Source code for langchain_core.language_models.fake_chat_models

"""Fake ChatModel for testing purposes."""

import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, Union, cast

from typing_extensions import override

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig


[docs] class FakeMessagesListChatModel(BaseChatModel): """Fake ChatModel for testing purposes.""" responses: list[BaseMessage] """List of responses to **cycle** through in order.""" sleep: Optional[float] = None """Sleep time in seconds between responses.""" i: int = 0 """Internally incremented after every model invocation.""" @override def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.sleep is not None: time.sleep(self.sleep) response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 generation = ChatGeneration(message=response) return ChatResult(generations=[generation]) @property @override def _llm_type(self) -> str: return "fake-messages-list-chat-model"
[docs] class FakeListChatModelError(Exception): """Fake error for testing purposes."""
[docs] class FakeListChatModel(SimpleChatModel): """Fake ChatModel for testing purposes.""" responses: list[str] """List of responses to **cycle** through in order.""" sleep: Optional[float] = None i: int = 0 """Internally incremented after every model invocation.""" error_on_chunk_number: Optional[int] = None """If set, raise an error on the specified chunk number during streaming.""" @property @override def _llm_type(self) -> str: return "fake-list-chat-model" @override def _call( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" if self.sleep is not None: time.sleep(self.sleep) response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 return response @override def _stream( self, messages: list[BaseMessage], stop: Union[list[str], None] = None, run_manager: Union[CallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 for i_c, c in enumerate(response): if self.sleep is not None: time.sleep(self.sleep) if ( self.error_on_chunk_number is not None and i_c == self.error_on_chunk_number ): raise FakeListChatModelError yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @override async def _astream( self, messages: list[BaseMessage], stop: Union[list[str], None] = None, run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 for i_c, c in enumerate(response): if self.sleep is not None: await asyncio.sleep(self.sleep) if ( self.error_on_chunk_number is not None and i_c == self.error_on_chunk_number ): raise FakeListChatModelError yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property @override def _identifying_params(self) -> dict[str, Any]: return {"responses": self.responses}
[docs] @override # manually override batch to preserve batch ordering with no concurrency def batch( self, inputs: list[Any], config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> list[BaseMessage]: if isinstance(config, list): return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] return [self.invoke(m, config, **kwargs) for m in inputs]
[docs] @override async def abatch( self, inputs: list[Any], config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> list[BaseMessage]: if isinstance(config, list): # do Not use an async iterator here because need explicit ordering return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)] # do Not use an async iterator here because need explicit ordering return [await self.ainvoke(m, config, **kwargs) for m in inputs]
[docs] class FakeChatModel(SimpleChatModel): """Fake Chat Model wrapper for testing purposes.""" @override def _call( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: return "fake response" @override async def _agenerate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: output_str = "fake response" message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @property def _llm_type(self) -> str: return "fake-chat-model" @property def _identifying_params(self) -> dict[str, Any]: return {"key": "fake"}
[docs] class GenericFakeChatModel(BaseChatModel): """Generic fake chat model that can be used to test the chat model interface. * Chat model should be usable in both sync and async tests * Invokes on_llm_new_token to allow for testing of callback related code for new tokens. * Includes logic to break messages into message chunk to facilitate testing of streaming. """ messages: Iterator[Union[AIMessage, str]] """Get an iterator over messages. This can be expanded to accept other types like Callables / dicts / strings to make the interface more generic if needed. .. note:: if you want to pass a list, you can use ``iter`` to convert it to an iterator. .. warning:: Streaming is not implemented yet. We should try to implement it in the future by delegating to invoke and then breaking the resulting output into message chunks. """ @override def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" message = next(self.messages) message_ = AIMessage(content=message) if isinstance(message, str) else message generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) def _stream( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model.""" chat_result = self._generate( messages, stop=stop, run_manager=run_manager, **kwargs ) if not isinstance(chat_result, ChatResult): msg = ( f"Expected generate to return a ChatResult, " f"but got {type(chat_result)} instead." ) raise ValueError(msg) # noqa: TRY004 message = chat_result.generations[0].message if not isinstance(message, AIMessage): msg = ( f"Expected invoke to return an AIMessage, " f"but got {type(message)} instead." ) raise ValueError(msg) # noqa: TRY004 content = message.content if content: # Use a regular expression to split on whitespace with a capture group # so that we can preserve the whitespace in the output. if not isinstance(content, str): msg = "Expected content to be a string." raise ValueError(msg) content_chunks = cast("list[str]", re.split(r"(\s)", content)) for token in content_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk(content=token, id=message.id) ) if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) yield chunk if message.additional_kwargs: for key, value in message.additional_kwargs.items(): # We should further break down the additional kwargs into chunks # Special case for function call if key == "function_call": for fkey, fvalue in value.items(): if isinstance(fvalue, str): # Break function call by `,` fvalue_chunks = cast("list[str]", re.split(r"(,)", fvalue)) for fvalue_chunk in fvalue_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk( id=message.id, content="", additional_kwargs={ "function_call": {fkey: fvalue_chunk} }, ) ) if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) yield chunk else: chunk = ChatGenerationChunk( message=AIMessageChunk( id=message.id, content="", additional_kwargs={"function_call": {fkey: fvalue}}, ) ) if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) yield chunk else: chunk = ChatGenerationChunk( message=AIMessageChunk( id=message.id, content="", additional_kwargs={key: value} ) ) if run_manager: run_manager.on_llm_new_token( "", chunk=chunk, # No token for function call ) yield chunk @property def _llm_type(self) -> str: return "generic-fake-chat-model"
[docs] class ParrotFakeChatModel(BaseChatModel): """Generic fake chat model that can be used to test the chat model interface. * Chat model should be usable in both sync and async tests """ @override def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" return ChatResult(generations=[ChatGeneration(message=messages[-1])]) @property def _llm_type(self) -> str: return "parrot-fake-chat-model"