Skip to content

Commit

Permalink
Update Notebook with langchain code and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nishitpatel01 committed Jul 17, 2023
1 parent 14799fe commit 233ff7d
Showing 1 changed file with 267 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,278 @@
"source": [
"## Summary\n",
"\n",
"This notebook products introduction to Langchain and its core concepts and components with easy to follow examples"
"This notebook Illustrates how to create basic recommendations and generate matching using langchain and Vertex GenAI LLM models"
]
},
{
"cell_type": "markdown",
"id": "35fdfb72-a3d1-4ee3-8a5d-bd261d119473",
"metadata": {},
"source": [
"#### Overview\n",
"\n",
"\n",
"#TODO: Brief description here"
]
},
{
"cell_type": "markdown",
"id": "22af7867-b7ad-49cf-a19e-1734a243014e",
"metadata": {},
"source": [
"#### Setup "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "349e0a64-13c8-476f-a548-51cf7ca683a9",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U google-cloud-aiplatform langchain --user"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e68c4d34-49db-423e-81a2-f1bf3e54a491",
"metadata": {},
"outputs": [],
"source": [
"import json \n",
"\n",
"import vertexai\n",
"from vertexai.preview.language_models import TextGenerationModel\n",
"\n",
"import langchain\n",
"from pydantic import BaseModel\n",
"from langchain.llms.base import LLM\n",
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.llms import VertexAI"
]
},
{
"cell_type": "markdown",
"id": "7b047858-06aa-469b-ae1f-14965f635848",
"metadata": {},
"source": [
"Create custom vertex wrapper class to call LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e862bbe-e522-434c-b786-55b4193b04c6",
"metadata": {},
"outputs": [],
"source": [
"# LLM custom wrapper\n",
"\n",
"class VertexLLMTextExractor(LLM):\n",
" model: TextGenerationModel\n",
" predict_kwargs: dict\n",
"\n",
" def __init__(self, model, **predict_kwargs):\n",
" super().__init__(model=model, predict_kwargs=predict_kwargs)\n",
"\n",
" @property\n",
" def _llm_type(self):\n",
" return 'VertexLLM'\n",
"\n",
" def _call(self, prompt, stop=None):\n",
" result = self.model.predict(prompt, **self.predict_kwargs)\n",
" return str(result)\n",
"\n",
" @property\n",
" def _identifying_params(self):\n",
" return {}\n",
"\n",
"# Call llm model\n",
"\n",
"model = TextGenerationModel.from_pretrained(\"text-bison@001\")\n",
"parameters = {\n",
" \"max_output_tokens\": 1024,\n",
" \"temperature\": 0.2,\n",
" \"top_k\": 40,\n",
" \"top_p\": 0.8,\n",
"}\n",
"\n",
"llm = VertexLLMTextExractor(\n",
" model,\n",
" **parameters\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07df8bcc-b0a2-4ec7-aa9e-45ce7dc8e82f",
"metadata": {},
"outputs": [],
"source": [
"# Set of helper functions\n",
"\n",
"# Function to get schema\n",
"def get_response_schema(chain: str):\n",
" \n",
" from langchain.output_parsers import StructuredOutputParser, ResponseSchema\n",
" from langchain.prompts import HumanMessagePromptTemplate\n",
" \n",
" # Define recommended color & brand schema\n",
" recommendation_response_schema = [\n",
" ResponseSchema(name=\"recommended_brand_name\", description=\"recommended brand name from llm output\"),\n",
" ResponseSchema(name=\"recommended_color_name\", description=\"recommended color name from llm output\")\n",
" ]\n",
" \n",
" # Format response intructions\n",
" response_schema_output_parser = StructuredOutputParser.from_response_schemas(recommendation_response_schema)\n",
" recommendation_response_format_instructions = response_schema_output_parser.get_format_instructions()\n",
" \n",
" # Define matched color schema\n",
" matches_response_schema = [\n",
" ResponseSchema(name=\"recommended_brand_name\", description=\"given recommended brand name\"),\n",
" ResponseSchema(name=\"recommended_color_name\", description=\"given recommended color name\"),\n",
" ResponseSchema(name=\"matched_brand_name\", description=\"matched brand name for given recommended_color_name and recommended_brand_name combination\"),\n",
" ResponseSchema(name=\"matched_color_name\", description=\"matched color name for given recommended_color_name and recommended_brand_name combination\"),\n",
" ResponseSchema(name=\"matched_uri\", description=\"color uri of matched color name for given recommended_color_name and recommended_brand_name combination\")\n",
" ]\n",
" \n",
" # Format response intructions\n",
" matches_response_schema_output_parser = StructuredOutputParser.from_response_schemas(matches_response_schema)\n",
" matches_response_format_instructions = matches_response_schema_output_parser.get_format_instructions()\n",
" \n",
" if chain == 'recommend':\n",
" return recommendation_response_format_instructions\n",
" elif chain == 'match':\n",
" return matches_response_format_instructions\n",
" else:\n",
" pass\n",
" \n",
" \n",
"# Function to generate prompt template\n",
"def generate_prompt(chain: str, input_prompt_text: str):\n",
" \n",
" format_intruction = get_response_schema(chain)\n",
" \n",
" if chain == 'recommend':\n",
" \n",
" # Create prompt template\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"user_input\"],\n",
" partial_variables={\"format_instructions\": format_intruction},\n",
" template=recommend_template\n",
" )\n",
"\n",
" color_recommendations_promptValue = prompt.format(user_input=input_prompt_text) \n",
" return prompt\n",
" \n",
" elif chain == 'match':\n",
" matches_prompt = PromptTemplate(\n",
" input_variables=[\"recommended_brand_name, recommended_color_name\"],\n",
" partial_variables={\"format_instructions\": format_intruction},\n",
" template=match_template\n",
" )\n",
" return matches_prompt\n",
"\n",
" else:\n",
" return None\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14fe59a9-23eb-4e4f-bf2e-ef4b60b75784",
"metadata": {},
"outputs": [],
"source": [
"# Function to recommend and match colors\n",
"\n",
"def recommend_and_matches(input_prompt_text: str):\n",
" \n",
" import json\n",
" from langchain.chains import LLMChain, SimpleSequentialChain\n",
" \n",
" # Simple sequential chain\n",
" # Holds recommended colors from user input response\n",
" recommended_color_chain = LLMChain(llm=llm, prompt=generate_prompt('recommend', input_prompt_text))\n",
"\n",
" # Holds matchee colors from recommended colors\n",
" matched_color_chain = LLMChain(llm=llm,prompt=generate_prompt('match', input_prompt_text))\n",
" \n",
" # Build final chain\n",
" overall_chain = SimpleSequentialChain(chains=[recommended_color_chain, matched_color_chain], verbose=False)\n",
" colors = overall_chain.run(user_input)\n",
" \n",
" json_colors = json.loads(colors.strip('```json```'))\n",
" \n",
" return json_colors #(json.dumps(json_colors, indent = 4)) \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e035639e-e96b-4d89-b83e-6150c016b50a",
"metadata": {},
"outputs": [],
"source": [
"recommend_and_matches(user_input)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91f0ca04-dd2b-4eb9-b70e-2f23fdf3908f",
"metadata": {},
"outputs": [],
"source": [
"print(get_response_schema('recommend'))\n",
"print(get_response_schema('match')) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f656df25-101c-456b-b819-ce910a584201",
"metadata": {},
"outputs": [],
"source": [
"print(generate_prompt('recommend',user_input))\n",
"print('**************************************')\n",
"print(generate_prompt('match',user_input))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fcfa4db-c6bb-49c7-b1cb-7fd0ac1b04dc",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2baf3ef-fa98-44f6-b803-892a1dd43acc",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "734fb04e-cbbd-4788-9b74-a63e4ca98c07",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "80de8667-b358-4199-bf14-4cf39301d6a3",
"id": "9696006b-22cc-49b4-b40d-63d37fae67b5",
"metadata": {},
"outputs": [],
"source": []
Expand Down

0 comments on commit 233ff7d

Please sign in to comment.