Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image description with OpenAI #81

Merged
merged 21 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/wagtail_ai/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..text_splitters.length import NaiveTextSplitterCalculator
from ..types import TextSplitterLengthCalculatorProtocol, TextSplitterProtocol
from ..utils import deprecation
from .base import AIBackend, BaseAIBackendConfigSettings
from .base import AIBackend, BackendFeature, BaseAIBackendConfigSettings


class TextSplittingSettingsDict(TypedDict):
Expand Down Expand Up @@ -161,3 +161,22 @@ def get_ai_backend(alias: str) -> AIBackend:
)

return ai_backend_cls(config=config)


class BackendNotFound(Exception):
pass


def get_backend(feature: BackendFeature = BackendFeature.TEXT_COMPLETION) -> AIBackend:
match feature:
case BackendFeature.TEXT_COMPLETION:
alias = "default"
case BackendFeature.IMAGE_DESCRIPTION:
alias = settings.WAGTAIL_AI.get("IMAGE_DESCRIPTION_BACKEND")
case _:
alias = None

if alias is None:
raise BackendNotFound(f"No backend found for {feature.name}")

return get_ai_backend(alias)
15 changes: 12 additions & 3 deletions src/wagtail_ai/ai/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
ClassVar,
Expand All @@ -13,6 +14,7 @@
)

from django.core.exceptions import ImproperlyConfigured
from django.core.files import File

from .. import tokens
from ..types import (
Expand All @@ -22,6 +24,11 @@
)


class BackendFeature(Enum):
TEXT_COMPLETION = "TEXT_COMPLETION"
IMAGE_DESCRIPTION = "IMAGE_DESCRIPTION"


class BaseAIBackendConfigSettings(TypedDict):
MODEL_ID: Required[str]
TOKEN_LIMIT: NotRequired[int | None]
Expand Down Expand Up @@ -99,14 +106,13 @@ def __init__(
) -> None:
self.config = config

@abstractmethod
def prompt_with_context(
self, *, pre_prompt: str, context: str, post_prompt: str | None = None
) -> AIResponse:
"""
Given a prompt and a context, return a response.
"""
...
raise NotImplementedError("This backend does not support text completion")

def get_text_splitter(self) -> TextSplitterProtocol:
return self.config.text_splitter_class(
Expand All @@ -116,3 +122,6 @@ def get_text_splitter(self) -> TextSplitterProtocol:

def get_splitter_length_calculator(self) -> TextSplitterLengthCalculatorProtocol:
return self.config.text_splitter_length_calculator_class()

def describe_image(self, *, image_file: File, prompt: str) -> str:
raise NotImplementedError("This backend does not support generating alt tags")
86 changes: 86 additions & 0 deletions src/wagtail_ai/ai/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import base64
import os
from dataclasses import dataclass
from typing import Any, NotRequired, Self

import requests
from django.core.files import File

from .base import AIBackend, BaseAIBackendConfig, BaseAIBackendConfigSettings


class OpenAIBackendConfigSettingsDict(BaseAIBackendConfigSettings):
TIMEOUT_SECONDS: NotRequired[int | None]


@dataclass(kw_only=True)
class OpenAIBackendConfig(BaseAIBackendConfig[OpenAIBackendConfigSettingsDict]):
timeout_seconds: int

@classmethod
def from_settings(
cls, config: OpenAIBackendConfigSettingsDict, **kwargs: Any
) -> Self:
timeout_seconds = config.get("TIMEOUT_SECONDS")
if timeout_seconds is None:
timeout_seconds = 15
kwargs.setdefault("timeout_seconds", timeout_seconds)

return super().from_settings(config, **kwargs)


class OpenAIBackend(AIBackend[OpenAIBackendConfig]):
config_cls = OpenAIBackendConfig

def describe_image(self, *, image_file: File, prompt: str) -> str:
if not prompt:
raise ValueError("Prompt must not be empty.")
with image_file.open() as f:
base64_image = base64.b64encode(f.read()).decode("utf-8")

response = self.chat_completions(
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
},
],
)
return response["choices"][0]["message"]["content"]

def chat_completions(self, messages: list[dict[str, Any]]):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.get_openai_api_key()}",
}
payload = {
"model": self.config.model_id,
"messages": messages,
"max_tokens": self.config.token_limit,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=self.config.timeout_seconds,
)

response.raise_for_status()
return response.json()

def get_openai_api_key(self) -> str:
mgax marked this conversation as resolved.
Show resolved Hide resolved
env_key = os.environ.get("OPENAI_API_KEY")
if env_key is None:
raise RuntimeError("OPENAI_API_KEY environment variable is not set.")
return env_key
35 changes: 28 additions & 7 deletions src/wagtail_ai/forms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from django import forms
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
from wagtail.admin.staticfiles import versioned_static
from wagtail.images.forms import BaseImageForm


class PromptTextField(forms.CharField):
Expand All @@ -18,7 +20,17 @@ class PromptUUIDField(forms.UUIDField):
}


class PromptForm(forms.Form):
class ApiForm(forms.Form):
def errors_for_json_response(self) -> str:
errors_for_response = []
for _field, errors in self.errors.get_json_data().items():
for error in errors:
errors_for_response.append(error["message"])

return " \n".join(errors_for_response)


class PromptForm(ApiForm):
text = PromptTextField()
prompt = PromptUUIDField()

Expand All @@ -31,10 +43,19 @@ def clean_prompt(self):

return prompt_uuid

def errors_for_json_response(self) -> str:
errors_for_response = []
for _field, errors in self.errors.get_json_data().items():
for error in errors:
errors_for_response.append(error["message"])

return " \n".join(errors_for_response)
class DescribeImageApiForm(ApiForm):
image_id = forms.CharField()


class DescribeImageForm(BaseImageForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.instance and self.instance.pk:
widget = self.fields["title"].widget
widget.attrs["data-wagtail-ai-image-id"] = str(self.instance.pk)
widget.template_name = "wagtail_ai/widgets/image_title.html"

class Media:
js = [versioned_static("wagtail_ai/image-description.js")]
css = {"all": [versioned_static("wagtail_ai/image-description.css")]}
8 changes: 2 additions & 6 deletions src/wagtail_ai/static_src/AIControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@ import { ToolbarButton } from 'draftail';
import { createPortal } from 'react-dom';
import { useOutsideAlerter } from './hooks';
import WandIcon from './WandIcon';
import {
handleAppend,
handleReplace,
processAction,
getAIConfiguration,
} from './utils';
import { handleAppend, handleReplace, processAction } from './utils';
import { getAIConfiguration } from './api';

import type { ControlComponentProps } from 'draftail';
import type { Prompt } from './custom';
Expand Down
50 changes: 50 additions & 0 deletions src/wagtail_ai/static_src/api.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import type { ApiUrlName, Prompt, WagtailAiConfiguration } from './custom';

class APIRequestError extends Error {}

export const getAIConfiguration = (): WagtailAiConfiguration => {
const configurationElement =
document.querySelector<HTMLScriptElement>('#wagtail-ai-config');
if (!configurationElement || !configurationElement.textContent) {
throw new Error('No wagtail-ai configuration found.');
}

try {
return JSON.parse(configurationElement.textContent);
} catch (err) {
throw new SyntaxError(
`Error parsing wagtail-ai configuration: ${err.message}`,
);
}
};

// TODO rename
export const fetchAIResponse = async (
text: string,
prompt: Prompt,
signal: AbortSignal,
): Promise<string> => {
const formData = new FormData();
formData.append('text', text);
formData.append('prompt', prompt.uuid);
return fetchResponse('TEXT_COMPLETION', formData, signal);
};

export const fetchResponse = async (
action: keyof typeof ApiUrlName,
body: FormData,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Out of curiosity, why do we prefer FormData instead of JSON? Is that so we could send files in the future or is that what is a custom in Wagtail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering this too, but went with what the text completion endpoint was already doing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like an extra step to parse JSON in the view, whereas with FormData we're just passing request.POST to the form?

signal?: AbortSignal,
): Promise<string> => {
try {
const urls = getAIConfiguration().urls;
const res = await fetch(urls[action], { method: 'POST', body, signal });
const json = await res.json();
if (res.ok) {
return json.message;
} else {
throw new APIRequestError(json.error);
}
} catch (err) {
throw new APIRequestError(err.message);
}
};
12 changes: 11 additions & 1 deletion src/wagtail_ai/static_src/custom.d.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
/* eslint-disable no-unused-vars */
export {};

export interface WagtailAiConfigurationUrls {}

export enum ApiUrlName {
TEXT_COMPLETION = 'TEXT_COMPLETION',
DESCRIBE_IMAGE = 'DESCRIBE_IMAGE',
}

export interface WagtailAiConfiguration {
aiPrompts: Prompt[];
aiProcessUrl: string;
urls: {
[ApiUrlName.TEXT_COMPLETION]: string;
[ApiUrlName.DESCRIBE_IMAGE]: string;
};
}

export type Prompt = {
Expand Down
16 changes: 16 additions & 0 deletions src/wagtail_ai/static_src/image_description.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.wagtail-ai-image-title {
display: flex;
flex-direction: row;
align-items: stretch;
gap: 0.3125rem;
}

.wagtail-ai-button {
background: white;
border: 1px solid var(--w-color-border-field-default);
border-radius: 0.3125rem;
}

.wagtail-ai-button:hover {
border-color: var(--w-color-border-field-hover);
}
39 changes: 39 additions & 0 deletions src/wagtail_ai/static_src/image_description.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import './image_description.css';
import { fetchResponse } from './api';

// TODO find a way to import the same SVG file here and in WandIcon.tsx
const wandIcon = `<svg
width="16"
height="16"
className="Draftail-Icon"
aria-hidden="true"
viewBox="0 0 576 512"
>
<path d="M234.7 42.7L197 56.8c-3 1.1-5 4-5 7.2s2 6.1 5 7.2l37.7 14.1L248.8 123c1.1 3 4 5 7.2 5s6.1-2 7.2-5l14.1-37.7L315 71.2c3-1.1 5-4 5-7.2s-2-6.1-5-7.2L277.3 42.7 263.2 5c-1.1-3-4-5-7.2-5s-6.1 2-7.2 5L234.7 42.7zM46.1 395.4c-18.7 18.7-18.7 49.1 0 67.9l34.6 34.6c18.7 18.7 49.1 18.7 67.9 0L529.9 116.5c18.7-18.7 18.7-49.1 0-67.9L495.3 14.1c-18.7-18.7-49.1-18.7-67.9 0L46.1 395.4zM484.6 82.6l-105 105-23.3-23.3 105-105 23.3 23.3zM7.5 117.2C3 118.9 0 123.2 0 128s3 9.1 7.5 10.8L64 160l21.2 56.5c1.7 4.5 6 7.5 10.8 7.5s9.1-3 10.8-7.5L128 160l56.5-21.2c4.5-1.7 7.5-6 7.5-10.8s-3-9.1-7.5-10.8L128 96 106.8 39.5C105.1 35 100.8 32 96 32s-9.1 3-10.8 7.5L64 96 7.5 117.2zm352 256c-4.5 1.7-7.5 6-7.5 10.8s3 9.1 7.5 10.8L416 416l21.2 56.5c1.7 4.5 6 7.5 10.8 7.5s9.1-3 10.8-7.5L480 416l56.5-21.2c4.5-1.7 7.5-6 7.5-10.8s-3-9.1-7.5-10.8L480 352l-21.2-56.5c-1.7-4.5-6-7.5-10.8-7.5s-9.1 3-10.8 7.5L416 352l-56.5 21.2z" />
</svg>`;

document.addEventListener('wagtail-ai:image-form', (event) => {
const form = event.target as HTMLFormElement;
const input = form.querySelector('[name*=title]') as HTMLInputElement;
const imageId = input.dataset['wagtail-ai-image-id'] as string;
const inputContainer = input.parentNode as HTMLDivElement;
inputContainer.classList.add('wagtail-ai-image-title'); // TODO better class name
const button = document.createElement('button');
button.classList.add('wagtail-ai-button'); // TODO better class name
button.innerHTML = wandIcon;
inputContainer.appendChild(button);

button.addEventListener('click', async (event) => {
event.preventDefault();
button.disabled = true;
button.innerText = '…';

const formData = new FormData();
formData.append('image_id', imageId);
// TODO error handling
const response = await fetchResponse('DESCRIBE_IMAGE', formData);
input.value = response;
button.innerHTML = wandIcon;
button.disabled = false;
});
});
Loading
Loading