Source code for langchain_aws.utils
import re
from typing import Any, List
from packaging import version
[docs]
def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
[docs]
def anthropic_tokens_supported() -> bool:
"""Check if all requirements for Anthropic count_tokens() are met."""
try:
import anthropic
except ImportError:
return False
if version.parse(anthropic.__version__) > version.parse("0.38.0"):
return False
try:
import httpx
if version.parse(httpx.__version__) > version.parse("0.27.2"):
raise ImportError()
except ImportError:
raise ImportError("httpx<=0.27.2 is required.")
return True
def _get_anthropic_client() -> Any:
import anthropic
return anthropic.Anthropic()
[docs]
def get_num_tokens_anthropic(text: str) -> int:
"""Get the number of tokens in a string of text."""
client = _get_anthropic_client()
return client.count_tokens(text=text)
[docs]
def get_token_ids_anthropic(text: str) -> List[int]:
"""Get the token ids for a string of text."""
client = _get_anthropic_client()
tokenizer = client.get_tokenizer()
encoded_text = tokenizer.encode(text)
return encoded_text.ids