Source code for langchain_aws.utils

import re
from typing import Any, List

from packaging import version


[docs] def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text, maxsplit=1)[0]
[docs] def anthropic_tokens_supported() -> bool: """Check if all requirements for Anthropic count_tokens() are met.""" try: import anthropic except ImportError: return False if version.parse(anthropic.__version__) > version.parse("0.38.0"): return False try: import httpx if version.parse(httpx.__version__) > version.parse("0.27.2"): raise ImportError() except ImportError: raise ImportError("httpx<=0.27.2 is required.") return True
def _get_anthropic_client() -> Any: import anthropic return anthropic.Anthropic()
[docs] def get_num_tokens_anthropic(text: str) -> int: """Get the number of tokens in a string of text.""" client = _get_anthropic_client() return client.count_tokens(text=text)
[docs] def get_token_ids_anthropic(text: str) -> List[int]: """Get the token ids for a string of text.""" client = _get_anthropic_client() tokenizer = client.get_tokenizer() encoded_text = tokenizer.encode(text) return encoded_text.ids