Source code for elasticsearch.helpers.vectorstore._async.embedding_service
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
from typing import List
from elasticsearch import AsyncElasticsearch
from elasticsearch._version import __versionstr__ as lib_version
class AsyncEmbeddingService(ABC):
[docs]
@abstractmethod
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a list of documents.
:param texts: A list of document strings to generate embeddings for.
:return: A list of embeddings, one for each document in the input.
"""
[docs]
@abstractmethod
async def embed_query(self, query: str) -> List[float]:
"""Generate an embedding for a single query text.
:param text: The query text to generate an embedding for.
:return: The embedding for the input query text.
"""
class AsyncElasticsearchEmbeddings(AsyncEmbeddingService):
"""Elasticsearch as a service for embedding model inference.
You need to have an embedding model downloaded and deployed in Elasticsearch:
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
""" # noqa: E501
def __init__(
self,
*,
client: AsyncElasticsearch,
model_id: str,
input_field: str = "text_field",
user_agent: str = f"elasticsearch-py-es/{lib_version}",
):
"""
:param agent_header: user agent header specific to the 3rd party integration.
Used for usage tracking in Elastic Cloud.
:param model_id: The model_id of the model deployed in the Elasticsearch cluster.
:param input_field: The name of the key for the input text field in the
document. Defaults to 'text_field'.
:param client: Elasticsearch client connection. Alternatively specify the
Elasticsearch connection with the other es_* parameters.
"""
# Add integration-specific usage header for tracking usage in Elastic Cloud.
# client.options preserves existing (non-user-agent) headers.
client = client.options(headers={"User-Agent": user_agent})
self.client = client
self.model_id = model_id
self.input_field = input_field
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
return await self._embedding_func(texts)
async def embed_query(self, text: str) -> List[float]:
result = await self._embedding_func([text])
return result[0]
async def _embedding_func(self, texts: List[str]) -> List[List[float]]:
response = await self.client.ml.infer_trained_model(
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
)
return [doc["predicted_value"] for doc in response["inference_results"]]