"""Experiment with different models."""from__future__importannotationsfromtypingimportList,Optional,Sequencefromlangchain_core.language_models.llmsimportBaseLLMfromlangchain_core.prompts.promptimportPromptTemplatefromlangchain_core.utils.inputimportget_color_mapping,print_textfromlangchain.chains.baseimportChainfromlangchain.chains.llmimportLLMChain
[docs]classModelLaboratory:"""Experiment with different models."""
[docs]def__init__(self,chains:Sequence[Chain],names:Optional[List[str]]=None):"""Initialize with chains to experiment with. Args: chains: list of chains to experiment with. """forchaininchains:ifnotisinstance(chain,Chain):raiseValueError("ModelLaboratory should now be initialized with Chains. ""If you want to initialize with LLMs, use the `from_llms` method ""instead (`ModelLaboratory.from_llms(...)`)")iflen(chain.input_keys)!=1:raiseValueError("Currently only support chains with one input variable, "f"got {chain.input_keys}")iflen(chain.output_keys)!=1:raiseValueError("Currently only support chains with one output variable, "f"got {chain.output_keys}")ifnamesisnotNone:iflen(names)!=len(chains):raiseValueError("Length of chains does not match length of names.")self.chains=chainschain_range=[str(i)foriinrange(len(self.chains))]self.chain_colors=get_color_mapping(chain_range)self.names=names
[docs]@classmethoddeffrom_llms(cls,llms:List[BaseLLM],prompt:Optional[PromptTemplate]=None)->ModelLaboratory:"""Initialize with LLMs to experiment with and optional prompt. Args: llms: list of LLMs to experiment with prompt: Optional prompt to use to prompt the LLMs. Defaults to None. If a prompt was provided, it should only have one input variable. """ifpromptisNone:prompt=PromptTemplate(input_variables=["_input"],template="{_input}")chains=[LLMChain(llm=llm,prompt=prompt)forllminllms]names=[str(llm)forllminllms]returncls(chains,names=names)
[docs]defcompare(self,text:str)->None:"""Compare model outputs on an input text. If a prompt was provided with starting the laboratory, then this text will be fed into the prompt. If no prompt was provided, then the input text is the entire prompt. Args: text: input text to run all models on. """print(f"\033[1mInput:\033[0m\n{text}\n")# noqa: T201fori,chaininenumerate(self.chains):ifself.namesisnotNone:name=self.names[i]else:name=str(chain)print_text(name,end="\n")output=chain.run(text)print_text(output,color=self.chain_colors[str(i)],end="\n\n")