1from pydantic import BaseModel, ConfigDict, Field
2from typing import Optional, List, Literal
3
4
[docs]
5class Dimensions(BaseModel):
6 model_config = ConfigDict(extra="forbid")
7
8 name: str = Field(description="CloudWatch Metric dimension name")
9 value: str = Field(description="CloudWatch Metric dimension value")
10
11
[docs]
12class CloudWatchTrigger(BaseModel):
13 """CloudWatch metric trigger to use for autoscaling"""
14
15 model_config = ConfigDict(extra="forbid")
16
17 activationTargetValue: Optional[float] = Field(
18 default=0,
19 alias="activation_target_value",
20 description="Activation Value for CloudWatch metric to scale from 0 to 1. Only applicable if minReplicaCount = 0",
21 )
22 dimensions: Optional[List[Dimensions]] = Field(
23 default=None, description="Dimensions for Cloudwatch metrics"
24 )
25 metricCollectionPeriod: Optional[int] = Field(
26 default=300,
27 alias="metric_collection_period",
28 description="Defines the Period for CloudWatch query",
29 )
30 metricCollectionStartTime: Optional[int] = Field(
31 default=300,
32 alias="metric_collection_start_time",
33 description="Defines the StartTime for CloudWatch query",
34 )
35 metricName: Optional[str] = Field(
36 default=None,
37 alias="metric_name",
38 description="Metric name to query for Cloudwatch trigger",
39 )
40 metricStat: Optional[str] = Field(
41 default="Average",
42 alias="metric_stat",
43 description="Statistics metric to be used by Trigger. Used to define Stat for CloudWatch query. Default is Average.",
44 )
45 metricType: Optional[Literal["Value", "Average"]] = Field(
46 default="Average",
47 alias="metric_type",
48 description="The type of metric to be used by HPA. Enum: AverageValue - Uses average value of metric per pod, Value - Uses absolute metric value",
49 )
50 minValue: Optional[float] = Field(
51 default=0,
52 alias="min_value",
53 description="Minimum metric value used in case of empty response from CloudWatch. Default is 0.",
54 )
55 name: Optional[str] = Field(
56 default=None, description="Name for the CloudWatch trigger"
57 )
58 namespace: Optional[str] = Field(
59 default=None, description="AWS CloudWatch namespace for metric"
60 )
61 targetValue: Optional[float] = Field(
62 default=None,
63 alias="target_value",
64 description="TargetValue for CloudWatch metric",
65 )
66 useCachedMetrics: Optional[bool] = Field(
67 default=True,
68 alias="use_cached_metrics",
69 description="Enable caching of metric values during polling interval. Default is true",
70 )
71
72
[docs]
73class PrometheusTrigger(BaseModel):
74 """Prometheus metric trigger to use for autoscaling"""
75
76 model_config = ConfigDict(extra="forbid")
77
78 activationTargetValue: Optional[float] = Field(
79 default=0,
80 alias="activation_target_value",
81 description="Activation Value for Prometheus metric to scale from 0 to 1. Only applicable if minReplicaCount = 0",
82 )
83 customHeaders: Optional[str] = Field(
84 default=None,
85 alias="custom_headers",
86 description="Custom headers to include while querying the prometheus endpoint.",
87 )
88 metricType: Optional[Literal["Value", "Average"]] = Field(
89 default="Average",
90 alias="metric_type",
91 description="The type of metric to be used by HPA. Enum: AverageValue - Uses average value of metric per pod, Value - Uses absolute metric value",
92 )
93 name: Optional[str] = Field(
94 default=None, description="Name for the Prometheus trigger"
95 )
96 namespace: Optional[str] = Field(
97 default=None, description="Namespace for namespaced queries"
98 )
99 query: Optional[str] = Field(
100 default=None, description="PromQLQuery for the metric."
101 )
102 serverAddress: Optional[str] = Field(
103 default=None,
104 alias="server_address",
105 description="Server address for AMP workspace",
106 )
107 targetValue: Optional[float] = Field(
108 default=None,
109 alias="target_value",
110 description="Target metric value for scaling",
111 )
112 useCachedMetrics: Optional[bool] = Field(
113 default=True,
114 alias="use_cached_metrics",
115 description="Enable caching of metric values during polling interval. Default is true",
116 )
117
118
[docs]
119class AutoScalingSpec(BaseModel):
120 model_config = ConfigDict(extra="forbid")
121
122 cloudWatchTrigger: Optional[CloudWatchTrigger] = Field(
123 default=None,
124 alias="cloud_watch_trigger",
125 description="CloudWatch metric trigger to use for autoscaling",
126 )
127 cooldownPeriod: Optional[int] = Field(
128 default=300,
129 alias="cooldown_period",
130 description="The period to wait after the last trigger reported active before scaling the resource back to 0. Default 300 seconds.",
131 )
132 initialCooldownPeriod: Optional[int] = Field(
133 default=300,
134 alias="initial_cooldown_period",
135 description="The delay before the cooldownPeriod starts after the initial creation of the ScaledObject. Default 300 seconds.",
136 )
137 maxReplicaCount: Optional[int] = Field(
138 default=5,
139 alias="max_replica_count",
140 description="The maximum number of model pods to scale to. Default 5.",
141 )
142 minReplicaCount: Optional[int] = Field(
143 default=1,
144 alias="min_replica_count",
145 description="The minimum number of model pods to scale down to. Default 1.",
146 )
147 pollingInterval: Optional[int] = Field(
148 default=30,
149 alias="polling_interval",
150 description="This is the interval to check each trigger on. Default 30 seconds.",
151 )
152 prometheusTrigger: Optional[PrometheusTrigger] = Field(
153 default=None,
154 alias="prometheus_trigger",
155 description="Prometheus metric trigger to use for autoscaling",
156 )
157 scaleDownStabilizationTime: Optional[int] = Field(
158 default=300,
159 alias="scale_down_stabilization_time",
160 description="The time window to stabilize for HPA before scaling down. Default 300 seconds.",
161 )
162 scaleUpStabilizationTime: Optional[int] = Field(
163 default=0,
164 alias="scale_up_stabilization_time",
165 description="The time window to stabilize for HPA before scaling up. Default 0 seconds.",
166 )
167
168
[docs]
169class EnvironmentVariables(BaseModel):
170 model_config = ConfigDict(extra="forbid")
171
172 name: str
173 value: str
174
175
176class ModelMetrics(BaseModel):
177 """Configuration for model container metrics scraping"""
178
179 model_config = ConfigDict(extra="forbid")
180
181 path: Optional[str] = Field(
182 default="/metrics", description="Path where the model exposes metrics"
183 )
184 port: Optional[int] = Field(
185 default=8080,
186 description="Port where the model exposes metrics. If not specified, a default port will be used.",
187 )
188
189
[docs]
190class Metrics(BaseModel):
191 """Configuration for metrics collection and exposure"""
192
193 model_config = ConfigDict(extra="forbid")
194
195 enabled: Optional[bool] = Field(
196 default=True, description="Enable metrics collection for this model deployment"
197 )
198 metricsScrapeIntervalSeconds: Optional[int] = Field(
199 default=15,
200 alias="metrics_scrape_interval_seconds",
201 description="Scrape interval in seconds for metrics collection from sidecar and model container.",
202 )
203 modelMetrics: Optional[ModelMetrics] = Field(
204 default=None,
205 alias="model_metrics",
206 description="Configuration for model container metrics scraping",
207 )
208
209
[docs]
210class AdditionalConfigs(BaseModel):
211 model_config = ConfigDict(extra="forbid")
212
213 name: str
214 value: str
215
216
[docs]
217class Model(BaseModel):
218 model_config = ConfigDict(extra="forbid")
219
220 acceptEula: bool = Field(
221 default=False,
222 alias="accept_eula",
223 description="For models that require a Model Access Config, specify True or False to indicate whether model terms of use have been accepted.",
224 )
225 additionalConfigs: Optional[List[AdditionalConfigs]] = Field(
226 default=None, alias="additional_configs"
227 )
228 gatedModelDownloadRole: Optional[str] = Field(
229 default=None,
230 alias="gated_model_download_role",
231 description="The Amazon Resource Name (ARN) of an IAM role that will be used to download gated model",
232 )
233 modelHubName: Optional[str] = Field(
234 default="SageMakerPublicHub",
235 alias="model_hub_name",
236 description="The name of the model hub content. Can be an ARN or a simple name.",
237 )
238 modelId: str = Field(
239 alias="model_id",
240 description="The unique identifier of the model within the specified hub (hubContentArn).",
241 )
242 modelVersion: Optional[str] = Field(
243 default=None,
244 alias="model_version",
245 description="The version of the model to deploy, in semantic versioning format (e.g., 1.0.0).",
246 )
247
248
[docs]
249class SageMakerEndpoint(BaseModel):
250 model_config = ConfigDict(extra="forbid")
251
252 name: Optional[str] = Field(
253 default="",
254 description="Name of a SageMaker endpoint to be created for this JumpStartModel. The default value of empty string, when used, will skip endpoint creation.",
255 )
256
257
[docs]
258class Validations(BaseModel):
259 model_config = ConfigDict(extra='forbid')
260
261 acceleratorPartitionValidation: Optional[bool] = Field(
262 default=True,
263 alias="accelerator_partition_validation",
264 description="Enable MIG validation for GPU partitioning. Default is true."
265 )
266
267
[docs]
268class Server(BaseModel):
269 model_config = ConfigDict(extra="forbid")
270
271 executionRole: Optional[str] = Field(
272 default=None,
273 alias="execution_role",
274 description="The Amazon Resource Name (ARN) of an IAM role that will be used to deploy and manage the inference server",
275 )
276 instanceType: str = Field(
277 alias="instance_type",
278 description="The EC2 instance type to use for the inference server. Must be one of the supported types.",
279 )
280
281 acceleratorPartitionType: Optional[str] = Field(
282 default=None,
283 alias="accelerator_partition_type",
284 description="MIG profile to use for GPU partitioning"
285 )
286
287 validations: Optional[Validations] = Field(
288 default=None,
289 description="Validations configuration for the server"
290 )
291
292
[docs]
293class TlsConfig(BaseModel):
294 model_config = ConfigDict(extra="forbid")
295
296 tlsCertificateOutputS3Uri: Optional[str] = Field(
297 default=None, alias="tls_certificate_output_s3_uri"
298 )
299
300
301class _HPJumpStartEndpoint(BaseModel):
302 """Config defines the desired state of JumpStartModel."""
303
304 model_config = ConfigDict(extra="ignore")
305
306 autoScalingSpec: Optional[AutoScalingSpec] = Field(
307 default=None, alias="auto_scaling_spec"
308 )
309 environmentVariables: Optional[List[EnvironmentVariables]] = Field(
310 default=None,
311 alias="environment_variables",
312 description="Additional environment variables to be passed to the inference server. Limited to 100 key-value pairs.",
313 )
314 maxDeployTimeInSeconds: Optional[int] = Field(
315 default=3600,
316 alias="max_deploy_time_in_seconds",
317 description="Maximum allowed time in seconds for the deployment to complete before timing out. Defaults to 1 hour (3600 seconds)",
318 )
319 metrics: Optional[Metrics] = Field(
320 default=None, description="Configuration for metrics collection and exposure"
321 )
322 model: Model
323 replicas: Optional[int] = Field(
324 default=1,
325 description="The desired number of inference server replicas. Default 1.",
326 )
327 sageMakerEndpoint: Optional[SageMakerEndpoint] = Field(
328 default=None, alias="sage_maker_endpoint"
329 )
330 server: Server
331 tlsConfig: Optional[TlsConfig] = Field(default=None, alias="tls_config")
332
333
[docs]
334class Conditions(BaseModel):
335 """DeploymentCondition describes the state of a deployment at a certain point."""
336
337 model_config = ConfigDict(extra="forbid")
338
339 lastTransitionTime: Optional[str] = Field(
340 default=None,
341 alias="last_transition_time",
342 description="Last time the condition transitioned from one status to another.",
343 )
344 lastUpdateTime: Optional[str] = Field(
345 default=None,
346 alias="last_update_time",
347 description="The last time this condition was updated.",
348 )
349 message: Optional[str] = Field(
350 default=None,
351 description="A human readable message indicating details about the transition.",
352 )
353 reason: Optional[str] = Field(
354 default=None, description="The reason for the condition's last transition."
355 )
356 status: str = Field(
357 description="Status of the condition, one of True, False, Unknown."
358 )
359 type: str = Field(description="Type of deployment condition.")
360 observedGeneration: Optional[int] = Field(
361 default=None,
362 alias="observed_generation",
363 description="observedGeneration represents the .metadata.generation that the condition was set based upon. For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date with respect to the current state of the instance.",
364 )
365
366
[docs]
367class Status(BaseModel):
368 """Status of the Deployment Object"""
369
370 model_config = ConfigDict(extra="forbid")
371
372 availableReplicas: Optional[int] = Field(
373 default=None,
374 alias="available_replicas",
375 description="Total number of available pods (ready for at least minReadySeconds) targeted by this deployment.",
376 )
377 collisionCount: Optional[int] = Field(
378 default=None,
379 alias="collision_count",
380 description="Count of hash collisions for the Deployment. The Deployment controller uses this field as a collision avoidance mechanism when it needs to create the name for the newest ReplicaSet.",
381 )
382 conditions: Optional[List[Conditions]] = Field(
383 default=None,
384 description="Represents the latest available observations of a deployment's current state.",
385 )
386 observedGeneration: Optional[int] = Field(
387 default=None,
388 alias="observed_generation",
389 description="The generation observed by the deployment controller.",
390 )
391 readyReplicas: Optional[int] = Field(
392 default=None,
393 alias="ready_replicas",
394 description="readyReplicas is the number of pods targeted by this Deployment with a Ready Condition.",
395 )
396 replicas: Optional[int] = Field(
397 default=None,
398 description="Total number of non-terminated pods targeted by this deployment (their labels match the selector).",
399 )
400 unavailableReplicas: Optional[int] = Field(
401 default=None,
402 alias="unavailable_replicas",
403 description="Total number of unavailable pods targeted by this deployment. This is the total number of pods that are still required for the deployment to have 100% available capacity. They may either be pods that are running but not yet available or pods that still have not been created.",
404 )
405 updatedReplicas: Optional[int] = Field(
406 default=None,
407 alias="updated_replicas",
408 description="Total number of non-terminated pods targeted by this deployment that have the desired template spec.",
409 )
410
411
[docs]
412class DeploymentStatus(BaseModel):
413 """Details of the native kubernetes deployment that hosts the model"""
414
415 model_config = ConfigDict(extra="forbid")
416
417 deploymentObjectOverallState: Optional[str] = Field(
418 default=None,
419 alias="deployment_object_overall_state",
420 description="Overall State of the Deployment Object",
421 )
422 lastUpdated: str = Field(alias="last_updated", description="Last Update Time")
423 message: Optional[str] = Field(
424 default=None,
425 description="Message populated in the root CRD while updating the status of underlying Deployment",
426 )
427 name: str = Field(description="Name of the Deployment Object")
428 reason: Optional[str] = Field(
429 default=None,
430 description="Reason populated in the root CRD while updating the status of underlying Deployment",
431 )
432 status: Optional[Status] = Field(
433 default=None, description="Status of the Deployment Object"
434 )
435
436
[docs]
437class Sagemaker(BaseModel):
438 """Status of the SageMaker endpoint"""
439
440 model_config = ConfigDict(extra="forbid")
441
442 configArn: Optional[str] = Field(
443 default=None,
444 alias="config_arn",
445 description="The Amazon Resource Name (ARN) of the endpoint configuration.",
446 )
447 endpointArn: Optional[str] = Field(
448 default=None,
449 alias="endpoint_arn",
450 description="The Amazon Resource Name (ARN) of the SageMaker endpoint",
451 )
452 modelArn: Optional[str] = Field(
453 default=None,
454 alias="model_arn",
455 description="The ARN of the model created in SageMaker.",
456 )
457 state: str = Field(description="The current state of the SageMaker endpoint")
458
459
[docs]
460class Endpoints(BaseModel):
461 """EndpointStatus contains the status of SageMaker endpoints"""
462
463 model_config = ConfigDict(extra="forbid")
464
465 sagemaker: Optional[Sagemaker] = Field(
466 default=None, description="Status of the SageMaker endpoint"
467 )
468
469
[docs]
470class ModelMetrics(BaseModel):
471 """Status of model container metrics collection"""
472
473 model_config = ConfigDict(extra="forbid")
474
475 path: Optional[str] = Field(
476 default=None, description="The path where metrics are available"
477 )
478 port: Optional[int] = Field(
479 default=None, description="The port on which metrics are exposed"
480 )
481
482
[docs]
483class MetricsStatus(BaseModel):
484 """Status of metrics collection"""
485
486 model_config = ConfigDict(extra="forbid")
487
488 enabled: bool = Field(description="Whether metrics collection is enabled")
489 errorMessage: Optional[str] = Field(
490 default=None,
491 alias="error_message",
492 description="Error message if metrics collection is in error state",
493 )
494 metricsScrapeIntervalSeconds: Optional[int] = Field(
495 default=None,
496 alias="metrics_scrape_interval_seconds",
497 description="Scrape interval in seconds for metrics collection from sidecar and model container.",
498 )
499 modelMetrics: Optional[ModelMetrics] = Field(
500 default=None,
501 alias="model_metrics",
502 description="Status of model container metrics collection",
503 )
504 state: Optional[str] = Field(
505 default=None, description="Current state of metrics collection"
506 )
507
508
[docs]
509class TlsCertificate(BaseModel):
510 """CertificateStatus represents the status of TLS certificates"""
511
512 model_config = ConfigDict(extra="forbid")
513
514 certificateARN: Optional[str] = Field(
515 default=None,
516 alias="certificate_arn",
517 description="The Amazon Resource Name (ARN) of the ACM certificate",
518 )
519 certificateDomainNames: Optional[List[str]] = Field(
520 default=None,
521 alias="certificate_domain_names",
522 description="The certificate domain names that is attached to the certificate",
523 )
524 certificateName: Optional[str] = Field(
525 default=None,
526 alias="certificate_name",
527 description="The certificate name of cert manager",
528 )
529 importedCertificates: Optional[List[str]] = Field(
530 default=None,
531 alias="imported_certificates",
532 description="Used for tracking the imported certificates to ACM",
533 )
534 issuerName: Optional[str] = Field(
535 default=None, alias="issuer_name", description="The issuer name of cert manager"
536 )
537 lastCertExpiryTime: Optional[str] = Field(
538 default=None,
539 alias="last_cert_expiry_time",
540 description="The last certificate expiry time",
541 )
542 tlsCertificateOutputS3Bucket: Optional[str] = Field(
543 default=None,
544 alias="tls_certificate_output_s3_bucket",
545 description="S3 bucket that stores the certificate that needs to be trusted",
546 )
547 tlsCertificateS3Keys: Optional[List[str]] = Field(
548 default=None,
549 alias="tls_certificate_s3_keys",
550 description="The output tls certificate S3 key that points to the .pem file",
551 )
552
553
[docs]
554class JumpStartModelStatus(BaseModel):
555 """ModelDeploymentStatus defines the observed state of ModelDeployment"""
556
557 model_config = ConfigDict(extra="forbid")
558
559 conditions: Optional[List[Conditions]] = Field(
560 default=None,
561 description="Detailed conditions representing the state of the deployment",
562 )
563 deploymentStatus: Optional[DeploymentStatus] = Field(
564 default=None,
565 alias="deployment_status",
566 description="Details of the native kubernetes deployment that hosts the model",
567 )
568 endpoints: Optional[Endpoints] = Field(
569 default=None,
570 description="EndpointStatus contains the status of SageMaker endpoints",
571 )
572 metricsStatus: Optional[MetricsStatus] = Field(
573 default=None, alias="metrics_status", description="Status of metrics collection"
574 )
575 observedGeneration: Optional[int] = Field(
576 default=None,
577 alias="observed_generation",
578 description="Latest generation reconciled by controller",
579 )
580 replicas: Optional[int] = Field(
581 default=None, description="The observed number of inference server replicas."
582 )
583 selector: Optional[str] = Field(
584 default=None, description="LabelSelector for the deployment."
585 )
586 state: Optional[
587 Literal[
588 "DeploymentPending",
589 "DeploymentInProgress",
590 "DeploymentFailed",
591 "DeploymentComplete",
592 "DeletionPending",
593 "DeletionInProgress",
594 "DeletionFailed",
595 "DeletionComplete",
596 ]
597 ] = Field(default=None, description="Current phase of the model deployment")
598 tlsCertificate: Optional[TlsCertificate] = Field(
599 default=None,
600 alias="tls_certificate",
601 description="CertificateStatus represents the status of TLS certificates",
602 )