Source code for langchain_aws.utils

import os
import re
from abc import abstractmethod
from typing import Any, Dict, Generic, Iterator, List, Literal, Optional, TypeVar, Union

from botocore.exceptions import BotoCoreError, UnknownServiceError
from langchain_core.messages import AIMessage
from packaging import version
from pydantic import SecretStr

MESSAGE_ROLES = Literal["system", "user", "assistant"]
MESSAGE_FORMAT = Dict[Literal["role", "content"], Union[MESSAGE_ROLES, str]]

INPUT_TYPE = TypeVar(
    "INPUT_TYPE", bound=Union[str, List[str], MESSAGE_FORMAT, List[MESSAGE_FORMAT]]
)
OUTPUT_TYPE = TypeVar(
    "OUTPUT_TYPE",
    bound=Union[str, List[List[float]], MESSAGE_FORMAT, List[MESSAGE_FORMAT], Iterator],
)


[docs] class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): """A handler class to transform input from LLM and BaseChatModel to a format that SageMaker endpoint expects. Similarly, the class handles transforming output from the SageMaker endpoint to a format that LLM & BaseChatModel class expects. """ """ Example: .. code-block:: python class ContentHandler(ContentHandlerBase): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: input_str = json.dumps({prompt: prompt, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) return response_json[0]["generated_text"] """ content_type: Optional[str] = "text/plain" """The MIME type of the input data passed to endpoint""" accepts: Optional[str] = "text/plain" """The MIME type of the response data returned from endpoint"""
[docs] @abstractmethod def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes: """Transforms the input to a format that model can accept as the request Body. Should return bytes or seekable file like object in the format specified in the content_type request header. """
[docs] @abstractmethod def transform_output(self, output: bytes) -> OUTPUT_TYPE: """Transforms the output from the model to string that the LLM class expects. """
[docs] def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text, maxsplit=1)[0]
[docs] def anthropic_tokens_supported() -> bool: """Check if all requirements for Anthropic count_tokens() are met.""" try: import anthropic except ImportError: return False if version.parse(anthropic.__version__) > version.parse("0.38.0"): return False try: import httpx if version.parse(httpx.__version__) > version.parse("0.27.2"): raise ImportError() except ImportError: raise ImportError("httpx<=0.27.2 is required.") return True
def _get_anthropic_client() -> Any: import anthropic return anthropic.Anthropic()
[docs] def get_num_tokens_anthropic(text: str) -> int: """Get the number of tokens in a string of text.""" client = _get_anthropic_client() return client.count_tokens(text=text)
[docs] def get_token_ids_anthropic(text: str) -> List[int]: """Get the token ids for a string of text.""" client = _get_anthropic_client() tokenizer = client.get_tokenizer() encoded_text = tokenizer.encode(text) return encoded_text.ids
[docs] def create_aws_client( service_name: str, region_name: Optional[str] = None, credentials_profile_name: Optional[str] = None, aws_access_key_id: Optional[SecretStr] = None, aws_secret_access_key: Optional[SecretStr] = None, aws_session_token: Optional[SecretStr] = None, endpoint_url: Optional[str] = None, config: Any = None, ): """Helper function to validate AWS credentials and create an AWS client. Args: service_name: The name of the AWS service to create a client for. region_name: AWS region name. If not provided, will try to get from environment variables. credentials_profile_name: The name of the AWS credentials profile to use. aws_access_key_id: AWS access key ID. aws_secret_access_key: AWS secret access key. aws_session_token: AWS session token. endpoint_url: The complete URL to use for the constructed client. config: Advanced client configuration options. Returns: boto3.client: An AWS service client instance. """ try: import boto3 region_name = ( region_name or os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") ) client_params = { "service_name": service_name, "region_name": region_name, "endpoint_url": endpoint_url, "config": config, } client_params = { k: v for k, v in client_params.items() if v } needs_session = bool( credentials_profile_name or aws_access_key_id or aws_secret_access_key or aws_session_token ) if not needs_session: return boto3.client(**client_params) if credentials_profile_name: session = boto3.Session(profile_name=credentials_profile_name) elif aws_access_key_id and aws_secret_access_key: session_params = { "aws_access_key_id": aws_access_key_id.get_secret_value(), "aws_secret_access_key": aws_secret_access_key.get_secret_value(), } if aws_session_token: session_params["aws_session_token"] = aws_session_token.get_secret_value() session = boto3.Session(**session_params) else: raise ValueError( "If providing credentials, both aws_access_key_id and " "aws_secret_access_key must be specified." ) if not client_params.get("region_name") and session.region_name: client_params["region_name"] = session.region_name return session.client(**client_params) except UnknownServiceError as e: raise ModuleNotFoundError( f"Ensure that you have installed the latest boto3 package " f"that contains the API for `{service_name}`." ) from e except BotoCoreError as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that the specified profile name and/or its credentials are valid. " f"Service error: {e}" ) from e except Exception as e: raise ValueError(f"Error raised by service:\n\n{e}") from e
[docs] def thinking_in_params(params: dict) -> bool: """Check if the thinking parameter is enabled in the request.""" return params.get("thinking", {}).get("type") == "enabled"
[docs] def trim_message_whitespace(messages: List[Any]) -> List[Any]: """Trim trailing whitespace from final AIMessage content.""" if not messages or not isinstance(messages[-1], AIMessage): return messages last_message = messages[-1] if isinstance(last_message.content, str): trimmed = last_message.content.rstrip() if trimmed != last_message.content: last_message.content = trimmed elif isinstance(last_message.content, list): for j, block in enumerate(last_message.content): if isinstance(block, dict) and block.get("type") == "text" \ and isinstance(block.get("text"), str): trimmed = block["text"].rstrip() if trimmed != block["text"]: last_message.content[j]["text"] = trimmed return messages