Source code for sagemaker.hyperpod.inference.hp_endpoint

  1from sagemaker.hyperpod.common.config.metadata import Metadata
  2from sagemaker.hyperpod.inference.config.constants import *
  3from sagemaker.hyperpod.common.utils import (
  4    get_default_namespace,
  5    get_cluster_instance_types,
  6    setup_logging,
  7    get_current_cluster,
  8    get_current_region,
  9)
 10from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
 11    InferenceEndpointConfigStatus,
 12    _HPEndpoint,
 13)
 14from sagemaker.hyperpod.common.telemetry.telemetry_logging import (
 15    _hyperpod_telemetry_emitter,
 16)
 17from sagemaker.hyperpod.common.telemetry.constants import Feature
 18from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase
 19from typing import Dict, List, Optional
 20from sagemaker_core.main.resources import Endpoint
 21from pydantic import Field, ValidationError
 22from kubernetes import client
 23
 24
[docs] 25class HPEndpoint(_HPEndpoint, HPEndpointBase): 26 metadata: Optional[Metadata] = Field(default=None) 27 status: Optional[InferenceEndpointConfigStatus] = Field(default=None) 28 29 def _create_internal(self, spec, debug=False): 30 """Shared internal create logic""" 31 logger = self.get_logger() 32 logger = setup_logging(logger, debug) 33 34 name = self.metadata.name if self.metadata else None 35 namespace = self.metadata.namespace if self.metadata else None 36 37 if not spec.endpointName and not name: 38 raise Exception('Either metadata name or endpoint name must be provided') 39 40 if not namespace: 41 namespace = get_default_namespace() 42 43 if not name: 44 name = spec.endpointName 45 46 # Create metadata object with labels and annotations if available 47 metadata = Metadata( 48 name=name, 49 namespace=namespace, 50 labels=self.metadata.labels if self.metadata else None, 51 annotations=self.metadata.annotations if self.metadata else None, 52 ) 53 54 self.validate_instance_type(spec.instanceType) 55 56 self.call_create_api( 57 metadata=metadata, 58 kind=INFERENCE_ENDPOINT_CONFIG_KIND, 59 spec=spec, 60 debug=debug, 61 ) 62 63 self.metadata = metadata 64 65 logger.info( 66 f"Creating sagemaker model and endpoint. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..." 67 ) 68 69 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_endpoint") 70 def create( 71 self, 72 debug=False 73 ) -> None: 74 spec = _HPEndpoint(**self.model_dump(by_alias=True, exclude_none=True)) 75 self._create_internal(spec, debug) 76 77 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_endpoint_from_dict") 78 def create_from_dict( 79 self, 80 input: Dict, 81 debug=False 82 ) -> None: 83 spec = _HPEndpoint.model_validate(input, by_name=True) 84 self._create_internal(spec, debug) 85 86 87 def refresh(self): 88 if not self.metadata: 89 raise Exception( 90 "Metadata not found! Please provide object name and namespace in metadata field." 91 ) 92 93 response = self.call_get_api( 94 name=self.metadata.name, 95 kind=INFERENCE_ENDPOINT_CONFIG_KIND, 96 namespace=self.metadata.namespace, 97 ) 98 99 self.status = InferenceEndpointConfigStatus.model_validate( 100 response["status"], by_name=True 101 ) 102 103 return self 104 105 @classmethod 106 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_endpoints") 107 def list( 108 cls, 109 namespace: str = None, 110 ) -> List[Endpoint]: 111 if not namespace: 112 namespace = get_default_namespace() 113 114 response = cls.call_list_api( 115 kind=INFERENCE_ENDPOINT_CONFIG_KIND, 116 namespace=namespace, 117 ) 118 119 endpoints = [] 120 121 if response and response["items"]: 122 for item in response["items"]: 123 name = item["metadata"]["name"] 124 endpoints.append(cls.get(name, namespace=namespace)) 125 126 return endpoints 127 128 @classmethod 129 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_endpoint") 130 def get(cls, name: str, namespace: str = None) -> Endpoint: 131 if not namespace: 132 namespace = get_default_namespace() 133 134 response = cls.call_get_api( 135 name=name, 136 kind=INFERENCE_ENDPOINT_CONFIG_KIND, 137 namespace=namespace, 138 ) 139 140 endpoint = HPEndpoint.model_validate(response["spec"], by_name=True) 141 status = response.get("status") 142 if status is not None: 143 try: 144 endpoint.status = InferenceEndpointConfigStatus.model_validate( 145 status, by_name=True 146 ) 147 except ValidationError: 148 endpoint.status = None 149 else: 150 endpoint.status = None 151 endpoint.metadata = Metadata.model_validate(response["metadata"], by_name=True) 152 153 return endpoint 154 155 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_endpoint") 156 def delete(self) -> None: 157 logger = self.get_logger() 158 logger = setup_logging(logger) 159 160 self.call_delete_api( 161 name=self.metadata.name, 162 kind=INFERENCE_ENDPOINT_CONFIG_KIND, 163 namespace=self.metadata.namespace, 164 ) 165 logger.info(f"Deleting HPEndpoint: {self.metadata.name}...") 166 167 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "invoke_endpoint") 168 def invoke(self, body, content_type="application/json"): 169 if not self.endpointName: 170 raise Exception("SageMaker endpoint name not found in this object!") 171 172 endpoint = Endpoint.get(self.endpointName, region=get_current_region()) 173 174 return endpoint.invoke(body=body, content_type=content_type) 175 176 def validate_instance_type(self, instance_type: str): 177 logger = self.get_logger() 178 logger = setup_logging(logger) 179 180 cluster_instance_types = None 181 182 # verify supported instance types from HyperPod cluster 183 try: 184 cluster_instance_types = get_cluster_instance_types( 185 cluster=get_current_cluster(), 186 region=get_current_region(), 187 ) 188 except Exception as e: 189 logger.warning(f"Failed to get instance types from HyperPod cluster: {e}") 190 191 if cluster_instance_types and (instance_type not in cluster_instance_types): 192 raise Exception( 193 f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}" 194 ) 195
[docs] 196 @classmethod 197 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") 198 def list_pods(cls, namespace=None, endpoint_name=None): 199 cls.verify_kube_config() 200 201 if not namespace: 202 namespace = get_default_namespace() 203 204 v1 = client.CoreV1Api() 205 list_pods_response = v1.list_namespaced_pod(namespace=namespace) 206 207 endpoints = set() 208 if endpoint_name: 209 endpoints.add(endpoint_name) 210 else: 211 list_response = cls.call_list_api( 212 kind=INFERENCE_ENDPOINT_CONFIG_KIND, 213 namespace=namespace, 214 ) 215 if list_response and list_response["items"]: 216 for item in list_response["items"]: 217 endpoints.add(item["metadata"]["name"]) 218 219 pods = [] 220 for item in list_pods_response.items: 221 app_name = item.metadata.labels.get("app", None) 222 if app_name in endpoints: 223 # list_namespaced_pod will return all pods in the namespace, so we need to filter 224 # out the pods that are created by custom endpoint 225 pods.append(item.metadata.name) 226 227 return pods