1from sagemaker.hyperpod.common.config.metadata import Metadata
2from sagemaker.hyperpod.inference.config.constants import *
3from sagemaker.hyperpod.common.utils import (
4 get_default_namespace,
5 get_cluster_instance_types,
6 setup_logging,
7 get_current_cluster,
8 get_current_region,
9)
10from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
11 InferenceEndpointConfigStatus,
12 _HPEndpoint,
13)
14from sagemaker.hyperpod.common.telemetry.telemetry_logging import (
15 _hyperpod_telemetry_emitter,
16)
17from sagemaker.hyperpod.common.telemetry.constants import Feature
18from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase
19from typing import Dict, List, Optional
20from sagemaker_core.main.resources import Endpoint
21from pydantic import Field, ValidationError
22from kubernetes import client
23
24
[docs]
25class HPEndpoint(_HPEndpoint, HPEndpointBase):
26 metadata: Optional[Metadata] = Field(default=None)
27 status: Optional[InferenceEndpointConfigStatus] = Field(default=None)
28
29 def _create_internal(self, spec, debug=False):
30 """Shared internal create logic"""
31 logger = self.get_logger()
32 logger = setup_logging(logger, debug)
33
34 name = self.metadata.name if self.metadata else None
35 namespace = self.metadata.namespace if self.metadata else None
36
37 if not spec.endpointName and not name:
38 raise Exception('Either metadata name or endpoint name must be provided')
39
40 if not namespace:
41 namespace = get_default_namespace()
42
43 if not name:
44 name = spec.endpointName
45
46 # Create metadata object with labels and annotations if available
47 metadata = Metadata(
48 name=name,
49 namespace=namespace,
50 labels=self.metadata.labels if self.metadata else None,
51 annotations=self.metadata.annotations if self.metadata else None,
52 )
53
54 self.validate_instance_type(spec.instanceType)
55
56 self.call_create_api(
57 metadata=metadata,
58 kind=INFERENCE_ENDPOINT_CONFIG_KIND,
59 spec=spec,
60 debug=debug,
61 )
62
63 self.metadata = metadata
64
65 logger.info(
66 f"Creating sagemaker model and endpoint. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..."
67 )
68
69 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_endpoint")
70 def create(
71 self,
72 debug=False
73 ) -> None:
74 spec = _HPEndpoint(**self.model_dump(by_alias=True, exclude_none=True))
75 self._create_internal(spec, debug)
76
77 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_endpoint_from_dict")
78 def create_from_dict(
79 self,
80 input: Dict,
81 debug=False
82 ) -> None:
83 spec = _HPEndpoint.model_validate(input, by_name=True)
84 self._create_internal(spec, debug)
85
86
87 def refresh(self):
88 if not self.metadata:
89 raise Exception(
90 "Metadata not found! Please provide object name and namespace in metadata field."
91 )
92
93 response = self.call_get_api(
94 name=self.metadata.name,
95 kind=INFERENCE_ENDPOINT_CONFIG_KIND,
96 namespace=self.metadata.namespace,
97 )
98
99 self.status = InferenceEndpointConfigStatus.model_validate(
100 response["status"], by_name=True
101 )
102
103 return self
104
105 @classmethod
106 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_endpoints")
107 def list(
108 cls,
109 namespace: str = None,
110 ) -> List[Endpoint]:
111 if not namespace:
112 namespace = get_default_namespace()
113
114 response = cls.call_list_api(
115 kind=INFERENCE_ENDPOINT_CONFIG_KIND,
116 namespace=namespace,
117 )
118
119 endpoints = []
120
121 if response and response["items"]:
122 for item in response["items"]:
123 name = item["metadata"]["name"]
124 endpoints.append(cls.get(name, namespace=namespace))
125
126 return endpoints
127
128 @classmethod
129 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_endpoint")
130 def get(cls, name: str, namespace: str = None) -> Endpoint:
131 if not namespace:
132 namespace = get_default_namespace()
133
134 response = cls.call_get_api(
135 name=name,
136 kind=INFERENCE_ENDPOINT_CONFIG_KIND,
137 namespace=namespace,
138 )
139
140 endpoint = HPEndpoint.model_validate(response["spec"], by_name=True)
141 status = response.get("status")
142 if status is not None:
143 try:
144 endpoint.status = InferenceEndpointConfigStatus.model_validate(
145 status, by_name=True
146 )
147 except ValidationError:
148 endpoint.status = None
149 else:
150 endpoint.status = None
151 endpoint.metadata = Metadata.model_validate(response["metadata"], by_name=True)
152
153 return endpoint
154
155 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_endpoint")
156 def delete(self) -> None:
157 logger = self.get_logger()
158 logger = setup_logging(logger)
159
160 self.call_delete_api(
161 name=self.metadata.name,
162 kind=INFERENCE_ENDPOINT_CONFIG_KIND,
163 namespace=self.metadata.namespace,
164 )
165 logger.info(f"Deleting HPEndpoint: {self.metadata.name}...")
166
167 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "invoke_endpoint")
168 def invoke(self, body, content_type="application/json"):
169 if not self.endpointName:
170 raise Exception("SageMaker endpoint name not found in this object!")
171
172 endpoint = Endpoint.get(self.endpointName, region=get_current_region())
173
174 return endpoint.invoke(body=body, content_type=content_type)
175
176 def validate_instance_type(self, instance_type: str):
177 logger = self.get_logger()
178 logger = setup_logging(logger)
179
180 cluster_instance_types = None
181
182 # verify supported instance types from HyperPod cluster
183 try:
184 cluster_instance_types = get_cluster_instance_types(
185 cluster=get_current_cluster(),
186 region=get_current_region(),
187 )
188 except Exception as e:
189 logger.warning(f"Failed to get instance types from HyperPod cluster: {e}")
190
191 if cluster_instance_types and (instance_type not in cluster_instance_types):
192 raise Exception(
193 f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}"
194 )
195
[docs]
196 @classmethod
197 @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
198 def list_pods(cls, namespace=None, endpoint_name=None):
199 cls.verify_kube_config()
200
201 if not namespace:
202 namespace = get_default_namespace()
203
204 v1 = client.CoreV1Api()
205 list_pods_response = v1.list_namespaced_pod(namespace=namespace)
206
207 endpoints = set()
208 if endpoint_name:
209 endpoints.add(endpoint_name)
210 else:
211 list_response = cls.call_list_api(
212 kind=INFERENCE_ENDPOINT_CONFIG_KIND,
213 namespace=namespace,
214 )
215 if list_response and list_response["items"]:
216 for item in list_response["items"]:
217 endpoints.add(item["metadata"]["name"])
218
219 pods = []
220 for item in list_pods_response.items:
221 app_name = item.metadata.labels.get("app", None)
222 if app_name in endpoints:
223 # list_namespaced_pod will return all pods in the namespace, so we need to filter
224 # out the pods that are created by custom endpoint
225 pods.append(item.metadata.name)
226
227 return pods