Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise Draftail integration UI #35

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions src/wagtail_ai/static_src/AIControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ const LOADING_MESSAGES = [
'Interpreting your message, loading...',
];

function LoadingOverlay() {
function LoadingOverlay({
cancelHandler,
}: {
cancelHandler: React.MouseEventHandler<HTMLButtonElement>;
}) {
const loadingMessage =
LOADING_MESSAGES[Math.floor(Math.random() * LOADING_MESSAGES.length)];

Expand All @@ -35,6 +39,10 @@ function LoadingOverlay() {
</svg>
{loadingMessage}
</span>

<button onClick={cancelHandler} className="button button-secondary">
Cancel request
</button>
</div>
);
}
Expand Down Expand Up @@ -65,22 +73,45 @@ function AIControl({ getEditorState, onChange }: ControlComponentProps) {
const editorState = getEditorState() as EditorState;
const [isLoading, setIsLoading] = useState<Boolean>(false);
const [isDropdownOpen, setIsDropdownOpen] = useState<Boolean>(false);
const [error, setError] = useState(null);
const [error, setError] = useState<null | string>(null);
const aIControlRef = useRef<any>();

const container = aIControlRef?.current
? aIControlRef?.current.closest('[data-draftail-editor-wrapper]')
: null;

const abortController = new AbortController();

const cancelRequest: React.MouseEventHandler<HTMLButtonElement> = (e) => {
e.preventDefault();
// Call the abort method to cancel the request
abortController.abort();
setIsLoading(false); // Set loading to false to hide the overlay
};

const handleAction = async (prompt: Prompt) => {
setError(null);
setIsDropdownOpen(false);
setIsLoading(true);
try {
if (prompt.method === 'append') {
onChange(await processAction(editorState, prompt, handleAppend));
onChange(
await processAction(
editorState,
prompt,
handleAppend,
abortController,
),
);
} else {
onChange(await processAction(editorState, prompt, handleReplace));
onChange(
await processAction(
editorState,
prompt,
handleReplace,
abortController,
),
);
}
} catch (err) {
setError(err.message);
Expand All @@ -105,20 +136,24 @@ function AIControl({ getEditorState, onChange }: ControlComponentProps) {
) : null}
{error && container?.parentNode
? createPortal(
<>
<div className="w-field__errors">
<svg
className="icon icon-warning w-field__errors-icon"
aria-hidden="true"
>
<use href="#icon-warning"></use>
</svg>
&nbsp;
Morsey187 marked this conversation as resolved.
Show resolved Hide resolved
<p className="error-message">{error}</p>
</>,
</div>,
container.parentNode.previousElementSibling,
)
: null}
{isLoading && container
? createPortal(<LoadingOverlay />, container)
? createPortal(
<LoadingOverlay cancelHandler={cancelRequest} />,
container,
)
: null}
</>
);
Expand Down
33 changes: 26 additions & 7 deletions src/wagtail_ai/static_src/main.css
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@
.Draftail-AI-LoadingOverlay {
background-color: rgba(255, 255, 255, 0.5);
backdrop-filter: blur(4px);
background-color: rgba(255, 255, 255, 0.7);

.w-theme-dark & {
background-color: rgba(60, 60, 60, 0.7);
}

@media (prefers-color-scheme: dark) {
.w-theme-system & {
background-color: rgba(60, 60, 60, 0.7);
}
}

backdrop-filter: blur(8px);
position: absolute;
z-index: 19;
width: 100%;
height: 100%;
top: 0;
left: 0;
z-index: 100;
display: flex;
align-items: center;
justify-content: center;
flex-direction: column;
}

.Draftail-AI-LoadingOverlay > span > svg {
margin-right: 5px;
}

.Draftail-AI-LoadingOverlay > span {
margin-bottom: 12px;
}

.Draftail-AI-ButtonDropdown {
position: absolute;
border: 1px solid var(--w-color-grey-200);
border-radius: 0.3125rem;
background-color: white;
background-color: var(--w-color-surface-tooltip);
color: var(--w-color-text-label-menus-default);
display: flex;
flex-direction: column;
min-width: 300px;
z-index: 200;
padding: 0.5rem 0;
}

.Draftail-AI-ButtonDropdown > button {
background-color: transparent;
color: var(--w-color-text-label-menus-default);
background-color: initial;
padding: 8px;
text-align: left;
}
Expand All @@ -39,7 +57,8 @@
}

.Draftail-AI-ButtonDropdown > button:hover {
background-color: var(--w-color-grey-200);
background-color: var(--w-color-surface-menu-item-active);
color: var(--w-color-text-label-menus-active);
}

/* Temp fix for wagtail issue https://github.com/wagtail/wagtail/issues/11302 */
Expand Down
10 changes: 8 additions & 2 deletions src/wagtail_ai/static_src/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class APIRequestError extends Error {}
const fetchAIResponse = async (
text: string,
prompt: Prompt,
signal: AbortSignal,
): Promise<string> => {
const formData = new FormData();
formData.append('text', text);
Expand All @@ -21,6 +22,7 @@ const fetchAIResponse = async (
const res = await fetch(window.WAGTAIL_AI_PROCESS_URL, {
method: 'POST',
body: formData,
signal: signal,
});
const json = await res.json();
if (res.ok) {
Expand All @@ -29,7 +31,6 @@ const fetchAIResponse = async (
throw new APIRequestError(json.error);
}
} catch (err) {
console.log('here');
throw new APIRequestError(err.message);
}
};
Expand Down Expand Up @@ -90,9 +91,14 @@ export const processAction = async (
editorState: EditorState,
response: string,
) => EditorState,
abortController: AbortController, // Pass the AbortController instance
): Promise<EditorState> => {
const content = editorState.getCurrentContent();
const plainText = content.getPlainText();
const response = await fetchAIResponse(plainText, prompt);
const response = await fetchAIResponse(
plainText,
prompt,
abortController.signal,
);
return editorStateHandler(editorState, response);
};
34 changes: 30 additions & 4 deletions src/wagtail_ai/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt

from . import ai, prompts
from . import ai, prompts, types

logger = logging.getLogger(__name__)

Expand All @@ -13,14 +13,38 @@ class AIHandlerException(Exception):
pass


def _process_backend_request(
ai_backend: ai.AIBackend, pre_prompt: str, context: str
) -> types.AIResponse:
"""
Method for processing prompt requests and handling errors.

Errors will either be an API or Python library error, this method uses exception
chaining to retain the original error and raise a more generic error message to be sent to the front-end.

:return: The response message from the AI backend.
:raises AIHandlerException: Raised for specific error scenarios to be communicated to the front-end.
"""
try:
response = ai_backend.prompt_with_context(
pre_prompt=pre_prompt, context=context
)
except Exception as e:
# Raise a more generic error to send to the front-end
raise AIHandlerException(
"Error processing request, Please try again later."
) from e
return response


def _replace_handler(*, prompt: prompts.Prompt, text: str) -> str:
ai_backend = ai.get_ai_backend(alias=prompt.backend)
splitter = ai_backend.get_text_splitter()
texts = splitter.split_text(text)

for split in texts:
response = ai_backend.prompt_with_context(
pre_prompt=prompt.prompt, context=split
response = _process_backend_request(
ai_backend, pre_prompt=prompt.prompt, context=split
)
# Remove extra blank lines returned by the API
message = os.linesep.join([s for s in response.text().splitlines() if s])
Expand All @@ -35,7 +59,9 @@ def _append_handler(*, prompt: prompts.Prompt, text: str) -> str:
if length_calculator.get_splitter_length(text) > ai_backend.config.token_limit:
raise AIHandlerException("Cannot run completion on text this long")

response = ai_backend.prompt_with_context(pre_prompt=prompt.prompt, context=text)
response = _process_backend_request(
ai_backend, pre_prompt=prompt.prompt, context=text
)
# Remove extra blank lines returned by the API
message = os.linesep.join([s for s in response.text().splitlines() if s])

Expand Down
Loading