Source code for sagemaker.hyperpod.inference.hp_jumpstart_endpoint

  1from typing import Dict, List, Optional
  2from pydantic import Field, ValidationError
  3from sagemaker.hyperpod.inference.config.constants import *
  4from sagemaker.hyperpod.inference.constant import INSTANCE_MIG_PROFILES
  5from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase
  6from sagemaker.hyperpod.common.config.metadata import Metadata
  7from sagemaker.hyperpod.common.utils import (
  8    get_current_cluster,
  9    get_current_region,
 10    get_jumpstart_model_instance_types,
 11    get_cluster_instance_types,
 12    get_default_namespace,
 13    setup_logging,
 14)
 15from sagemaker_core.main.resources import Endpoint
 16from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import (
 17    _HPJumpStartEndpoint,
 18    JumpStartModelStatus,
 19)
 20from sagemaker.hyperpod.common.telemetry.telemetry_logging import (
 21    _hyperpod_telemetry_emitter,
 22)
 23from sagemaker.hyperpod.common.telemetry.constants import Feature
 24from kubernetes import client
 25
 26
[docs] 27class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase): 28 metadata: Optional[Metadata] = Field(default=None) 29 status: Optional[JumpStartModelStatus] = Field(default=None) 30 31 def _create_internal(self, spec, debug=False): 32 """Shared internal create logic""" 33 logger = self.get_logger() 34 logger = setup_logging(logger, debug) 35 36 endpoint_name = "" 37 name = self.metadata.name if self.metadata else None 38 namespace = self.metadata.namespace if self.metadata else None 39 40 if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name: 41 endpoint_name = spec.sageMakerEndpoint.name 42 43 if not endpoint_name and not name: 44 raise Exception("Either metadata name or endpoint name must be provided") 45 46 if not name: 47 name = endpoint_name 48 49 if not namespace: 50 namespace = get_default_namespace() 51 52 53 # Create metadata object with labels and annotations if available 54 metadata = Metadata( 55 name=name, 56 namespace=namespace, 57 labels=self.metadata.labels if self.metadata else None, 58 annotations=self.metadata.annotations if self.metadata else None, 59 ) 60 61 # Only validate instance type if accelerator_partition_validation is provided 62 if not spec.server.acceleratorPartitionType: 63 self.validate_instance_type(spec.model.modelId, spec.server.instanceType) 64 else: 65 self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType) 66 67 self.call_create_api( 68 metadata=metadata, 69 kind=JUMPSTART_MODEL_KIND, 70 spec=spec, 71 debug=debug, 72 ) 73 74 self.metadata = metadata 75 76 logger.info( 77 f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..." 78 ) 79 80 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint") 81 def create( 82 self, 83 debug=False 84 ) -> None: 85 logger = self.get_logger() 86 logger = setup_logging(logger, debug) 87 spec = _HPJumpStartEndpoint(**self.model_dump(by_alias=True, exclude_none=True)) 88 self._create_internal(spec, debug) 89 90 91 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint_from_dict") 92 def create_from_dict(self, input: Dict, debug=False) -> None: 93 logger = self.get_logger() 94 logger = setup_logging(logger, debug) 95 96 spec = _HPJumpStartEndpoint.model_validate(input, by_name=True) 97 98 endpoint_name = "" 99 name = self.metadata.name if self.metadata else None 100 namespace = self.metadata.namespace if self.metadata else None 101 102 if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name: 103 endpoint_name = spec.sageMakerEndpoint.name 104 105 if not endpoint_name and not name: 106 raise Exception('Input "name" is required if endpoint name is not provided') 107 108 if not name: 109 name = endpoint_name 110 111 if not namespace: 112 namespace = get_default_namespace() 113 114 # Only validate instance type if accelerator_partition_validation is provided 115 if not spec.server.acceleratorPartitionType: 116 self.validate_instance_type(spec.model.modelId, spec.server.instanceType) 117 else: 118 self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType) 119 120 self.call_create_api( 121 name=name, # use model name as metadata name 122 kind=JUMPSTART_MODEL_KIND, 123 namespace=namespace, 124 spec=spec, 125 debug=debug, 126 ) 127 128 self.metadata = Metadata( 129 name=name, 130 namespace=namespace, 131 ) 132 133 logger.info( 134 f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..." 135 ) 136 137 138 def refresh(self): 139 if not self.metadata: 140 raise Exception( 141 "Metadata is empty. Please provide name and namespace in metadata field." 142 ) 143 144 response = HPJumpStartEndpoint.call_get_api( 145 name=self.metadata.name, 146 kind=JUMPSTART_MODEL_KIND, 147 namespace=self.metadata.namespace, 148 ) 149 150 if isinstance(response, dict) and "status" in response: 151 self.status = JumpStartModelStatus.model_validate( 152 response["status"], by_name=True 153 ) 154 else: 155 self.status = None 156 157 return self 158 159 @classmethod 160 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_js_endpoints") 161 def list( 162 cls, 163 namespace: str = None, 164 ) -> List[Endpoint]: 165 if not namespace: 166 namespace = get_default_namespace() 167 168 response = cls.call_list_api( 169 kind=JUMPSTART_MODEL_KIND, 170 namespace=namespace, 171 ) 172 173 endpoints = [] 174 175 if response and response["items"]: 176 for item in response["items"]: 177 name = item["metadata"]["name"] 178 endpoints.append(cls.get(name, namespace=namespace)) 179 180 return endpoints 181 182 @classmethod 183 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_js_endpoint") 184 def get(cls, name: str, namespace: str = None): 185 if not namespace: 186 namespace = get_default_namespace() 187 188 response = cls.call_get_api( 189 name=name, 190 kind=JUMPSTART_MODEL_KIND, 191 namespace=namespace, 192 ) 193 194 if not isinstance(response, dict): 195 raise Exception(f"Expected dictionary response, got {type(response)}") 196 197 endpoint = HPJumpStartEndpoint.model_validate(response["spec"], by_name=True) 198 status = response.get("status") 199 if status is not None: 200 try: 201 endpoint.status = JumpStartModelStatus.model_validate( 202 status, by_name=True 203 ) 204 except ValidationError: 205 endpoint.status = None 206 else: 207 endpoint.status = None 208 endpoint.metadata = Metadata.model_validate(response["metadata"], by_name=True) 209 210 return endpoint 211 212 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_js_endpoint") 213 def delete(self) -> None: 214 logger = self.get_logger() 215 logger = setup_logging(logger) 216 217 self.call_delete_api( 218 name=self.metadata.name, 219 kind=JUMPSTART_MODEL_KIND, 220 namespace=self.metadata.namespace, 221 ) 222 logger.info( 223 f"Deleting JumpStart model and sagemaker endpoint: {self.metadata.name}. This may take a few minutes..." 224 ) 225 226 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "invoke_js_endpoint") 227 def invoke(self, body, content_type="application/json"): 228 if not self.sageMakerEndpoint or not self.sageMakerEndpoint.name: 229 raise Exception("SageMaker endpoint name not found in this object!") 230 231 endpoint = Endpoint.get( 232 self.sageMakerEndpoint.name, region=get_current_region() 233 ) 234 235 return endpoint.invoke(body=body, content_type=content_type) 236 237 def validate_instance_type(self, model_id: str, instance_type: str): 238 logger = self.get_logger() 239 logger = setup_logging(logger) 240 241 model_types = None 242 cluster_instance_types = None 243 244 # verify supported instance types from model hub 245 try: 246 model_types = get_jumpstart_model_instance_types( 247 model_id, get_current_region() 248 ) 249 except Exception as e: 250 logger.warning( 251 f"Failed to fetch supported instance type from SageMakerPublicHub content: {e}" 252 ) 253 254 if model_types and (instance_type not in model_types): 255 raise Exception( 256 f"Instance type {instance_type} not supported by JumpStart model {model_id}. Supported types are {model_types}" 257 ) 258 259 # verify supported instance types from HyperPod cluster 260 try: 261 cluster_instance_types = get_cluster_instance_types( 262 cluster=get_current_cluster(), 263 region=get_current_region(), 264 ) 265 except Exception as e: 266 logger.warning(f"Failed to get instance types from HyperPod cluster: {e}") 267 268 if cluster_instance_types and (instance_type not in cluster_instance_types): 269 raise Exception( 270 f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}" 271 ) 272
[docs] 273 def validate_mig_profile(self, mig_profile: str, instance_type: str): 274 """ 275 Validate if the MIG profile is supported for the given instance type. 276 277 Args: 278 instance_type: SageMaker instance type (e.g., "ml.p4d.24xlarge") 279 mig_profile: MIG profile (e.g., "1g.10gb") 280 281 Raises: 282 ValueError: If the instance type doesn't support MIG profiles or if the MIG profile is not supported for the instance type 283 """ 284 logger = self.get_logger() 285 logger = setup_logging(logger) 286 287 if instance_type not in INSTANCE_MIG_PROFILES: 288 error_msg = ( 289 f"Instance type '{instance_type}' does not support MIG profiles. " 290 f"Supported instance types: {list(INSTANCE_MIG_PROFILES.keys())}" 291 ) 292 logger.error(error_msg) 293 raise ValueError(error_msg) 294 295 if mig_profile not in INSTANCE_MIG_PROFILES[instance_type]: 296 error_msg = ( 297 f"MIG profile '{mig_profile}' is not supported for instance type '{instance_type}'. " 298 f"Supported MIG profiles for {instance_type}: {INSTANCE_MIG_PROFILES[instance_type]}" 299 ) 300 logger.error(error_msg) 301 raise ValueError(error_msg) 302 303 logger.info( 304 f"MIG profile '{mig_profile}' is valid for instance type '{instance_type}'" 305 )
306
[docs] 307 @classmethod 308 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") 309 def list_pods(cls, namespace=None, endpoint_name=None): 310 cls.verify_kube_config() 311 312 if not namespace: 313 namespace = get_default_namespace() 314 315 v1 = client.CoreV1Api() 316 list_pods_response = v1.list_namespaced_pod(namespace=namespace) 317 318 endpoints = set() 319 if endpoint_name: 320 endpoints.add(endpoint_name) 321 else: 322 list_response = cls.call_list_api( 323 kind=JUMPSTART_MODEL_KIND, 324 namespace=namespace, 325 ) 326 if list_response and list_response["items"]: 327 for item in list_response["items"]: 328 endpoints.add(item["metadata"]["name"]) 329 330 pods = [] 331 for item in list_pods_response.items: 332 app_name = item.metadata.labels.get("app", None) 333 if app_name in endpoints: 334 # list_namespaced_pod will return all pods in the namespace, so we need to filter 335 # out the pods that are created by jumpstart endpoint 336 pods.append(item.metadata.name) 337 338 return pods