"""Semantic Cache for Azure CosmosDB NoSql and Mongo vCore API."""
from __future__ import annotations
import hashlib
import json
import logging
from enum import Enum
from typing import (
Any,
Dict,
List,
Optional,
Type,
Union,
)
from azure.cosmos import CosmosClient
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import Generation
from langchain_azure_ai.vectorstores.azure_cosmos_db_mongo_vcore import (
AzureCosmosDBMongoVCoreVectorSearch,
CosmosDBSimilarityType,
CosmosDBVectorSearchType,
)
from langchain_azure_ai.vectorstores.azure_cosmos_db_no_sql import (
AzureCosmosDBNoSqlVectorSearch,
)
logger = logging.getLogger(__file__)
def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()
def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str:
"""Dump generations to json.
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: Json representing a list of generations.
Warning: would not work well with arbitrary subclasses of `Generation`
"""
return json.dumps([generation.dict() for generation in generations])
def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
"""Load generations from json.
Args:
generations_json (str): A string of json representing a list of generations.
Raises:
ValueError: Could not decode json string to list of generations.
Returns:
RETURN_VAL_TYPE: A list of generations.
Warning: would not work well with arbitrary subclasses of `Generation`
"""
try:
results = json.loads(generations_json)
return [Generation(**generation_dict) for generation_dict in results]
except json.JSONDecodeError:
raise ValueError(
f"Could not decode json to list of generations: {generations_json}"
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`.
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This function (+ its counterpart `_loads_generations`) rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""Deserialization of a string into a generic RETURN_VAL_TYPE.
See `_dumps_generations`, the inverse of this function.
Args:
generations_str (str): A string representing a list of generations.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
[docs]
class AzureCosmosDBMongoVCoreSemanticCache(BaseCache):
"""Cache that uses Cosmos DB Mongo vCore vector-store backend."""
DEFAULT_DATABASE_NAME = "CosmosMongoVCoreCacheDB"
DEFAULT_COLLECTION_NAME = "CosmosMongoVCoreCacheColl"
[docs]
def __init__(
self,
cosmosdb_connection_string: str,
database_name: str,
collection_name: str,
embedding: Embeddings,
*,
cosmosdb_client: Optional[Any] = None,
num_lists: int = 100,
similarity: CosmosDBSimilarityType = CosmosDBSimilarityType.COS,
kind: CosmosDBVectorSearchType = CosmosDBVectorSearchType.VECTOR_IVF,
dimensions: int = 1536,
m: int = 16,
ef_construction: int = 64,
max_degree: int = 32,
l_build: int = 50,
l_search: int = 40,
ef_search: int = 40,
score_threshold: Optional[float] = None,
application_name: str = "LangChainAzure-CDBMongoVCore-SemanticCache-Python",
):
"""AzureCosmosDBMongoVCoreSemanticCache constructor.
Args:
cosmosdb_connection_string: Cosmos DB Mongo vCore connection string
cosmosdb_client: Cosmos DB Mongo vCore client
embedding (Embedding): Embedding provider for semantic encoding and search.
database_name: Database name for the CosmosDBMongoVCoreSemanticCache
collection_name: Collection name for the CosmosDBMongoVCoreSemanticCache
num_lists: This integer is the number of clusters that the
inverted file (IVF) index uses to group the vector data.
We recommend that numLists is set to documentCount/1000
for up to 1 million documents and to sqrt(documentCount)
for more than 1 million documents.
Using a numLists value of 1 is akin to performing
brute-force search, which has limited performance
dimensions: Number of dimensions for vector similarity.
The maximum number of supported dimensions is 2000
similarity: Similarity metric to use with the IVF index.
Possible options are:
- CosmosDBSimilarityType.COS (cosine distance),
- CosmosDBSimilarityType.L2 (Euclidean distance), and
- CosmosDBSimilarityType.IP (inner product).
kind: Type of vector index to create.
Possible options are:
- vector-ivf
- vector-hnsw
- vector-diskann: available as a preview feature only,
to enable visit https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/preview-features
m: The max number of connections per layer (16 by default, minimum
value is 2, maximum value is 100). Higher m is suitable for datasets
with high dimensionality and/or high accuracy requirements.
ef_construction: the size of the dynamic candidate list for constructing
the graph (64 by default, minimum value is 4, maximum
value is 1000). Higher ef_construction will result in
better index quality and higher accuracy, but it will
also increase the time required to build the index.
ef_construction has to be at least 2 * m
ef_search: The size of the dynamic candidate list for search
(40 by default). A higher value provides better
recall at the cost of speed.
max_degree: Max number of neighbors.
Default value is 32, range from 20 to 2048.
Only vector-diskann search supports this for now.
l_build: l value for index building.
Default value is 50, range from 10 to 500.
Only vector-diskann search supports this for now.
l_search: l value for index searching.
Default value is 40, range from 10 to 10000.
Only vector-diskann search supports this.
score_threshold: Maximum score used to filter the vector search documents.
application_name: Application name for the client for tracking and logging
"""
self._validate_enum_value(similarity, CosmosDBSimilarityType)
self._validate_enum_value(kind, CosmosDBVectorSearchType)
if not cosmosdb_connection_string:
raise ValueError(" CosmosDB connection string can be empty.")
self.cosmosdb_connection_string = cosmosdb_connection_string
self.cosmosdb_client = cosmosdb_client
self.embedding = embedding
self.database_name = database_name or self.DEFAULT_DATABASE_NAME
self.collection_name = collection_name or self.DEFAULT_COLLECTION_NAME
self.num_lists = num_lists
self.dimensions = dimensions
self.similarity = similarity
self.kind = kind
self.m = m
self.ef_construction = ef_construction
self.max_degree = max_degree
self.l_build = l_build
self.l_search = l_search
self.ef_search = ef_search
self.score_threshold = score_threshold
self._cache_dict: Dict[str, AzureCosmosDBMongoVCoreVectorSearch] = {}
self.application_name = application_name
def _index_name(self, llm_string: str) -> str:
hashed_index = _hash(llm_string)
return f"cache:{hashed_index}"
def _get_llm_cache(self, llm_string: str) -> AzureCosmosDBMongoVCoreVectorSearch:
index_name = self._index_name(llm_string)
namespace = self.database_name + "." + self.collection_name
# return vectorstore client for the specific llm string
if index_name in self._cache_dict:
return self._cache_dict[index_name]
# create new vectorstore client for the specific llm string
if self.cosmosdb_client:
collection = self.cosmosdb_client[self.database_name][self.collection_name]
self._cache_dict[index_name] = AzureCosmosDBMongoVCoreVectorSearch(
collection=collection,
embedding=self.embedding,
index_name=index_name,
)
else:
self._cache_dict[index_name] = (
AzureCosmosDBMongoVCoreVectorSearch.from_connection_string(
connection_string=self.cosmosdb_connection_string,
namespace=namespace,
embedding=self.embedding,
index_name=index_name,
application_name=self.application_name,
)
)
# create index for the vectorstore
vectorstore = self._cache_dict[index_name]
if not vectorstore.index_exists():
vectorstore.create_index(
self.num_lists,
self.dimensions,
self.similarity,
self.kind,
self.m,
self.ef_construction,
self.max_degree,
self.l_build,
)
return vectorstore
[docs]
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
generations: List = []
# Read from a Hash
results = llm_cache.similarity_search(
query=prompt,
k=1,
kind=self.kind,
ef_search=self.ef_search,
l_search=self.l_search,
score_threshold=self.score_threshold, # type: ignore[arg-type]
)
if results:
for document in results:
try:
generations.extend(loads(document.metadata["return_val"]))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None
[docs]
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"CosmosDBMongoVCoreSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
llm_cache = self._get_llm_cache(llm_string)
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": dumps([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
[docs]
def clear(self, **kwargs: Any) -> None:
"""Clear semantic cache for a given llm_string."""
index_name = self._index_name(kwargs["llm_string"])
if index_name in self._cache_dict:
self._cache_dict[index_name].get_collection().delete_many({})
@staticmethod
def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None:
if not isinstance(value, enum_type):
raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")
[docs]
class AzureCosmosDBNoSqlSemanticCache(BaseCache):
"""Cache that uses Cosmos DB NoSQL backend."""
[docs]
def __init__(
self,
embedding: Embeddings,
cosmos_client: CosmosClient,
database_name: str = "CosmosNoSqlCacheDB",
container_name: str = "CosmosNoSqlCacheContainer",
*,
vector_embedding_policy: Dict[str, Any],
indexing_policy: Dict[str, Any],
cosmos_container_properties: Dict[str, Any],
cosmos_database_properties: Dict[str, Any],
vector_search_fields: Dict[str, Any],
search_type: str = "vector",
create_container: bool = True,
):
"""AzureCosmosDBNoSqlSemanticCache constructor.
Args:
embedding: CosmosDB Embedding.
cosmos_client: CosmosDB client
database_name: CosmosDB database name
container_name: CosmosDB container name
vector_embedding_policy: CosmosDB vector embedding policy
indexing_policy: CosmosDB indexing policy
cosmos_container_properties: CosmosDB container properties
cosmos_database_properties: CosmosDB database properties
vector_search_fields: Vector Search Fields for the container.
search_type: CosmosDB search type.
create_container: Create the container if it doesn't exist.
"""
self.cosmos_client = cosmos_client
self.database_name = database_name
self.container_name = container_name
self.embedding = embedding
self.vector_embedding_policy = vector_embedding_policy
self.indexing_policy = indexing_policy
self.cosmos_container_properties = cosmos_container_properties
self.cosmos_database_properties = cosmos_database_properties
self.vector_search_fields = vector_search_fields
self.search_type = search_type
self.create_container = create_container
self._cache_dict: Dict[str, AzureCosmosDBNoSqlVectorSearch] = {}
def _cache_name(self, llm_string: str) -> str:
hashed_index = _hash(llm_string)
return f"cache:{hashed_index}"
def _get_llm_cache(self, llm_string: str) -> AzureCosmosDBNoSqlVectorSearch:
cache_name = self._cache_name(llm_string)
# return vectorstore client for the specific llm string
if cache_name in self._cache_dict:
return self._cache_dict[cache_name]
# create new vectorstore client to create the cache
if self.cosmos_client:
self._cache_dict[cache_name] = AzureCosmosDBNoSqlVectorSearch(
cosmos_client=self.cosmos_client,
embedding=self.embedding,
vector_embedding_policy=self.vector_embedding_policy,
indexing_policy=self.indexing_policy,
cosmos_container_properties=self.cosmos_container_properties,
cosmos_database_properties=self.cosmos_database_properties,
database_name=self.database_name,
container_name=self.container_name,
search_type=self.search_type,
vector_search_fields=self.vector_search_fields,
create_container=self.create_container,
)
return self._cache_dict[cache_name]
[docs]
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt."""
llm_cache = self._get_llm_cache(llm_string)
generations: List = []
# Read from a Hash
results = llm_cache.similarity_search(
query=prompt,
k=1,
)
if results:
for document in results:
try:
generations.extend(loads(document.metadata["return_val"]))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None
[docs]
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"CosmosDBNoSqlSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
llm_cache = self._get_llm_cache(llm_string)
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": dumps([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
[docs]
def clear(self, **kwargs: Any) -> None:
"""Clear semantic cache for a given llm_string."""
cache_name = self._cache_name(llm_string=kwargs["llm_string"])
if cache_name in self._cache_dict:
self.cosmos_client.delete_database(database=self.database_name)