import os
import re
from abc import abstractmethod
from typing import Any, Dict, Generic, Iterator, List, Literal, Optional, TypeVar, Union
from botocore.exceptions import BotoCoreError, UnknownServiceError
from packaging import version
from pydantic import SecretStr
MESSAGE_ROLES = Literal["system", "user", "assistant"]
MESSAGE_FORMAT = Dict[Literal["role", "content"], Union[MESSAGE_ROLES, str]]
INPUT_TYPE = TypeVar(
"INPUT_TYPE", bound=Union[str, List[str], MESSAGE_FORMAT, List[MESSAGE_FORMAT]]
)
OUTPUT_TYPE = TypeVar(
"OUTPUT_TYPE",
bound=Union[str, List[List[float]], MESSAGE_FORMAT, List[MESSAGE_FORMAT], Iterator],
)
[docs]
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
"""A handler class to transform input from LLM and BaseChatModel to a
format that SageMaker endpoint expects.
Similarly, the class handles transforming output from the
SageMaker endpoint to a format that LLM & BaseChatModel class expects.
"""
"""
Example:
.. code-block:: python
class ContentHandler(ContentHandlerBase):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
"""
content_type: Optional[str] = "text/plain"
"""The MIME type of the input data passed to endpoint"""
accepts: Optional[str] = "text/plain"
"""The MIME type of the response data returned from endpoint"""
[docs]
@abstractmethod
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
"""Transforms the input to a format that model can accept
as the request Body. Should return bytes or seekable file
like object in the format specified in the content_type
request header.
"""
[docs]
@abstractmethod
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
"""Transforms the output from the model to string that
the LLM class expects.
"""
[docs]
def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
[docs]
def anthropic_tokens_supported() -> bool:
"""Check if all requirements for Anthropic count_tokens() are met."""
try:
import anthropic
except ImportError:
return False
if version.parse(anthropic.__version__) > version.parse("0.38.0"):
return False
try:
import httpx
if version.parse(httpx.__version__) > version.parse("0.27.2"):
raise ImportError()
except ImportError:
raise ImportError("httpx<=0.27.2 is required.")
return True
def _get_anthropic_client() -> Any:
import anthropic
return anthropic.Anthropic()
[docs]
def get_num_tokens_anthropic(text: str) -> int:
"""Get the number of tokens in a string of text."""
client = _get_anthropic_client()
return client.count_tokens(text=text)
[docs]
def get_token_ids_anthropic(text: str) -> List[int]:
"""Get the token ids for a string of text."""
client = _get_anthropic_client()
tokenizer = client.get_tokenizer()
encoded_text = tokenizer.encode(text)
return encoded_text.ids
[docs]
def create_aws_client(
service_name: str,
region_name: Optional[str] = None,
credentials_profile_name: Optional[str] = None,
aws_access_key_id: Optional[SecretStr] = None,
aws_secret_access_key: Optional[SecretStr] = None,
aws_session_token: Optional[SecretStr] = None,
endpoint_url: Optional[str] = None,
config: Any = None,
):
"""Helper function to validate AWS credentials and create an AWS client.
Args:
service_name: The name of the AWS service to create a client for.
region_name: AWS region name. If not provided, will try to get from environment variables.
credentials_profile_name: The name of the AWS credentials profile to use.
aws_access_key_id: AWS access key ID.
aws_secret_access_key: AWS secret access key.
aws_session_token: AWS session token.
endpoint_url: The complete URL to use for the constructed client.
config: Advanced client configuration options.
Returns:
boto3.client: An AWS service client instance.
"""
try:
import boto3
region_name = (
region_name
or os.getenv("AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
)
client_params = {
"service_name": service_name,
"region_name": region_name,
"endpoint_url": endpoint_url,
"config": config,
}
client_params = {
k: v for k, v in client_params.items() if v
}
needs_session = bool(
credentials_profile_name or
aws_access_key_id or
aws_secret_access_key or
aws_session_token
)
if not needs_session:
return boto3.client(**client_params)
if credentials_profile_name:
session = boto3.Session(profile_name=credentials_profile_name)
elif aws_access_key_id and aws_secret_access_key:
session_params = {
"aws_access_key_id": aws_access_key_id.get_secret_value(),
"aws_secret_access_key": aws_secret_access_key.get_secret_value(),
}
if aws_session_token:
session_params["aws_session_token"] = aws_session_token.get_secret_value()
session = boto3.Session(**session_params)
else:
raise ValueError(
"If providing credentials, both aws_access_key_id and "
"aws_secret_access_key must be specified."
)
if not client_params.get("region_name") and session.region_name:
client_params["region_name"] = session.region_name
return session.client(**client_params)
except UnknownServiceError as e:
raise ModuleNotFoundError(
f"Ensure that you have installed the latest boto3 package "
f"that contains the API for `{service_name}`."
) from e
except BotoCoreError as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that the specified profile name and/or its credentials are valid. "
f"Service error: {e}"
) from e
except Exception as e:
raise ValueError(f"Error raised by service:\n\n{e}") from e
[docs]
def thinking_in_params(params: dict) -> bool:
"""Check if the thinking parameter is enabled in the request."""
return params.get("thinking", {}).get("type") == "enabled"