diff --git a/src/app/app-routing.module.ts b/src/app/app-routing.module.ts index de1f51b..424393f 100644 --- a/src/app/app-routing.module.ts +++ b/src/app/app-routing.module.ts @@ -1,6 +1,5 @@ import { NgModule } from '@angular/core'; import { RouterModule, Routes } from '@angular/router'; -import { SettingsComponent } from './components/settings/settings.component'; import { SvgCreatorComponent } from './components/svg-creator/svg-creator.component'; import { LayoutComponent } from './layout/layout.component'; diff --git a/src/app/components/svg-creator/svg-creator.component.html b/src/app/components/svg-creator/svg-creator.component.html index 0cbd3a7..3e0edd4 100644 --- a/src/app/components/svg-creator/svg-creator.component.html +++ b/src/app/components/svg-creator/svg-creator.component.html @@ -23,18 +23,28 @@
Input
+
-
SVG Preview
+
SVG {{ previewTitle }}
@@ -44,6 +54,7 @@
SVG Preview
type="button" icon="pi pi-file-import" label="Import SVG" + [disabled]="submitting" (click)="chooseFile()"> SVG Preview icon="pi pi-file-export" [disabled]="!sanitizedSvg" label="Export SVG" + [disabled]="submitting" (click)="exportSVG()">
; @@ -34,8 +35,31 @@ export class SvgCreatorComponent { this.clearFile(); } + ngOnInit() { + this.previewTitle = 'Preview'; + this.creator.svgCode$.subscribe(rs => { + if (rs.done) { + // SSE 结束后统一解析 + const svgCodes = this.creator.extractSVGCode(rs.content); + if (svgCodes.length) { + this.svgCode = svgCodes.join('\n'); + } + } else { + // 临时显示输出 + this._svgCode = rs.content + .replace(/[\s\S]*(?=())[\s\S]*/g, ''); + } + }); + } + + ngOnDestroy() { + this.creator.svgCode$.unsubscribe(); + } + async submit() { this.submitting = true; + this.previewTitle = 'Generating...'; let svgCodes = this.creator.extractSVGCode(this.userInput); if (svgCodes.length) { this.svgCode = svgCodes.join('\n'); @@ -43,13 +67,14 @@ export class SvgCreatorComponent { const nonSVG = this.creator.extractNonSVGCode(this.userInput); // 仅在有其他内容的时候,调用 API if (nonSVG) { - const outputMsg = await this.creator.analyzeInput(this.userInput, this.svgCode); - svgCodes = this.creator.extractSVGCode(outputMsg); - if (svgCodes.length) { - this.svgCode = svgCodes.join('\n'); - } + await this.creator.analyzeInputStreaming(this.userInput, this.svgCode); } this.submitting = false; + this.previewTitle = 'Preview'; + } + + stop() { + this.creator.stopAnalyze(); } clear() { @@ -59,10 +84,8 @@ export class SvgCreatorComponent { } exportSVG() { - console.log('exportSVG'); const svg = this.svgPreview?.nativeElement.querySelector('svg'); if (!svg) { - console.log('null svg'); return; } const serializer = new XMLSerializer(); diff --git a/src/app/services/chatgpt-api.service.ts b/src/app/services/chatgpt-api.service.ts index 88215ee..cda56c0 100644 --- a/src/app/services/chatgpt-api.service.ts +++ b/src/app/services/chatgpt-api.service.ts @@ -1,21 +1,20 @@ import { Injectable } from '@angular/core'; -import { HttpClient } from '@angular/common/http'; -import { lastValueFrom } from 'rxjs'; +import { Subject } from 'rxjs'; import { ConfigService } from './config.service'; -import { Message } from '../types'; +import { Message, StreamingResult } from '../types'; import { MessageService } from 'primeng/api'; @Injectable({ providedIn: 'root', }) export class ChatGPTAPIService { - constructor( - private config: ConfigService, - private http: HttpClient, - private messageSvc: MessageService - ) {} + constructor(private config: ConfigService, private messageSvc: MessageService) {} - async doChat(messages: Message[]): Promise { + async doChatStream( + messages: Message[], + sub: Subject, + signal: AbortSignal + ): Promise { const apiKey = await this.config.get('openAPIKey'); if (!apiKey) { this.messageSvc.add({ @@ -23,7 +22,7 @@ export class ChatGPTAPIService { summary: 'Missing API Key', detail: 'Please update your openAPIKey in Settings.', }); - return ''; + return false; } const apiHost = await this.config.get('openAPIHost'); const currentModel = await this.config.get('currentModel'); @@ -33,23 +32,75 @@ export class ChatGPTAPIService { messages: messages, max_tokens: 2048, temperature: 0.5, - }; - const headers = { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, + stream: true, }; try { - const rsp$ = this.http.post(apiURL, data, { headers }); - const rsp = await lastValueFrom(rsp$); - return rsp.choices[0].message.content.trim(); - } catch (err) { - console.error(err); + const rsp = await fetch(apiURL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(data), + signal, + }); + if (!rsp.ok) { + this.messageSvc.add({ + severity: 'error', + summary: 'Error', + detail: 'Failed call openai api', + }); + return false; + } + if (!rsp.body) { + return true; + } + await this.parseSSE(rsp.body, sub); + } catch (err: any) { this.messageSvc.add({ severity: 'error', summary: 'Error', - detail: 'Failed call openai api', + detail: err.message || 'Failed call openai api', }); - return ''; + return false; } + return true; + } + + async parseSSE(rs: ReadableStream, sub: Subject): Promise { + let readData; + let partialMessage = ''; + let resultContent = ''; + const reader = rs.getReader(); + const decoder = new TextDecoder('utf-8'); + do { + readData = await reader.read(); + if (!readData.done) { + partialMessage += decoder.decode(readData.value, { stream: true }); + // Split the message by new lines + const lines = partialMessage.split('\n'); + // If the last line is not empty, it means that it's an incomplete message + if (lines[lines.length - 1] !== '') { + partialMessage = lines.pop() ?? ''; + } else { + partialMessage = ''; + } + // Process each line + for (const line of lines) { + if (line.startsWith('data: ')) { + const jsonString = line.slice('data: '.length); + // Check if the jsonString is not "[DONE]" + if (jsonString !== '[DONE]') { + const payload = JSON.parse(jsonString); + resultContent += payload?.choices[0]?.delta?.content ?? ''; + sub.next({ content: resultContent, done: false }); + } else { + sub.next({ content: resultContent, done: true }); + return; + } + } + } + } + } while (!readData.done); } } diff --git a/src/app/services/creator.service.ts b/src/app/services/creator.service.ts index 97d1cd9..6dd8751 100644 --- a/src/app/services/creator.service.ts +++ b/src/app/services/creator.service.ts @@ -1,5 +1,6 @@ import { Injectable } from '@angular/core'; -import { Message } from '../types'; +import { Subject } from 'rxjs'; +import { Message, StreamingResult } from '../types'; import { ChatGPTAPIService } from './chatgpt-api.service'; const SVG_REG_EXP: RegExp = //gi; @@ -8,7 +9,9 @@ const SVG_REG_EXP: RegExp = //gi; providedIn: 'root', }) export class CreatorService { - readonly systemPrompts: Message[] = [ + public svgCode$ = new Subject(); + + private readonly systemPrompts: Message[] = [ { role: 'user', content: @@ -19,22 +22,34 @@ export class CreatorService { content: 'ok', }, ]; + private abortCtl?: AbortController; constructor(private api: ChatGPTAPIService) {} - async analyzeInput(msg: string, originalSVGCode: string) { + async analyzeInputStreaming(msg: string, originalSVGCode: string) { let headMsg = ''; if (originalSVGCode) { headMsg += `Given the original SVG: ${originalSVGCode}\n`; } - const res = await this.api.doChat([ - ...this.systemPrompts, - { - role: 'user', - content: headMsg + msg, - }, - ]); - return res; + this.abortCtl = new AbortController(); + const rsp = await this.api.doChatStream( + [ + ...this.systemPrompts, + { + role: 'user', + content: headMsg + msg, + }, + ], + this.svgCode$, + this.abortCtl.signal + ); + if (!rsp) { + return; + } + } + + stopAnalyze() { + this.abortCtl?.abort(); } // 提取 SVG 代码 diff --git a/src/app/types.ts b/src/app/types.ts index b18c3a0..3200b0d 100644 --- a/src/app/types.ts +++ b/src/app/types.ts @@ -1,5 +1,7 @@ export type Message = { role: string; content: string }; +export type StreamingResult = { content: string; done: boolean }; + export const enum GPTModels { GPT_3_5_TURBO = 'gpt-3.5-turbo', GPT_4 = 'gpt-4', diff --git a/src/manifest.json b/src/manifest.json index e2407d1..51276c5 100644 --- a/src/manifest.json +++ b/src/manifest.json @@ -1,6 +1,6 @@ { "name": "ChatGPT SVG Creator", - "version": "1.0.1", + "version": "1.1.0", "description": "An extension uses ChatGPT to create and preview SVG", "manifest_version": 3, "action": {},