diff --git a/eaia/gmail.py b/eaia/gmail.py index 40a3d7b..5b2a701 100644 --- a/eaia/gmail.py +++ b/eaia/gmail.py @@ -1,71 +1,36 @@ +import os import base64 -import email.utils import logging -from datetime import datetime, time, timedelta -from email.mime.multipart import MIMEMultipart -from email.mime.text import MIMEText +from datetime import datetime, timedelta, time from pathlib import Path -from typing import Iterable +from typing import Iterable, List -import langsmith as ls import pytz from dateutil import parser +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from google_auth_oauthlib.flow import InstalledAppFlow from googleapiclient.discovery import build -from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.tools import tool -from langchain_core.runnables.config import RunnableConfig - - -from eaia.schemas import EmailData +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +import email.utils +from functools import lru_cache logger = logging.getLogger(__name__) + +# Constants _SCOPES = [ - "https://www.googleapis.com/auth/calendar.readonly", - "https://www.googleapis.com/auth/calendar.events", - "https://www.googleapis.com/auth/gmail.send", - "https://www.googleapis.com/auth/gmail.readonly", -] -_OLD_SCOPES = [ "https://www.googleapis.com/auth/gmail.modify", "https://www.googleapis.com/auth/calendar", ] _ROOT = Path(__file__).parent.absolute() -_PORT = 54191 _SECRETS_DIR = _ROOT / ".secrets" _SECRETS_PATH = str(_SECRETS_DIR / "secrets.json") _TOKEN_PATH = str(_SECRETS_DIR / "token.json") - -def get_credentials(assistant_id): - from arcadepy import Arcade - - client = Arcade() - - try: - auth_response = client.auth.start( - user_id=assistant_id, - provider="google", - scopes=_SCOPES, - ) - - if auth_response.status != "completed": - raise ValueError - except ValueError: - auth_response = client.auth.start( - user_id=assistant_id, - provider="google", - scopes=_OLD_SCOPES, - ) - - if auth_response.status != "completed": - raise ValueError - - from google.oauth2.credentials import Credentials - - # Wait for the authorization to complete - auth_response = client.auth.wait_for_completion(auth_response) - return Credentials(auth_response.context.token) - +# Utility Functions +def get_header(headers, name, default=None): + return next((header["value"] for header in headers if header["name"].lower() == name.lower()), default) def extract_message_part(msg): """Recursively walk through the email parts to find message body.""" @@ -84,260 +49,110 @@ def extract_message_part(msg): return body return "No message body available." - def parse_time(send_time: str): try: - parsed_time = parser.parse(send_time) - return parsed_time + return parser.parse(send_time) except (ValueError, TypeError) as e: raise ValueError(f"Error parsing time: {send_time} - {e}") - -def create_message(sender, to, subject, message_text, thread_id, original_message_id): - message = MIMEMultipart() - message["to"] = ", ".join(to) - message["from"] = sender - message["subject"] = subject - message["In-Reply-To"] = original_message_id - message["References"] = original_message_id - message["Message-ID"] = email.utils.make_msgid() - msg = MIMEText(message_text) - message.attach(msg) - raw = base64.urlsafe_b64encode(message.as_bytes()) - raw = raw.decode() - return {"raw": raw, "threadId": thread_id} - - -def get_recipients( - headers, - email_address, - addn_receipients=None, -): - recipients = set(addn_receipients or []) - sender = None - for header in headers: - if header["name"].lower() in ["to", "cc"]: - recipients.update(header["value"].replace(" ", "").split(",")) - if header["name"].lower() == "from": - sender = header["value"] - if sender: - recipients.add(sender) # Ensure the original sender is included in the response - for r in list(recipients): - if email_address in r: - recipients.remove(r) - return list(recipients) - - -def send_message(service, user_id, message): - message = service.users().messages().send(userId=user_id, body=message).execute() - return message - - -def send_email( - email_id, - response_text, - email_address, - gmail_token: str | None = None, - gmail_secret: str | None = None, - addn_receipients=None, - assistant_id=None, -): - creds = get_credentials(assistant_id) - - service = build("gmail", "v1", credentials=creds) - message = service.users().messages().get(userId="me", id=email_id).execute() - - headers = message["payload"]["headers"] - message_id = next( - header["value"] for header in headers if header["name"].lower() == "message-id" - ) - thread_id = message["threadId"] - - # Get recipients and sender - recipients = get_recipients(headers, email_address, addn_receipients) - - # Create the response - subject = next( - header["value"] for header in headers if header["name"].lower() == "subject" - ) - response_subject = subject - response_message = create_message( - "me", recipients, response_subject, response_text, thread_id, message_id - ) - # Send the response - send_message(service, "me", response_message) - - -def fetch_group_emails( - to_email, - minutes_since: int = 30, - gmail_token: str | None = None, - gmail_secret: str | None = None, - assistant_id=None, -) -> Iterable[EmailData]: - creds = get_credentials(assistant_id) +# Credentials Handling +@lru_cache(maxsize=1) +def get_credentials(gmail_token: str | None = None, gmail_secret: str | None = None) -> Credentials: + creds = None + _SECRETS_DIR.mkdir(parents=True, exist_ok=True) + + if gmail_token: + with open(_TOKEN_PATH, "w") as token: + token.write(gmail_token) + if gmail_secret: + with open(_SECRETS_PATH, "w") as secret: + secret.write(gmail_secret) + + if os.path.exists(_TOKEN_PATH): + creds = Credentials.from_authorized_user_file(_TOKEN_PATH) + + if not creds or not creds.valid or not creds.has_scopes(_SCOPES): + if creds and creds.expired and creds.refresh_token: + creds.refresh(Request()) + else: + flow = InstalledAppFlow.from_client_secrets_file(_SECRETS_PATH, _SCOPES) + creds = flow.run_local_server(port=54191) + with open(_TOKEN_PATH, "w") as token: + token.write(creds.to_json()) + + return creds + +# Gmail Functions +def get_gmail_service(): + creds = get_credentials() + return build("gmail", "v1", credentials=creds) + +def fetch_emails(query: str, to_email: str, minutes_since: int = 30) -> Iterable[dict]: + service = get_gmail_service() after = int((datetime.now() - timedelta(minutes=minutes_since)).timestamp()) - with ls.trace( - "Fetching emails", - inputs={"to_email": to_email, "after": after}, - metadata={"assistant_id": assistant_id}, - ) as rt: - service = build("gmail", "v1", credentials=creds) - - query = f"(to:{to_email} OR from:{to_email}) after:{after}" - messages = [] - nextPageToken = None - # Fetch messages matching the query - while True: - results = ( - service.users() - .messages() - .list(userId="me", q=query, pageToken=nextPageToken) - .execute() - ) - if "messages" in results: - messages.extend(results["messages"]) - nextPageToken = results.get("nextPageToken") - if not nextPageToken: - break - - count = 0 - for message in messages: - try: - msg = ( - service.users() - .messages() - .get(userId="me", id=message["id"]) - .execute() - ) - thread_id = msg["threadId"] - payload = msg["payload"] - headers = payload.get("headers") - # Get the thread details - thread = ( - service.users().threads().get(userId="me", id=thread_id).execute() - ) - messages_in_thread = thread["messages"] - # Check the last message in the thread - last_message = messages_in_thread[-1] - last_headers = last_message["payload"]["headers"] - from_header = next( - header["value"] - for header in last_headers - if header["name"] == "From" - ) - last_from_header = next( - header["value"] - for header in last_message["payload"].get("headers") - if header["name"] == "From" - ) - if to_email in last_from_header: - yield { - "id": message["id"], - "thread_id": message["threadId"], - "user_respond": True, - } - # Check if the last message was from you and if the current message is the last in the thread - if to_email not in from_header and message["id"] == last_message["id"]: - subject = next( - header["value"] - for header in headers - if header["name"] == "Subject" - ) - from_email = next( - ( - header["value"] - for header in headers - if header["name"] == "From" - ), - "", - ).strip() - _to_email = next( - ( - header["value"] - for header in headers - if header["name"] == "To" - ), - "", - ).strip() - if reply_to := next( - ( - header["value"] - for header in headers - if header["name"] == "Reply-To" - ), - "", - ).strip(): - from_email = reply_to - send_time = next( - header["value"] - for header in headers - if header["name"] == "Date" - ) - # Only process emails that are less than an hour old - parsed_time = parse_time(send_time) - body = extract_message_part(payload) - yield { - "from_email": from_email, - "to_email": _to_email, - "subject": subject, - "page_content": body, - "id": message["id"], - "thread_id": message["threadId"], - "send_time": parsed_time.isoformat(), - } - count += 1 - except Exception as e: - rt.error = str(e) - print(f"Failed on {message}") - - rt.add_outputs({"count": count}) - logger.info(f"Found {count} emails.") - - -def mark_as_read( - message_id, - gmail_token: str | None = None, - gmail_secret: str | None = None, - assistant_id=None, -): - creds = get_credentials(assistant_id) - - service = build("gmail", "v1", credentials=creds) - try: - service.users().messages().modify( - userId="me", id=message_id, body={"removeLabelIds": ["UNREAD"]} - ).execute() - except: - pass - + query = f"(to:{to_email} OR from:{to_email}) after:{after}" + + messages = [] + next_page_token = None + while True: + results = service.users().messages().list(userId="me", q=query, pageToken=next_page_token).execute() + if "messages" in results: + messages.extend(results["messages"]) + next_page_token = results.get("nextPageToken") + if not next_page_token: + break + + for message in messages: + try: + msg = service.users().messages().get(userId="me", id=message["id"]).execute() + payload = msg["payload"] + headers = payload.get("headers") + + subject = get_header(headers, "Subject") + from_email = get_header(headers, "From") + send_time = parse_time(get_header(headers, "Date")) + body = extract_message_part(payload) + + yield { + "id": message["id"], + "subject": subject, + "from_email": from_email, + "send_time": send_time, + "body": body, + } + except Exception as e: + logger.error(f"Failed to fetch email: {e}") -class CalInput(BaseModel): - date_strs: list[str] = Field( - description="The days for which to retrieve events. Each day should be represented by dd-mm-yyyy string." - ) +def send_email_response(message_id: str, response_text: str, additional_recipients: List[str] = None): + service = get_gmail_service() + message = service.users().messages().get(userId="me", id=message_id).execute() + headers = message["payload"]["headers"] + recipients = set(additional_recipients or []) + recipients.add(get_header(headers, "From")) -@tool(args_schema=CalInput) -def get_events_for_days(date_strs: list[str], config: RunnableConfig): - """ - Retrieves events for a list of days. If you want to check for multiple days, call this with multiple inputs. + subject = get_header(headers, "Subject") + response_subject = f"Re: {subject}" - Input in the format of ['dd-mm-yyyy', 'dd-mm-yyyy'] + reply_message = MIMEMultipart() + reply_message["to"] = ", ".join(recipients) + reply_message["subject"] = response_subject + msg = MIMEText(response_text) + reply_message.attach(msg) - Args: - date_strs: The days for which to retrieve events (dd-mm-yyyy string). + raw = base64.urlsafe_b64encode(reply_message.as_bytes()).decode() + service.users().messages().send(userId="me", body={"raw": raw}).execute() - Returns: availability for those days. - """ +# Google Calendar Functions +def get_calendar_service(): + creds = get_credentials() + return build("calendar", "v3", credentials=creds) - creds = get_credentials(config["configurable"]["assistant_id"]) - service = build("calendar", "v3", credentials=creds) +def get_events_for_days(date_strs: List[str]): + service = get_calendar_service() results = "" for date_str in date_strs: - # Convert the date string to a datetime.date object day = datetime.strptime(date_str, "%d-%m-%Y").date() - start_of_day = datetime.combine(day, time.min).isoformat() + "Z" end_of_day = datetime.combine(day, time.max).isoformat() + "Z" @@ -353,40 +168,14 @@ def get_events_for_days(date_strs: list[str], config: RunnableConfig): .execute() ) events = events_result.get("items", []) - - results += f"***FOR DAY {date_str}***\n\n" + print_events(events) + results += f"***FOR DAY {date_str}***\n\n" + format_events(events) return results - -def format_datetime_with_timezone(dt_str, timezone="US/Pacific"): - """ - Formats a datetime string with the specified timezone. - - Args: - dt_str: The datetime string to format. - timezone: The timezone to use for formatting. - - Returns: - A formatted datetime string with the timezone abbreviation. - """ - dt = datetime.fromisoformat(dt_str.replace("Z", "+00:00")) - tz = pytz.timezone(timezone) - dt = dt.astimezone(tz) - return dt.strftime("%Y-%m-%d %I:%M %p %Z") - - -def print_events(events): - """ - Prints the events in a human-readable format. - - Args: - events: List of events to print. - """ +def format_events(events): if not events: - return "No events found for this day." + return "No events found for this day.\n" result = "" - for event in events: start = event["start"].get("dateTime", event["start"].get("date")) end = event["end"].get("dateTime", event["end"].get("date")) @@ -396,63 +185,11 @@ def print_events(events): start = format_datetime_with_timezone(start) end = format_datetime_with_timezone(end) - result += f"Event: {summary}\n" - result += f"Starts: {start}\n" - result += f"Ends: {end}\n" - result += "-" * 40 + "\n" + result += f"Event: {summary}\nStarts: {start}\nEnds: {end}\n{'-' * 40}\n" return result - -def send_calendar_invite( - emails, - title, - start_time, - end_time, - email_address, - timezone="PST", - assistant_id=None, -): - creds = get_credentials(assistant_id) - service = build("calendar", "v3", credentials=creds) - - # Parse the start and end times - start_datetime = datetime.fromisoformat(start_time) - end_datetime = datetime.fromisoformat(end_time) - emails = list(set(emails + [email_address])) - event = { - "summary": title, - "start": { - "dateTime": start_datetime.isoformat(), - "timeZone": timezone, - }, - "end": { - "dateTime": end_datetime.isoformat(), - "timeZone": timezone, - }, - "attendees": [{"email": email} for email in emails], - "reminders": { - "useDefault": False, - "overrides": [ - {"method": "email", "minutes": 24 * 60}, - {"method": "popup", "minutes": 10}, - ], - }, - "conferenceData": { - "createRequest": { - "requestId": f"{title}-{start_datetime.isoformat()}", - "conferenceSolutionKey": {"type": "hangoutsMeet"}, - } - }, - } - - try: - service.events().insert( - calendarId="primary", - body=event, - sendNotifications=True, - conferenceDataVersion=1, - ).execute() - return True - except Exception as e: - print(f"An error occurred: {e}") - return False +def format_datetime_with_timezone(dt_str, timezone="US/Pacific"): + dt = datetime.fromisoformat(dt_str.replace("Z", "+00:00")) + tz = pytz.timezone(timezone) + dt = dt.astimezone(tz) + return dt.strftime("%Y-%m-%d %I:%M %p %Z")