-
Notifications
You must be signed in to change notification settings - Fork 20
/
sagemaker_predictor.py
132 lines (116 loc) · 5.86 KB
/
sagemaker_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import time
import json
import logging
import sagemaker
import pandas as pd
from datetime import datetime
from typing import Dict, Optional
from fmbench.utils import count_tokens
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from fmbench.scripts.sagemaker_metrics import get_endpoint_metrics
from fmbench.scripts.fmbench_predictor import (FMBenchPredictor,
FMBenchPredictionResponse)
# set a logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SageMakerPredictor(FMBenchPredictor):
# overriding abstract method
def __init__(self,
endpoint_name: str,
inference_spec: Optional[Dict],
metadata: Optional[Dict]):
self._predictor: Optional[sagemaker.base_predictor.Predictor] = None
self._endpoint_name: str = endpoint_name
self._inference_spec: Dict = inference_spec
self._variant_name: Optional[str] = None
if metadata is not None:
self._variant_name = metadata.get("variant_name")
try:
# Create a SageMaker Predictor object
self._predictor = Predictor(
endpoint_name=self._endpoint_name,
sagemaker_session=sagemaker.Session(),
serializer=JSONSerializer()
)
except Exception as e:
logger.error(f"create_predictor, exception occured while creating predictor "
f"for endpoint_name={self._endpoint_name}, exception={e}")
logger.info(f"__init__ _predictor={self._predictor}, _inference_spec={self._inference_spec}")
def get_prediction(self, payload: Dict) -> FMBenchPredictionResponse:
response_json: Optional[Dict] = None
response: Optional[str] = None
latency: Optional[float] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
# represents the number of tokens in the prompt payload
prompt_tokens = count_tokens(payload["inputs"])
try:
st = time.perf_counter()
split_input_and_inference_params = None
if self._inference_spec is not None:
split_input_and_inference_params = self._inference_spec.get("split_input_and_parameters")
response = None
response = None
if split_input_and_inference_params is True:
response = self._predictor.predict(payload["inputs"],
self._inference_spec["parameters"])
else:
payload = payload | dict(parameters=self._inference_spec["parameters"])
#import json
#logger.info(json.dumps(payload, indent=2, default=str))
response = self._predictor.predict(payload)
latency = time.perf_counter() - st
if isinstance(response, bytes):
response = response.decode('utf-8')
response_json = json.loads(response)
if isinstance(response_json, list):
response_json = response_json[0]
# add a key called completion, if not there
if response_json.get("generated_text") is None:
if response_json.get("predicted_label") is not None:
response_json["generated_text"] = response_json.get("predicted_label")
# counts the completion tokens for the model using the default/user provided tokenizer
completion_tokens = count_tokens(response_json.get("generated_text"))
except Exception as e:
logger.error(f"get_prediction, exception occurred while getting prediction for payload={payload} "
f"from predictor={self._endpoint_name}, response={response}, exception={e}")
return FMBenchPredictionResponse(response_json=response_json,
latency=latency,
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens)
@property
def endpoint_name(self) -> str:
"""The endpoint name property."""
return self._endpoint_name
def calculate_cost(self,
instance_type: str,
instance_count: int,
pricing: Dict,
duration: float,
prompt_tokens: int,
completion_tokens: int) -> float:
"""Calculate the cost of each experiment run."""
experiment_cost: Optional[float] = None
try:
instance_based_pricing = pricing['pricing']['instance_based']
hourly_rate = instance_based_pricing.get(instance_type, None)
logger.info(f"the hourly rate for running on {instance_type} is {hourly_rate}, "
f"instance_count={instance_count}")
# calculating the experiment cost for instance based pricing
instance_count = instance_count if instance_count else 1
experiment_cost = (hourly_rate / 3600) * duration * instance_count
except Exception as e:
logger.error(f"exception occurred during experiment cost calculation, exception={e}")
return experiment_cost
def get_metrics(self,
start_time: datetime,
end_time: datetime,
period: int = 60) -> pd.DataFrame:
return get_endpoint_metrics(self._endpoint_name, self._variant_name, start_time, end_time)
@property
def inference_parameters(self) -> Dict:
"""The inference parameters property."""
return self._inference_spec.get("parameters")
def create_predictor(endpoint_name: str, inference_spec: Optional[Dict], metadata: Optional[Dict]):
return SageMakerPredictor(endpoint_name, inference_spec, metadata)