From 7e12e367dba4860362e6656b3a90e839b3ce3ed4 Mon Sep 17 00:00:00 2001 From: Manuel Mosquera Date: Mon, 31 Jul 2023 10:58:01 -0500 Subject: [PATCH 1/3] :sparkles: Add a planner agent --- src/main_agent.py | 17 ++++++++++++++--- src/openai_utils.py | 11 +++++++++++ src/planner_agent/planner_agent.py | 29 +++++++++++++++++++++++++++++ src/planner_agent/prompts.py | 4 ++++ 4 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 src/planner_agent/planner_agent.py create mode 100644 src/planner_agent/prompts.py diff --git a/src/main_agent.py b/src/main_agent.py index fdb2bea..cd065ff 100644 --- a/src/main_agent.py +++ b/src/main_agent.py @@ -3,6 +3,8 @@ import requests from openai_utils import get_code_from_open_ai from html_cleaner import get_cleaned_html +from planner_agent.planner_agent import get_plan +import queue def send_message(message): @@ -21,10 +23,19 @@ def send_message(message): def main(): history_messages = [] + goal = input("Enter a goal for the web browser agent or exit to quit: ") + + if goal.lower() == "exit": + return + + print('This is the plan for the goal: ') + actions = queue.Queue() + for action in get_plan(goal): + actions.put(action) + print(action) + while True: - action = input("Enter a web action or exit to quit: ") - if action.lower() == "exit": - break + action = actions.get() is_error = input("Is this an error? (y/n): ") is_error = is_error.lower() == "y" history_messages, message = get_code_from_open_ai( diff --git a/src/openai_utils.py b/src/openai_utils.py index a699333..461af90 100644 --- a/src/openai_utils.py +++ b/src/openai_utils.py @@ -1,6 +1,7 @@ import os import openai from dotenv import load_dotenv, find_dotenv +from typing import Literal _ = load_dotenv(find_dotenv()) # read local .env file @@ -31,6 +32,16 @@ "content": """{response}""", } +TypeMessage = dict[Literal["role", "content"], str] + +def format_message(message: str, role:Literal["user", "system", "assistant"]="user") -> TypeMessage: + return { + "role": role, + "content": message + } + +def compile_messages(*messages: tuple[TypeMessage]) -> list[TypeMessage]: return list(messages) + def replace_action_in_message(message, action): new_message = message.copy() diff --git a/src/planner_agent/planner_agent.py b/src/planner_agent/planner_agent.py new file mode 100644 index 0000000..bcf9a57 --- /dev/null +++ b/src/planner_agent/planner_agent.py @@ -0,0 +1,29 @@ +import os +import sys +import re + +os.chdir('./src') +# setting path +sys.path.append(os.getcwd()) + +# importing project modules +from openai_utils import compile_messages, format_message, get_completion_from_messages +from planner_agent.prompts import CREATE_PLAN + + +def get_raw_plan(goal: str) -> str: + create_plan_prompt = format_message(CREATE_PLAN.format(goal=goal)) + context = compile_messages(create_plan_prompt) + plan = get_completion_from_messages(context) + return plan + +def parse_plan(plan: str) -> list[str]: + # Use regex to find all numbered items in the text + numbered_items = re.findall(r'\d+\.\s+(.+)', plan) + + return [item.strip() for item in numbered_items] + +def get_plan(goal: str) -> list[str]: + raw_plan = get_raw_plan(goal) + actions = parse_plan(raw_plan) + return actions diff --git a/src/planner_agent/prompts.py b/src/planner_agent/prompts.py new file mode 100644 index 0000000..9867914 --- /dev/null +++ b/src/planner_agent/prompts.py @@ -0,0 +1,4 @@ +CREATE_PLAN = """Given that you are an AI agent using selenium to navigate a web browser. What are the high level actions you would need to take to achieve the goal "{goal}"? +Give me your answer only in the following format: +To achieve the goal "{goal}" I will need to perform the following actions: +1. ... """ \ No newline at end of file From 394395fcb1f8f10aefaa5f2c256a66d07b286f36 Mon Sep 17 00:00:00 2001 From: Manuel Mosquera Date: Mon, 31 Jul 2023 17:50:14 -0500 Subject: [PATCH 2/3] :sparkles: Improve the initial plan with user feedback --- src/main_agent.py | 2 -- src/planner_agent/planner_agent.py | 22 ++++++++++++++++++++-- src/planner_agent/prompts.py | 10 ++++++---- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/main_agent.py b/src/main_agent.py index cd065ff..84a9f7a 100644 --- a/src/main_agent.py +++ b/src/main_agent.py @@ -28,11 +28,9 @@ def main(): if goal.lower() == "exit": return - print('This is the plan for the goal: ') actions = queue.Queue() for action in get_plan(goal): actions.put(action) - print(action) while True: action = actions.get() diff --git a/src/planner_agent/planner_agent.py b/src/planner_agent/planner_agent.py index bcf9a57..86816c9 100644 --- a/src/planner_agent/planner_agent.py +++ b/src/planner_agent/planner_agent.py @@ -15,15 +15,33 @@ def get_raw_plan(goal: str) -> str: create_plan_prompt = format_message(CREATE_PLAN.format(goal=goal)) context = compile_messages(create_plan_prompt) plan = get_completion_from_messages(context) - return plan + return plan, context def parse_plan(plan: str) -> list[str]: # Use regex to find all numbered items in the text numbered_items = re.findall(r'\d+\.\s+(.+)', plan) + # Get rid of the first item, which is open the web browser + numbered_items = numbered_items[1:] + return [item.strip() for item in numbered_items] def get_plan(goal: str) -> list[str]: - raw_plan = get_raw_plan(goal) + raw_plan, context = get_raw_plan(goal) + + # Get user feedback on the plan + while True: + print("Here is the plan:") + print(raw_plan) + print("Does this plan look good?") + user_feedback = input("Press enter to continue with the plan or provide some feedback: ") + if user_feedback == "": + break + else: + # Improve the plan based on the user feedback + context.append(format_message(raw_plan, role="assistant")) + context.append(format_message(user_feedback)) + raw_plan = get_completion_from_messages(context) + actions = parse_plan(raw_plan) return actions diff --git a/src/planner_agent/prompts.py b/src/planner_agent/prompts.py index 9867914..67fb8fb 100644 --- a/src/planner_agent/prompts.py +++ b/src/planner_agent/prompts.py @@ -1,4 +1,6 @@ -CREATE_PLAN = """Given that you are an AI agent using selenium to navigate a web browser. What are the high level actions you would need to take to achieve the goal "{goal}"? -Give me your answer only in the following format: -To achieve the goal "{goal}" I will need to perform the following actions: -1. ... """ \ No newline at end of file +PLAN_FORMAT = """Give me your answer only in the following format: +To achieve the goal I will need to perform the following actions: +1. Open the web browser +2. ... """ + +CREATE_PLAN = """Given that you are an AI agent using selenium to navigate a web browser. What are the high level actions you would need to take to achieve the goal "{goal}"?""" + PLAN_FORMAT \ No newline at end of file From 84d26d2f6c2c57d9f5383341c959325b1d930898 Mon Sep 17 00:00:00 2001 From: Manuel Mosquera Date: Mon, 31 Jul 2023 18:09:21 -0500 Subject: [PATCH 3/3] :rotating_light: Add test mode --- src/main_agent.py | 11 +++++++++-- src/planner_agent/constants.py | 10 ++++++++++ src/planner_agent/planner_agent.py | 11 ++++++++--- 3 files changed, 27 insertions(+), 5 deletions(-) create mode 100644 src/planner_agent/constants.py diff --git a/src/main_agent.py b/src/main_agent.py index 84a9f7a..41b6065 100644 --- a/src/main_agent.py +++ b/src/main_agent.py @@ -5,6 +5,7 @@ from html_cleaner import get_cleaned_html from planner_agent.planner_agent import get_plan import queue +import os def send_message(message): @@ -22,14 +23,20 @@ def send_message(message): def main(): + is_test_mode = os.getenv("TEST_MODE", False) + history_messages = [] - goal = input("Enter a goal for the web browser agent or exit to quit: ") + + if is_test_mode: + goal = "Search a blue car" + else: + goal = input("Enter a goal for the web browser agent or exit to quit: ") if goal.lower() == "exit": return actions = queue.Queue() - for action in get_plan(goal): + for action in get_plan(goal, test_mode=is_test_mode): actions.put(action) while True: diff --git a/src/planner_agent/constants.py b/src/planner_agent/constants.py new file mode 100644 index 0000000..c363c7f --- /dev/null +++ b/src/planner_agent/constants.py @@ -0,0 +1,10 @@ +DEFAULT_RAW_PLAN = """To achieve the goal I will need to perform the following actions: +1. Open the web browser. +2. Navigate to a search engine or a car dealership website. +3. Locate the search bar on the webpage. +4. Enter the keyword "blue car" into the search bar. +5. Submit the search query. +6. Wait for the search results to load. +7. Click on the most relevant search result webpage. +8. Wait for the selected webpage to load. +9. Close the web browser.""" \ No newline at end of file diff --git a/src/planner_agent/planner_agent.py b/src/planner_agent/planner_agent.py index 86816c9..1644bdc 100644 --- a/src/planner_agent/planner_agent.py +++ b/src/planner_agent/planner_agent.py @@ -9,6 +9,7 @@ # importing project modules from openai_utils import compile_messages, format_message, get_completion_from_messages from planner_agent.prompts import CREATE_PLAN +from planner_agent.constants import DEFAULT_RAW_PLAN def get_raw_plan(goal: str) -> str: @@ -26,11 +27,15 @@ def parse_plan(plan: str) -> list[str]: return [item.strip() for item in numbered_items] -def get_plan(goal: str) -> list[str]: - raw_plan, context = get_raw_plan(goal) +def get_plan(goal: str, test_mode = False) -> list[str]: + if test_mode: + raw_plan = DEFAULT_RAW_PLAN + print(raw_plan) + else: + raw_plan, context = get_raw_plan(goal) # Get user feedback on the plan - while True: + while True and not test_mode: print("Here is the plan:") print(raw_plan) print("Does this plan look good?")