1from typing import Dict, List, Optional
2from pydantic import Field, ValidationError
3from sagemaker.hyperpod.inference.config.constants import *
4from sagemaker.hyperpod.inference.constant import INSTANCE_MIG_PROFILES
5from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase
6from sagemaker.hyperpod.common.config.metadata import Metadata
7from sagemaker.hyperpod.common.utils import (
8 get_current_cluster,
9 get_current_region,
10 get_jumpstart_model_instance_types,
11 get_cluster_instance_types,
12 get_default_namespace,
13 setup_logging,
14)
15from sagemaker_core.main.resources import Endpoint
16from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import (
17 _HPJumpStartEndpoint,
18 JumpStartModelStatus,
19)
20from sagemaker.hyperpod.common.telemetry.telemetry_logging import (
21 _hyperpod_telemetry_emitter,
22)
23from sagemaker.hyperpod.common.telemetry.constants import Feature
24from kubernetes import client
25
26
[docs]
27class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase):
28 metadata: Optional[Metadata] = Field(default=None)
29 status: Optional[JumpStartModelStatus] = Field(default=None)
30
31 def _create_internal(self, spec, debug=False):
32 """Shared internal create logic"""
33 logger = self.get_logger()
34 logger = setup_logging(logger, debug)
35
36 endpoint_name = ""
37 name = self.metadata.name if self.metadata else None
38 namespace = self.metadata.namespace if self.metadata else None
39
40 if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name:
41 endpoint_name = spec.sageMakerEndpoint.name
42
43 if not endpoint_name and not name:
44 raise Exception("Either metadata name or endpoint name must be provided")
45
46 if not name:
47 name = endpoint_name
48
49 if not namespace:
50 namespace = get_default_namespace()
51
52
53 # Create metadata object with labels and annotations if available
54 metadata = Metadata(
55 name=name,
56 namespace=namespace,
57 labels=self.metadata.labels if self.metadata else None,
58 annotations=self.metadata.annotations if self.metadata else None,
59 )
60
61 # Only validate instance type if accelerator_partition_validation is provided
62 if not spec.server.acceleratorPartitionType:
63 self.validate_instance_type(spec.model.modelId, spec.server.instanceType)
64 else:
65 self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType)
66
67 self.call_create_api(
68 metadata=metadata,
69 kind=JUMPSTART_MODEL_KIND,
70 spec=spec,
71 debug=debug,
72 )
73
74 self.metadata = metadata
75
76 logger.info(
77 f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..."
78 )
79
80 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint")
81 def create(
82 self,
83 debug=False
84 ) -> None:
85 logger = self.get_logger()
86 logger = setup_logging(logger, debug)
87 spec = _HPJumpStartEndpoint(**self.model_dump(by_alias=True, exclude_none=True))
88 self._create_internal(spec, debug)
89
90
91 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint_from_dict")
92 def create_from_dict(self, input: Dict, debug=False) -> None:
93 logger = self.get_logger()
94 logger = setup_logging(logger, debug)
95
96 spec = _HPJumpStartEndpoint.model_validate(input, by_name=True)
97
98 endpoint_name = ""
99 name = self.metadata.name if self.metadata else None
100 namespace = self.metadata.namespace if self.metadata else None
101
102 if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name:
103 endpoint_name = spec.sageMakerEndpoint.name
104
105 if not endpoint_name and not name:
106 raise Exception('Input "name" is required if endpoint name is not provided')
107
108 if not name:
109 name = endpoint_name
110
111 if not namespace:
112 namespace = get_default_namespace()
113
114 # Only validate instance type if accelerator_partition_validation is provided
115 if not spec.server.acceleratorPartitionType:
116 self.validate_instance_type(spec.model.modelId, spec.server.instanceType)
117 else:
118 self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType)
119
120 self.call_create_api(
121 name=name, # use model name as metadata name
122 kind=JUMPSTART_MODEL_KIND,
123 namespace=namespace,
124 spec=spec,
125 debug=debug,
126 )
127
128 self.metadata = Metadata(
129 name=name,
130 namespace=namespace,
131 )
132
133 logger.info(
134 f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..."
135 )
136
137
138 def refresh(self):
139 if not self.metadata:
140 raise Exception(
141 "Metadata is empty. Please provide name and namespace in metadata field."
142 )
143
144 response = HPJumpStartEndpoint.call_get_api(
145 name=self.metadata.name,
146 kind=JUMPSTART_MODEL_KIND,
147 namespace=self.metadata.namespace,
148 )
149
150 if isinstance(response, dict) and "status" in response:
151 self.status = JumpStartModelStatus.model_validate(
152 response["status"], by_name=True
153 )
154 else:
155 self.status = None
156
157 return self
158
159 @classmethod
160 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_js_endpoints")
161 def list(
162 cls,
163 namespace: str = None,
164 ) -> List[Endpoint]:
165 if not namespace:
166 namespace = get_default_namespace()
167
168 response = cls.call_list_api(
169 kind=JUMPSTART_MODEL_KIND,
170 namespace=namespace,
171 )
172
173 endpoints = []
174
175 if response and response["items"]:
176 for item in response["items"]:
177 name = item["metadata"]["name"]
178 endpoints.append(cls.get(name, namespace=namespace))
179
180 return endpoints
181
182 @classmethod
183 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_js_endpoint")
184 def get(cls, name: str, namespace: str = None):
185 if not namespace:
186 namespace = get_default_namespace()
187
188 response = cls.call_get_api(
189 name=name,
190 kind=JUMPSTART_MODEL_KIND,
191 namespace=namespace,
192 )
193
194 if not isinstance(response, dict):
195 raise Exception(f"Expected dictionary response, got {type(response)}")
196
197 endpoint = HPJumpStartEndpoint.model_validate(response["spec"], by_name=True)
198 status = response.get("status")
199 if status is not None:
200 try:
201 endpoint.status = JumpStartModelStatus.model_validate(
202 status, by_name=True
203 )
204 except ValidationError:
205 endpoint.status = None
206 else:
207 endpoint.status = None
208 endpoint.metadata = Metadata.model_validate(response["metadata"], by_name=True)
209
210 return endpoint
211
212 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_js_endpoint")
213 def delete(self) -> None:
214 logger = self.get_logger()
215 logger = setup_logging(logger)
216
217 self.call_delete_api(
218 name=self.metadata.name,
219 kind=JUMPSTART_MODEL_KIND,
220 namespace=self.metadata.namespace,
221 )
222 logger.info(
223 f"Deleting JumpStart model and sagemaker endpoint: {self.metadata.name}. This may take a few minutes..."
224 )
225
226 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "invoke_js_endpoint")
227 def invoke(self, body, content_type="application/json"):
228 if not self.sageMakerEndpoint or not self.sageMakerEndpoint.name:
229 raise Exception("SageMaker endpoint name not found in this object!")
230
231 endpoint = Endpoint.get(
232 self.sageMakerEndpoint.name, region=get_current_region()
233 )
234
235 return endpoint.invoke(body=body, content_type=content_type)
236
237 def validate_instance_type(self, model_id: str, instance_type: str):
238 logger = self.get_logger()
239 logger = setup_logging(logger)
240
241 model_types = None
242 cluster_instance_types = None
243
244 # verify supported instance types from model hub
245 try:
246 model_types = get_jumpstart_model_instance_types(
247 model_id, get_current_region()
248 )
249 except Exception as e:
250 logger.warning(
251 f"Failed to fetch supported instance type from SageMakerPublicHub content: {e}"
252 )
253
254 if model_types and (instance_type not in model_types):
255 raise Exception(
256 f"Instance type {instance_type} not supported by JumpStart model {model_id}. Supported types are {model_types}"
257 )
258
259 # verify supported instance types from HyperPod cluster
260 try:
261 cluster_instance_types = get_cluster_instance_types(
262 cluster=get_current_cluster(),
263 region=get_current_region(),
264 )
265 except Exception as e:
266 logger.warning(f"Failed to get instance types from HyperPod cluster: {e}")
267
268 if cluster_instance_types and (instance_type not in cluster_instance_types):
269 raise Exception(
270 f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}"
271 )
272
[docs]
273 def validate_mig_profile(self, mig_profile: str, instance_type: str):
274 """
275 Validate if the MIG profile is supported for the given instance type.
276
277 Args:
278 instance_type: SageMaker instance type (e.g., "ml.p4d.24xlarge")
279 mig_profile: MIG profile (e.g., "1g.10gb")
280
281 Raises:
282 ValueError: If the instance type doesn't support MIG profiles or if the MIG profile is not supported for the instance type
283 """
284 logger = self.get_logger()
285 logger = setup_logging(logger)
286
287 if instance_type not in INSTANCE_MIG_PROFILES:
288 error_msg = (
289 f"Instance type '{instance_type}' does not support MIG profiles. "
290 f"Supported instance types: {list(INSTANCE_MIG_PROFILES.keys())}"
291 )
292 logger.error(error_msg)
293 raise ValueError(error_msg)
294
295 if mig_profile not in INSTANCE_MIG_PROFILES[instance_type]:
296 error_msg = (
297 f"MIG profile '{mig_profile}' is not supported for instance type '{instance_type}'. "
298 f"Supported MIG profiles for {instance_type}: {INSTANCE_MIG_PROFILES[instance_type]}"
299 )
300 logger.error(error_msg)
301 raise ValueError(error_msg)
302
303 logger.info(
304 f"MIG profile '{mig_profile}' is valid for instance type '{instance_type}'"
305 )
306
[docs]
307 @classmethod
308 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
309 def list_pods(cls, namespace=None, endpoint_name=None):
310 cls.verify_kube_config()
311
312 if not namespace:
313 namespace = get_default_namespace()
314
315 v1 = client.CoreV1Api()
316 list_pods_response = v1.list_namespaced_pod(namespace=namespace)
317
318 endpoints = set()
319 if endpoint_name:
320 endpoints.add(endpoint_name)
321 else:
322 list_response = cls.call_list_api(
323 kind=JUMPSTART_MODEL_KIND,
324 namespace=namespace,
325 )
326 if list_response and list_response["items"]:
327 for item in list_response["items"]:
328 endpoints.add(item["metadata"]["name"])
329
330 pods = []
331 for item in list_pods_response.items:
332 app_name = item.metadata.labels.get("app", None)
333 if app_name in endpoints:
334 # list_namespaced_pod will return all pods in the namespace, so we need to filter
335 # out the pods that are created by jumpstart endpoint
336 pods.append(item.metadata.name)
337
338 return pods