Source code for langchain_aws.document_compressors.rerank

from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

import boto3
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import from_env
from pydantic import ConfigDict, Field, model_validator


[docs] class BedrockRerank(BaseDocumentCompressor): """Document compressor that uses AWS Bedrock Rerank API.""" model_arn: str """The ARN of the reranker model.""" client: Any = None """Bedrock client to use for compressing documents.""" top_n: Optional[int] = 3 """Number of documents to return.""" region_name: str = Field( default_factory=from_env("AWS_DEFAULT_REGION", default=None) ) """AWS region to initialize the Bedrock client.""" credentials_profile_name: Optional[str] = Field( default_factory=from_env("AWS_PROFILE", default=None) ) """AWS profile for authentication, optional.""" model_config = ConfigDict( extra="forbid", arbitrary_types_allowed=True, ) @model_validator(mode="before") @classmethod def initialize_client(cls, values: Dict[str, Any]) -> Any: """Initialize the AWS Bedrock client.""" if not values.get("client"): session = ( boto3.Session(profile_name=values.get("credentials_profile_name")) if values.get("credentials_profile_name", None) else boto3.Session() ) values["client"] = session.client( "bedrock-agent-runtime", region_name=values.get("region_name"), ) return values
[docs] def rerank( self, documents: Sequence[Union[str, Document, dict]], query: str, top_n: Optional[int] = None, additional_model_request_fields: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Returns an ordered list of documents based on their relevance to the query. Args: query: The query to use for reranking. documents: A sequence of documents to rerank. top_n: The number of top-ranked results to return. Defaults to self.top_n. additional_model_request_fields: Additional fields to pass to the model. Returns: List[Dict[str, Any]]: A list of ranked documents with relevance scores. """ if len(documents) == 0: return [] # Serialize documents for the Bedrock API serialized_documents = [ {"textDocument": {"text": doc.page_content}, "type": "TEXT"} if isinstance(doc, Document) else {"textDocument": {"text": doc}, "type": "TEXT"} if isinstance(doc, str) else {"jsonDocument": doc, "type": "JSON"} for doc in documents ] request_body = { "queries": [{"textQuery": {"text": query}, "type": "TEXT"}], "rerankingConfiguration": { "bedrockRerankingConfiguration": { "modelConfiguration": { "modelArn": self.model_arn, "additionalModelRequestFields": additional_model_request_fields or {}, }, "numberOfResults": top_n or self.top_n, }, "type": "BEDROCK_RERANKING_MODEL", }, "sources": [ {"inlineDocumentSource": doc, "type": "INLINE"} for doc in serialized_documents ], } response = self.client.rerank(**request_body) response_body = response.get("results", []) results = [ {"index": result["index"], "relevance_score": result["relevanceScore"]} for result in response_body ] return results
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using Bedrock's rerank API. Args: documents: A sequence of documents to compress. query: The query to use for compressing the documents. callbacks: Callbacks to run during the compression process. Returns: A sequence of compressed documents. """ compressed = [] for res in self.rerank(documents, query): doc = documents[res["index"]] doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) doc_copy.metadata["relevance_score"] = res["relevance_score"] compressed.append(doc_copy) return compressed