"""Chain for applying self-critique using the SmartGPT workflow."""
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.schema import LLMResult, PromptValue
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from pydantic import ConfigDict, model_validator
[docs]
class SmartLLMChain(Chain):
"""Chain for applying self-critique using the SmartGPT workflow.
See details at https://youtu.be/wVzuvf9D9BU
A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM
performs these 3 steps:
1. Ideate: Pass the user prompt to an ideation LLM n_ideas times,
each result is an "idea"
2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas
& picks the best one
3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea
& outputs only the (improved version of) the best output
In total, SmartLLMChain pass will use n_ideas+2 LLM calls
Note that SmartLLMChain will only improve results (compared to a basic LLMChain),
when the underlying models have the capability for reflection, which smaller models
often don't.
Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result.
"""
[docs]
class SmartLLMChainHistory:
question: str = ""
ideas: List[str] = []
critique: str = ""
@property
def n_ideas(self) -> int:
return len(self.ideas)
[docs]
def ideation_prompt_inputs(self) -> Dict[str, Any]:
return {"question": self.question}
[docs]
def critique_prompt_inputs(self) -> Dict[str, Any]:
return {
"question": self.question,
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
}
[docs]
def resolve_prompt_inputs(self) -> Dict[str, Any]:
return {
"question": self.question,
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
"critique": self.critique,
}
prompt: BasePromptTemplate
"""Prompt object to use."""
output_key: str = "resolution"
ideation_llm: Optional[BaseLanguageModel] = None
"""LLM to use in ideation step. If None given, 'llm' will be used."""
critique_llm: Optional[BaseLanguageModel] = None
"""LLM to use in critique step. If None given, 'llm' will be used."""
resolver_llm: Optional[BaseLanguageModel] = None
"""LLM to use in resolve step. If None given, 'llm' will be used."""
llm: Optional[BaseLanguageModel] = None
"""LLM to use for each steps, if no specific llm for that step is given. """
n_ideas: int = 3
"""Number of ideas to generate in idea step"""
return_intermediate_steps: bool = False
"""Whether to return ideas and critique, in addition to resolution."""
history: SmartLLMChainHistory = SmartLLMChainHistory()
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_inputs(cls, values: Dict[str, Any]) -> Any:
"""Ensure we have an LLM for each step."""
llm = values.get("llm")
ideation_llm = values.get("ideation_llm")
critique_llm = values.get("critique_llm")
resolver_llm = values.get("resolver_llm")
if not llm and not ideation_llm:
raise ValueError(
"Either ideation_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if not llm and not critique_llm:
raise ValueError(
"Either critique_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if not llm and not resolver_llm:
raise ValueError(
"Either resolve_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if llm and ideation_llm and critique_llm and resolver_llm:
raise ValueError(
"LLMs are given for each step (ideation_llm, critique_llm,"
" resolver_llm), but backup LLM (llm) is also given, which"
" would not be used."
)
return values
@property
def input_keys(self) -> List[str]:
"""Defines the input keys."""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Defines the output keys."""
if self.return_intermediate_steps:
return ["ideas", "critique", self.output_key]
return [self.output_key]
[docs]
def prep_prompts(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[PromptValue, Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in inputs:
stop = inputs["stop"]
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
return prompt, stop
def _call(
self,
input_list: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
prompt, stop = self.prep_prompts(input_list, run_manager=run_manager)
self.history.question = prompt.to_string()
ideas = self._ideate(stop, run_manager)
self.history.ideas = ideas
critique = self._critique(stop, run_manager)
self.history.critique = critique
resolution = self._resolve(stop, run_manager)
if self.return_intermediate_steps:
return {"ideas": ideas, "critique": critique, self.output_key: resolution}
return {self.output_key: resolution}
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
"""Between steps, only the LLM result text is passed, not the LLMResult object.
This function extracts the text from an LLMResult."""
if len(result.generations) != 1:
raise ValueError(
f"In SmartLLM the LLM result in step {step} is not "
"exactly 1 element. This should never happen"
)
if len(result.generations[0]) != 1:
raise ValueError(
f"In SmartLLM the LLM in step {step} returned more than "
"1 output. SmartLLM only works with LLMs returning "
"exactly 1 output."
)
return result.generations[0][0].text
[docs]
def get_prompt_strings(
self, stage: str
) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]:
role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = []
role_strings.append(
(
HumanMessagePromptTemplate,
"Question: {question}\nAnswer: Let's work this out in a step by "
"step way to be sure we have the right answer:",
)
)
if stage == "ideation":
return role_strings
role_strings.extend(
[
*[
(
AIMessagePromptTemplate,
"Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}",
)
for i in range(self.n_ideas)
],
(
HumanMessagePromptTemplate,
"You are a researcher tasked with investigating the "
f"{self.n_ideas} response options provided. List the flaws and "
"faulty logic of each answer option. Let's work this out in a step"
" by step way to be sure we have all the errors:",
),
]
)
if stage == "critique":
return role_strings
role_strings.extend(
[
(AIMessagePromptTemplate, "Critique: {critique}"),
(
HumanMessagePromptTemplate,
"You are a resolver tasked with 1) finding which of "
f"the {self.n_ideas} answer options the researcher thought was "
"best, 2) improving that answer and 3) printing the answer in "
"full. Don't output anything for step 1 or 2, only the full "
"answer in 3. Let's work this out in a step by step way to "
"be sure we have the right answer:",
),
]
)
if stage == "resolve":
return role_strings
raise ValueError(
"stage should be either 'ideation', 'critique' or 'resolve',"
f" but it is '{stage}'. This should never happen."
)
[docs]
def ideation_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation"))
[docs]
def critique_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique"))
[docs]
def resolve_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve"))
def _ideate(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[str]:
"""Generate n_ideas ideas as response to user prompt."""
llm = self.ideation_llm if self.ideation_llm else self.llm
prompt = self.ideation_prompt().format_prompt(
**self.history.ideation_prompt_inputs()
)
callbacks = run_manager.get_child() if run_manager else None
if llm:
ideas = [
self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks),
step="ideate",
)
for _ in range(self.n_ideas)
]
for i, idea in enumerate(ideas):
_colored_text = get_colored_text(idea, "blue")
_text = f"Idea {i+1}:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return ideas
else:
raise ValueError("llm is none, which should never happen")
def _critique(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Critique each of the ideas from ideation stage & select best one."""
llm = self.critique_llm if self.critique_llm else self.llm
prompt = self.critique_prompt().format_prompt(
**self.history.critique_prompt_inputs()
)
callbacks = run_manager.handlers if run_manager else None
if llm:
critique = self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks), step="critique"
)
_colored_text = get_colored_text(critique, "yellow")
_text = "Critique:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return critique
else:
raise ValueError("llm is none, which should never happen")
def _resolve(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Improve upon the best idea as chosen in critique step & return it."""
llm = self.resolver_llm if self.resolver_llm else self.llm
prompt = self.resolve_prompt().format_prompt(
**self.history.resolve_prompt_inputs()
)
callbacks = run_manager.handlers if run_manager else None
if llm:
resolution = self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks), step="resolve"
)
_colored_text = get_colored_text(resolution, "green")
_text = "Resolution:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return resolution
else:
raise ValueError("llm is none, which should never happen")