from __future__ import annotations
import asyncio
import base64
import json
import logging
import mimetypes
import uuid
import warnings
from difflib import get_close_matches
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
import filetype # type: ignore[import]
import google.api_core
# TODO: remove ignore once the google package is published with types
import proto # type: ignore[import]
from google.ai.generativelanguage_v1beta import (
GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
)
from google.ai.generativelanguage_v1beta.types import (
Blob,
Candidate,
CodeExecution,
Content,
FileData,
FunctionCall,
FunctionDeclaration,
FunctionResponse,
GenerateContentRequest,
GenerateContentResponse,
GenerationConfig,
Part,
SafetySetting,
ToolConfig,
VideoMetadata,
)
from google.ai.generativelanguage_v1beta.types import Tool as GoogleTool
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
is_data_content_block,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
parse_tool_calls,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableConfig, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from typing_extensions import Self, is_typeddict
from langchain_google_genai._common import (
GoogleGenerativeAIError,
SafetySettingDict,
_BaseGoogleGenerativeAI,
get_client_info,
)
from langchain_google_genai._function_utils import (
_tool_choice_to_tool_config,
_ToolChoiceType,
_ToolConfigDict,
_ToolDict,
convert_to_genai_function_declarations,
is_basemodel_subclass_safe,
tool_to_dict,
)
from langchain_google_genai._image_utils import (
ImageBytesLoader,
image_bytes_to_b64_string,
)
from . import _genai_extension as genaix
logger = logging.getLogger(__name__)
_FunctionDeclarationType = Union[
FunctionDeclaration,
dict[str, Any],
Callable[..., Any],
]
[docs]
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
"""
Custom exception class for errors associated with the `Google GenAI` API.
This exception is raised when there are specific issues related to the
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
message types or roles.
"""
def _create_retry_decorator() -> Callable[[Any], Any]:
"""
Creates and returns a preconfigured tenacity retry decorator.
The retry decorator is configured to handle specific Google API exceptions
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
backoff strategy for retries.
Returns:
Callable[[Any], Any]: A retry decorator configured for handling specific
Google API exceptions.
"""
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 2
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
This function is a wrapper that applies a retry mechanism to a provided
chat generation function. It is useful for handling intermittent issues
like network errors or temporary service unavailability.
Args:
generation_method (Callable): The chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
try:
return generation_method(**kwargs)
# Do not retry for these errors.
except google.api_core.exceptions.FailedPrecondition as exc:
if "location is not supported" in exc.message:
error_msg = (
"Your location is not supported by google-generativeai "
"at the moment. Try to use ChatVertexAI LLM from "
"langchain_google_vertexai."
)
raise ValueError(error_msg)
except google.api_core.exceptions.InvalidArgument as e:
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e
return _chat_with_retry(**kwargs)
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
This function is a wrapper that applies a retry mechanism to a provided
chat generation function. It is useful for handling intermittent issues
like network errors or temporary service unavailability.
Args:
generation_method (Callable): The chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
from google.api_core.exceptions import InvalidArgument # type: ignore
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
try:
return await generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e
return await _achat_with_retry(**kwargs)
def _is_lc_content_block(part: dict) -> bool:
return "type" in part
def _convert_to_parts(
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[Part]:
"""Converts a list of LangChain messages into a google parts."""
parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
image_loader = ImageBytesLoader()
for part in content:
if isinstance(part, str):
parts.append(Part(text=part))
elif isinstance(part, Mapping):
if _is_lc_content_block(part):
if part["type"] == "text":
parts.append(Part(text=part["text"]))
elif is_data_content_block(part):
if part["source_type"] == "url":
bytes_ = image_loader._bytes_from_url(part["url"])
elif part["source_type"] == "base64":
bytes_ = base64.b64decode(part["data"])
else:
raise ValueError("source_type must be url or base64.")
inline_data: dict = {"data": bytes_}
if "mime_type" in part:
inline_data["mime_type"] = part["mime_type"]
else:
source = cast(str, part.get("url") or part.get("data"))
mime_type, _ = mimetypes.guess_type(source)
if not mime_type:
kind = filetype.guess(bytes_)
if kind:
mime_type = kind.mime
if mime_type:
inline_data["mime_type"] = mime_type
parts.append(Part(inline_data=inline_data))
elif part["type"] == "image_url":
img_url = part["image_url"]
if isinstance(img_url, dict):
if "url" not in img_url:
raise ValueError(
f"Unrecognized message image format: {img_url}"
)
img_url = img_url["url"]
parts.append(image_loader.load_part(img_url))
# Handle media type like LangChain.js
# https://github.com/langchain-ai/langchainjs/blob/e536593e2585f1dd7b0afc187de4d07cb40689ba/libs/langchain-google-common/src/utils/gemini.ts#L93-L106
elif part["type"] == "media":
if "mime_type" not in part:
raise ValueError(f"Missing mime_type in media part: {part}")
mime_type = part["mime_type"]
media_part = Part()
if "data" in part:
media_part.inline_data = Blob(
data=part["data"], mime_type=mime_type
)
elif "file_uri" in part:
media_part.file_data = FileData(
file_uri=part["file_uri"], mime_type=mime_type
)
else:
raise ValueError(
f"Media part must have either data or file_uri: {part}"
)
if "video_metadata" in part:
metadata = VideoMetadata(part["video_metadata"])
media_part.video_metadata = metadata
parts.append(media_part)
else:
raise ValueError(
f"Unrecognized message part type: {part['type']}. Only text, "
f"image_url, and media types are supported."
)
else:
# Yolo
logger.warning(
"Unrecognized message part format. Assuming it's a text part."
)
parts.append(Part(text=str(part)))
else:
# TODO: Maybe some of Google's native stuff
# would hit this branch.
raise ChatGoogleGenerativeAIError(
"Gemini only supports text and inline_data parts."
)
return parts
def _convert_tool_message_to_part(
message: ToolMessage | FunctionMessage, name: Optional[str] = None
) -> Part:
"""Converts a tool or function message to a google part."""
# Legacy agent stores tool name in message.additional_kwargs instead of message.name
name = message.name or name or message.additional_kwargs.get("name")
response: Any
if not isinstance(message.content, str):
response = message.content
else:
try:
response = json.loads(message.content)
except json.JSONDecodeError:
response = message.content # leave as str representation
part = Part(
function_response=FunctionResponse(
name=name,
response=(
{"output": response} if not isinstance(response, dict) else response
),
)
)
return part
def _get_ai_message_tool_messages_parts(
tool_messages: Sequence[ToolMessage], ai_message: AIMessage
) -> list[Part]:
"""
Finds relevant tool messages for the AI message and converts them to a single
list of Parts.
"""
# We are interested only in the tool messages that are part of the AI message
tool_calls_ids = {tool_call["id"]: tool_call for tool_call in ai_message.tool_calls}
parts = []
for i, message in enumerate(tool_messages):
if not tool_calls_ids:
break
if message.tool_call_id in tool_calls_ids:
tool_call = tool_calls_ids[message.tool_call_id]
part = _convert_tool_message_to_part(message, name=tool_call.get("name"))
parts.append(part)
# remove the id from the dict, so that we do not iterate over it again
tool_calls_ids.pop(message.tool_call_id)
return parts
def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> Tuple[Optional[Content], List[Content]]:
messages: List[Content] = []
if convert_system_message_to_human:
warnings.warn("Convert_system_message_to_human will be deprecated!")
system_instruction: Optional[Content] = None
messages_without_tool_messages = [
message for message in input_messages if not isinstance(message, ToolMessage)
]
tool_messages = [
message for message in input_messages if isinstance(message, ToolMessage)
]
for i, message in enumerate(messages_without_tool_messages):
if isinstance(message, SystemMessage):
system_parts = _convert_to_parts(message.content)
if i == 0:
system_instruction = Content(parts=system_parts)
elif system_instruction is not None:
system_instruction.parts.extend(system_parts)
else:
pass
continue
elif isinstance(message, AIMessage):
role = "model"
if message.tool_calls:
ai_message_parts = []
for tool_call in message.tool_calls:
function_call = FunctionCall(
{
"name": tool_call["name"],
"args": tool_call["args"],
}
)
ai_message_parts.append(Part(function_call=function_call))
tool_messages_parts = _get_ai_message_tool_messages_parts(
tool_messages=tool_messages, ai_message=message
)
messages.append(Content(role=role, parts=ai_message_parts))
messages.append(Content(role="user", parts=tool_messages_parts))
continue
elif raw_function_call := message.additional_kwargs.get("function_call"):
function_call = FunctionCall(
{
"name": raw_function_call["name"],
"args": json.loads(raw_function_call["arguments"]),
}
)
parts = [Part(function_call=function_call)]
else:
parts = _convert_to_parts(message.content)
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message.content)
if i == 1 and convert_system_message_to_human and system_instruction:
parts = [p for p in system_instruction.parts] + parts
system_instruction = None
elif isinstance(message, FunctionMessage):
role = "user"
parts = [_convert_tool_message_to_part(message)]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
messages.append(Content(role=role, parts=parts))
return system_instruction, messages
def _parse_response_candidate(
response_candidate: Candidate, streaming: bool = False
) -> AIMessage:
content: Union[None, str, List[Union[str, dict]]] = None
additional_kwargs = {}
tool_calls = []
invalid_tool_calls = []
tool_call_chunks = []
for part in response_candidate.content.parts:
try:
text: Optional[str] = part.text
# Remove erroneous newline character if present
if text is not None:
text = text.rstrip("\n")
except AttributeError:
text = None
if text is not None:
if not content:
content = text
elif isinstance(content, str) and text:
content = [content, text]
elif isinstance(content, list) and text:
content.append(text)
elif text:
raise Exception("Unexpected content type")
if hasattr(part, "executable_code") and part.executable_code is not None:
if part.executable_code.code and part.executable_code.language:
code_message = {
"type": "executable_code",
"executable_code": part.executable_code.code,
"language": part.executable_code.language,
}
if not content:
content = [code_message]
elif isinstance(content, str):
content = [content, code_message]
elif isinstance(content, list):
content.append(code_message)
else:
raise Exception("Unexpected content type")
if (
hasattr(part, "code_execution_result")
and part.code_execution_result is not None
):
if part.code_execution_result.output:
execution_result = {
"type": "code_execution_result",
"code_execution_result": part.code_execution_result.output,
}
if not content:
content = [execution_result]
elif isinstance(content, str):
content = [content, execution_result]
elif isinstance(content, list):
content.append(execution_result)
else:
raise Exception("Unexpected content type")
if part.inline_data.mime_type.startswith("image/"):
image_format = part.inline_data.mime_type[6:]
message = {
"type": "image_url",
"image_url": {
"url": image_bytes_to_b64_string(
part.inline_data.data, image_format=image_format
)
},
}
if not content:
content = [message]
elif isinstance(content, str) and message:
content = [content, message]
elif isinstance(content, list) and message:
content.append(message)
elif message:
raise Exception("Unexpected content type")
if part.function_call:
function_call = {"name": part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
if streaming:
tool_call_chunks.append(
tool_call_chunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
index=function_call.get("index"), # type: ignore
)
)
else:
try:
tool_call_dict = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)[0]
except Exception as e:
invalid_tool_calls.append(
invalid_tool_call(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
)
)
else:
tool_calls.append(
tool_call(
name=tool_call_dict["name"],
args=tool_call_dict["args"],
id=tool_call_dict.get("id", str(uuid.uuid4())),
)
)
if content is None:
content = ""
if any(isinstance(item, dict) and "executable_code" in item for item in content):
warnings.warn(
"""
⚠️ Warning: Output may vary each run.
- 'executable_code': Always present.
- 'execution_result' & 'image_url': May be absent for some queries.
Validate before using in production.
"""
)
if streaming:
return AIMessageChunk(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
return AIMessage(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
def _response_to_result(
response: GenerateContentResponse,
stream: bool = False,
prev_usage: Optional[UsageMetadata] = None,
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
# previous usage metadata needs to be subtracted because gemini api returns
# already-accumulated token counts with each chunk
prev_input_tokens = prev_usage["input_tokens"] if prev_usage else 0
prev_output_tokens = prev_usage["output_tokens"] if prev_usage else 0
prev_total_tokens = prev_usage["total_tokens"] if prev_usage else 0
# Get usage metadata
try:
input_tokens = response.usage_metadata.prompt_token_count
output_tokens = response.usage_metadata.candidates_token_count
total_tokens = response.usage_metadata.total_token_count
cache_read_tokens = response.usage_metadata.cached_content_token_count
if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
lc_usage = UsageMetadata(
input_tokens=input_tokens - prev_input_tokens,
output_tokens=output_tokens - prev_output_tokens,
total_tokens=total_tokens - prev_total_tokens,
input_token_details={"cache_read": cache_read_tokens},
)
else:
lc_usage = None
except AttributeError:
lc_usage = None
generations: List[ChatGeneration] = []
for candidate in response.candidates:
generation_info = {}
if candidate.finish_reason:
generation_info["finish_reason"] = candidate.finish_reason.name
# Add model_name in last chunk
generation_info["model_name"] = response.model_version
generation_info["safety_ratings"] = [
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
for safety_rating in candidate.safety_ratings
]
message = _parse_response_candidate(candidate, streaming=stream)
message.usage_metadata = lc_usage
if stream:
generations.append(
ChatGenerationChunk(
message=cast(AIMessageChunk, message),
generation_info=generation_info,
)
)
else:
generations.append(
ChatGeneration(message=message, generation_info=generation_info)
)
if not response.candidates:
# Likely a "prompt feedback" violation (e.g., toxic input)
# Raising an error would be different than how OpenAI handles it,
# so we'll just log a warning and continue with an empty message.
logger.warning(
"Gemini produced an empty response. Continuing with empty message\n"
f"Feedback: {response.prompt_feedback}"
)
if stream:
generations = [
ChatGenerationChunk(
message=AIMessageChunk(content=""), generation_info={}
)
]
else:
generations = [ChatGeneration(message=AIMessage(""), generation_info={})]
return ChatResult(generations=generations, llm_output=llm_output)
def _is_event_loop_running() -> bool:
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False
[docs]
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
"""`Google AI` chat models integration.
Instantiation:
To use, you must have either:
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg
to the ChatGoogleGenerativeAI constructor.
.. code-block:: python
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro")
llm.invoke("Write me a ballad about LangChain")
Invoke:
.. code-block:: python
messages = [
("system", "Translate the user sentence to French."),
("human", "I love programming."),
]
llm.invoke(messages)
.. code-block:: python
AIMessage(
content="J'adore programmer. \\n",
response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
id='run-56cecc34-2e54-4b52-a974-337e47008ad2-0',
usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
)
Stream:
.. code-block:: python
for chunk in llm.stream(messages):
print(chunk)
.. code-block:: python
AIMessageChunk(content='J', response_metadata={'finish_reason': 'STOP', 'safety_ratings': []}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 1, 'total_tokens': 19})
AIMessageChunk(content="'adore programmer. \\n", response_metadata={'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23})
.. code-block:: python
stream = llm.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
full
.. code-block:: python
AIMessageChunk(
content="J'adore programmer. \\n",
response_metadata={'finish_reason': 'STOPSTOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
id='run-3ce13a42-cd30-4ad7-a684-f1f0b37cdeec',
usage_metadata={'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}
)
Async:
.. code-block:: python
await llm.ainvoke(messages)
# stream:
# async for chunk in (await llm.astream(messages))
# batch:
# await llm.abatch([messages])
Context Caching:
Context caching allows you to store and reuse content (e.g., PDFs, images) for faster processing.
The `cached_content` parameter accepts a cache name created via the Google Generative AI API.
Below are two examples: caching a single file directly and caching multiple files using `Part`.
Single File Example:
This caches a single file and queries it.
.. code-block:: python
from google import genai
from google.genai import types
import time
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
client = genai.Client()
# Upload file
file = client.files.upload(file="./example_file")
while file.state.name == 'PROCESSING':
time.sleep(2)
file = client.files.get(name=file.name)
# Create cache
model = 'models/gemini-1.5-flash-001'
cache = client.caches.create(
model=model,
config=types.CreateCachedContentConfig(
display_name='Cached Content',
system_instruction=(
'You are an expert content analyzer, and your job is to answer '
'the user\'s query based on the file you have access to.'
),
contents=[file],
ttl="300s",
)
)
# Query with LangChain
llm = ChatGoogleGenerativeAI(
model=model,
cached_content=cache.name,
)
message = HumanMessage(content="Summarize the main points of the content.")
llm.invoke([message])
Multiple Files Example:
This caches two files using `Part` and queries them together.
.. code-block:: python
from google import genai
from google.genai.types import CreateCachedContentConfig, Content, Part
import time
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
client = genai.Client()
# Upload files
file_1 = client.files.upload(file="./file1")
while file_1.state.name == 'PROCESSING':
time.sleep(2)
file_1 = client.files.get(name=file_1.name)
file_2 = client.files.upload(file="./file2")
while file_2.state.name == 'PROCESSING':
time.sleep(2)
file_2 = client.files.get(name=file_2.name)
# Create cache with multiple files
contents = [
Content(
role="user",
parts=[
Part.from_uri(file_uri=file_1.uri, mime_type=file_1.mime_type),
Part.from_uri(file_uri=file_2.uri, mime_type=file_2.mime_type),
],
)
]
model = "gemini-1.5-flash-001"
cache = client.caches.create(
model=model,
config=CreateCachedContentConfig(
display_name='Cached Contents',
system_instruction=(
'You are an expert content analyzer, and your job is to answer '
'the user\'s query based on the files you have access to.'
),
contents=contents,
ttl="300s",
)
)
# Query with LangChain
llm = ChatGoogleGenerativeAI(
model=model,
cached_content=cache.name,
)
message = HumanMessage(content="Provide a summary of the key information across both files.")
llm.invoke([message])
Tool calling:
.. code-block:: python
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
ai_msg.tool_calls
.. code-block:: python
[{'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': 'c186c99f-f137-4d52-947f-9e3deabba6f6'},
{'name': 'GetWeather',
'args': {'location': 'New York City, NY'},
'id': 'cebd4a5d-e800-4fa5-babd-4aa286af4f31'},
{'name': 'GetPopulation',
'args': {'location': 'Los Angeles, CA'},
'id': '4f92d897-f5e4-4d34-a3bc-93062c92591e'},
{'name': 'GetPopulation',
'args': {'location': 'New York City, NY'},
'id': '634582de-5186-4e4b-968b-f192f0a93678'}]
Use Search with Gemini 2:
.. code-block:: python
from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
resp = llm.invoke(
"When is the next total solar eclipse in US?",
tools=[GenAITool(google_search={})],
)
Structured output:
.. code-block:: python
from typing import Optional
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(
setup='Why are cats so good at video games?',
punchline='They have nine lives on the internet',
rating=None
)
Image input:
.. code-block:: python
import base64
import httpx
from langchain_core.messages import HumanMessage
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": "describe the weather in this image"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
]
)
ai_msg = llm.invoke([message])
ai_msg.content
.. code-block:: python
'The weather in this image appears to be sunny and pleasant. The sky is a bright blue with scattered white clouds, suggesting fair weather. The lush green grass and trees indicate a warm and possibly slightly breezy day. There are no signs of rain or storms.'
Token usage:
.. code-block:: python
ai_msg = llm.invoke(messages)
ai_msg.usage_metadata
.. code-block:: python
{'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
Response metadata
.. code-block:: python
ai_msg = llm.invoke(messages)
ai_msg.response_metadata
.. code-block:: python
{
'prompt_feedback': {'block_reason': 0, 'safety_ratings': []},
'finish_reason': 'STOP',
'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]
}
""" # noqa: E501
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
default_metadata: Sequence[Tuple[str, str]] = Field(
default_factory=list
) #: :meta private:
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
cached_content: Optional[str] = None
"""The name of the cached content used as context to serve the prediction.
Note: only used in explicit caching, where users can have control over caching
(e.g. what content to cache) and enjoy guaranteed cost savings. Format:
``cachedContents/{cachedContent}``.
"""
def __init__(self, **kwargs: Any) -> None:
"""Needed for arg validation."""
# Get all valid field names, including aliases
valid_fields = set()
for field_name, field_info in self.model_fields.items():
valid_fields.add(field_name)
if hasattr(field_info, "alias") and field_info.alias is not None:
valid_fields.add(field_info.alias)
# Check for unrecognized arguments
for arg in kwargs:
if arg not in valid_fields:
suggestions = get_close_matches(arg, valid_fields, n=1)
suggestion = (
f" Did you mean: '{suggestions[0]}'?" if suggestions else ""
)
logger.warning(
f"Unexpected argument '{arg}' "
f"provided to ChatGoogleGenerativeAI.{suggestion}"
)
super().__init__(**kwargs)
model_config = ConfigDict(
populate_by_name=True,
)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@property
def _llm_type(self) -> str:
return "chat-google-generative-ai"
@property
def _supports_code_execution(self) -> bool:
return (
"gemini-1.5-pro" in self.model
or "gemini-1.5-flash" in self.model
or "gemini-2" in self.model
)
@classmethod
def is_lc_serializable(self) -> bool:
return True
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates params and passes them to google-generativeai package."""
if self.temperature is not None and not 0 <= self.temperature <= 2.0:
raise ValueError("temperature must be in the range [0.0, 2.0]")
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if self.top_k is not None and self.top_k <= 0:
raise ValueError("top_k must be positive")
if not any(
self.model.startswith(prefix) for prefix in ("models/", "tunedModels/")
):
self.model = f"models/{self.model}"
additional_headers = self.additional_headers or {}
self.default_metadata = tuple(additional_headers.items())
client_info = get_client_info(f"ChatGoogleGenerativeAI:{self.model}")
google_api_key = None
if not self.credentials:
if isinstance(self.google_api_key, SecretStr):
google_api_key = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
transport: Optional[str] = self.transport
self.client = genaix.build_generative_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=client_info,
client_options=self.client_options,
transport=transport,
)
self.async_client_running = None
return self
@property
def async_client(self) -> v1betaGenerativeServiceAsyncClient:
google_api_key = None
if not self.credentials:
if isinstance(self.google_api_key, SecretStr):
google_api_key = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
# NOTE: genaix.build_generative_async_service requires
# a running event loop, which causes an error
# when initialized inside a ThreadPoolExecutor.
# this check ensures that async client is only initialized
# within an asyncio event loop to avoid the error
if not self.async_client_running and _is_event_loop_running():
# async clients don't support "rest" transport
# https://github.com/googleapis/gapic-generator-python/issues/1962
transport = self.transport
if transport == "rest":
transport = "grpc_asyncio"
self.async_client_running = genaix.build_generative_async_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"),
client_options=self.client_options,
transport=transport,
)
return self.async_client_running
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"temperature": self.temperature,
"top_k": self.top_k,
"n": self.n,
"safety_settings": self.safety_settings,
"response_modalities": self.response_modalities,
}
[docs]
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
code_execution: Optional[bool] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""
Enable code execution. Supported on: gemini-1.5-pro, gemini-1.5-flash,
gemini-2.0-flash, and gemini-2.0-pro. When enabled, the model can execute
code to solve problems.
"""
"""Override invoke to add code_execution parameter."""
if code_execution is not None:
if not self._supports_code_execution:
raise ValueError(
f"Code execution is only supported on Gemini 1.5 Pro, \
Gemini 1.5 Flash, "
f"Gemini 2.0 Flash, and Gemini 2.0 Pro models. \
Current model: {self.model}"
)
if "tools" not in kwargs:
code_execution_tool = GoogleTool(code_execution=CodeExecution())
kwargs["tools"] = [code_execution_tool]
else:
raise ValueError(
"Tools are already defined." "code_execution tool can't be defined"
)
return super().invoke(input, config, stop=stop, **kwargs)
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="google_genai",
ls_model_name=self.model,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_output_tokens", self.max_output_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
def _prepare_params(
self,
stop: Optional[List[str]],
generation_config: Optional[Dict[str, Any]] = None,
) -> GenerationConfig:
gen_config = {
k: v
for k, v in {
"candidate_count": self.n,
"temperature": self.temperature,
"stop_sequences": stop,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"response_modalities": self.response_modalities,
}.items()
if v is not None
}
if generation_config:
gen_config = {**gen_config, **generation_config}
return GenerationConfig(**gen_config)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
cached_content: Optional[str] = None,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
**kwargs: Any,
) -> ChatResult:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
cached_content=cached_content or self.cached_content,
tool_choice=tool_choice,
)
response: GenerateContentResponse = _chat_with_retry(
request=request,
**kwargs,
generation_method=self.client.generate_content,
metadata=self.default_metadata,
)
return _response_to_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
cached_content: Optional[str] = None,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
**kwargs: Any,
) -> ChatResult:
if not self.async_client:
updated_kwargs = {
**kwargs,
**{
"tools": tools,
"functions": functions,
"safety_settings": safety_settings,
"tool_config": tool_config,
"generation_config": generation_config,
},
}
return await super()._agenerate(
messages, stop, run_manager, **updated_kwargs
)
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
cached_content=cached_content or self.cached_content,
tool_choice=tool_choice,
)
response: GenerateContentResponse = await _achat_with_retry(
request=request,
**kwargs,
generation_method=self.async_client.generate_content,
metadata=self.default_metadata,
)
return _response_to_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
cached_content: Optional[str] = None,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
cached_content=cached_content or self.cached_content,
tool_choice=tool_choice,
)
response: GenerateContentResponse = _chat_with_retry(
request=request,
generation_method=self.client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
)
prev_usage_metadata: UsageMetadata | None = None
for chunk in response:
_chat_result = _response_to_result(
chunk, stream=True, prev_usage=prev_usage_metadata
)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
message = cast(AIMessageChunk, gen.message)
curr_usage_metadata: UsageMetadata | dict[str, int] = (
message.usage_metadata or {}
)
prev_usage_metadata = (
message.usage_metadata
if prev_usage_metadata is None
else UsageMetadata(
input_tokens=prev_usage_metadata.get("input_tokens", 0)
+ curr_usage_metadata.get("input_tokens", 0),
output_tokens=prev_usage_metadata.get("output_tokens", 0)
+ curr_usage_metadata.get("output_tokens", 0),
total_tokens=prev_usage_metadata.get("total_tokens", 0)
+ curr_usage_metadata.get("total_tokens", 0),
)
)
if run_manager:
run_manager.on_llm_new_token(gen.text)
yield gen
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
cached_content: Optional[str] = None,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if not self.async_client:
updated_kwargs = {
**kwargs,
**{
"tools": tools,
"functions": functions,
"safety_settings": safety_settings,
"tool_config": tool_config,
"generation_config": generation_config,
},
}
async for value in super()._astream(
messages, stop, run_manager, **updated_kwargs
):
yield value
else:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
cached_content=cached_content or self.cached_content,
tool_choice=tool_choice,
)
prev_usage_metadata: UsageMetadata | None = None
async for chunk in await _achat_with_retry(
request=request,
generation_method=self.async_client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
):
_chat_result = _response_to_result(
chunk, stream=True, prev_usage=prev_usage_metadata
)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
message = cast(AIMessageChunk, gen.message)
curr_usage_metadata: UsageMetadata | dict[str, int] = (
message.usage_metadata or {}
)
prev_usage_metadata = (
message.usage_metadata
if prev_usage_metadata is None
else UsageMetadata(
input_tokens=prev_usage_metadata.get("input_tokens", 0)
+ curr_usage_metadata.get("input_tokens", 0),
output_tokens=prev_usage_metadata.get("output_tokens", 0)
+ curr_usage_metadata.get("output_tokens", 0),
total_tokens=prev_usage_metadata.get("total_tokens", 0)
+ curr_usage_metadata.get("total_tokens", 0),
)
)
if run_manager:
await run_manager.on_llm_new_token(gen.text)
yield gen
def _prepare_request(
self,
messages: List[BaseMessage],
*,
stop: Optional[List[str]] = None,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
generation_config: Optional[Dict[str, Any]] = None,
cached_content: Optional[str] = None,
) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
if tool_choice and tool_config:
raise ValueError(
"Must specify at most one of tool_choice and tool_config, received "
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
)
formatted_tools = None
code_execution_tool = GoogleTool(code_execution=CodeExecution())
if tools == [code_execution_tool]:
formatted_tools = tools
elif tools:
formatted_tools = [convert_to_genai_function_declarations(tools)]
elif functions:
formatted_tools = [convert_to_genai_function_declarations(functions)]
filtered_messages = []
for message in messages:
if isinstance(message, HumanMessage) and not message.content:
warnings.warn(
"HumanMessage with empty content was removed to prevent API error"
)
else:
filtered_messages.append(message)
messages = filtered_messages
system_instruction, history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
if tool_choice:
if not formatted_tools:
msg = (
f"Received {tool_choice=} but no {tools=}. 'tool_choice' can only "
f"be specified if 'tools' is specified."
)
raise ValueError(msg)
all_names: List[str] = []
for t in formatted_tools:
if hasattr(t, "function_declarations"):
t_with_declarations = cast(Any, t)
all_names.extend(
f.name for f in t_with_declarations.function_declarations
)
elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"):
continue
else:
raise TypeError(
f"Tool {t} doesn't have function_declarations attribute"
)
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
formatted_tool_config = None
if tool_config:
formatted_tool_config = ToolConfig(
function_calling_config=tool_config["function_calling_config"]
)
formatted_safety_settings = []
if safety_settings:
formatted_safety_settings = [
SafetySetting(category=c, threshold=t)
for c, t in safety_settings.items()
]
request = GenerateContentRequest(
model=self.model,
contents=history,
tools=formatted_tools,
tool_config=formatted_tool_config,
safety_settings=formatted_safety_settings,
generation_config=self._prepare_params(
stop, generation_config=generation_config
),
cached_content=cached_content,
)
if system_instruction:
request.system_instruction = system_instruction
return request
[docs]
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
result = self.client.count_tokens(
model=self.model, contents=[Content(parts=[Part(text=text)])]
)
return result.total_tokens
[docs]
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
_ = kwargs.pop("method", None)
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
tool_name = _get_tool_name(schema) # type: ignore[arg-type]
if isinstance(schema, type) and is_basemodel_subclass_safe(schema):
parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
parser = JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
tool_choice = tool_name if self._supports_tool_choice else None
try:
llm = self.bind_tools(
[schema],
tool_choice=tool_choice,
ls_structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": convert_to_openai_tool(schema),
},
)
except Exception:
llm = self.bind_tools([schema], tool_choice=tool_choice)
if include_raw:
parser_with_fallback = RunnablePassthrough.assign(
parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
).with_fallbacks(
[RunnablePassthrough.assign(parsed=lambda _: None)],
exception_key="parsing_error",
)
return {"raw": llm} | parser_with_fallback
else:
return llm | parser
@property
def _supports_tool_choice(self) -> bool:
return (
"gemini-1.5-pro" in self.model
or "gemini-1.5-flash" in self.model
or "gemini-2" in self.model
)
def _get_tool_name(
tool: Union[_ToolDict, GoogleTool, Dict],
) -> str:
try:
genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
except ValueError as e: # other TypedDict
if is_typeddict(tool):
return convert_to_openai_tool(cast(Dict, tool))["function"]["name"]
else:
raise e