Skip to content

Commit

Permalink
feat: support stream mode (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
xieziyu authored Apr 4, 2023
1 parent fdb2a00 commit 36d4d6d
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 45 deletions.
1 change: 0 additions & 1 deletion src/app/app-routing.module.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { NgModule } from '@angular/core';
import { RouterModule, Routes } from '@angular/router';
import { SettingsComponent } from './components/settings/settings.component';
import { SvgCreatorComponent } from './components/svg-creator/svg-creator.component';
import { LayoutComponent } from './layout/layout.component';

Expand Down
16 changes: 14 additions & 2 deletions src/app/components/svg-creator/svg-creator.component.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,28 @@ <h5>Input</h5>
<button
pButton
pRipple
*ngIf="!submitting"
label="Clear"
icon="pi pi-trash"
iconPos="right"
class="p-button-danger"
class="p-button-secondary"
(click)="clear()"></button>
<button
pButton
pRipple
*ngIf="submitting"
label="Stop"
icon="pi pi-stop-circle"
iconPos="right"
class="p-button-danger"
(click)="stop()"></button>
</div>
</div>
</div>

<div class="col-12 md:col-6">
<div class="card">
<h5>SVG Preview</h5>
<h5>SVG {{ previewTitle }}</h5>
<markdown ngPreserveWhitespaces clipboard [data]="svgCode | language : 'svg'"></markdown>
<p-divider></p-divider>
<div class="flex justify-content-center flex-wrap gap-2">
Expand All @@ -44,6 +54,7 @@ <h5>SVG Preview</h5>
type="button"
icon="pi pi-file-import"
label="Import SVG"
[disabled]="submitting"
(click)="chooseFile()">
<input
#advancedfileinput
Expand All @@ -61,6 +72,7 @@ <h5>SVG Preview</h5>
icon="pi pi-file-export"
[disabled]="!sanitizedSvg"
label="Export SVG"
[disabled]="submitting"
(click)="exportSVG()"></button>
</div>
<div
Expand Down
41 changes: 32 additions & 9 deletions src/app/components/svg-creator/svg-creator.component.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Component, ElementRef, ViewChild } from '@angular/core';
import { Component, ElementRef, OnDestroy, OnInit, ViewChild } from '@angular/core';
import { DomSanitizer, SafeHtml } from '@angular/platform-browser';
import { CreatorService } from '../../services/creator.service';

Expand All @@ -7,11 +7,12 @@ import { CreatorService } from '../../services/creator.service';
templateUrl: './svg-creator.component.html',
styleUrls: ['./svg-creator.component.scss'],
})
export class SvgCreatorComponent {
export class SvgCreatorComponent implements OnInit, OnDestroy {
userInput = '';
sanitizedSvg: SafeHtml = '';

submitting = false;
previewTitle = 'Preview';

@ViewChild('svgPreview', { static: true })
svgPreview?: ElementRef<HTMLElement>;
Expand All @@ -34,22 +35,46 @@ export class SvgCreatorComponent {
this.clearFile();
}

ngOnInit() {
this.previewTitle = 'Preview';
this.creator.svgCode$.subscribe(rs => {
if (rs.done) {
// SSE 结束后统一解析
const svgCodes = this.creator.extractSVGCode(rs.content);
if (svgCodes.length) {
this.svgCode = svgCodes.join('\n');
}
} else {
// 临时显示输出
this._svgCode = rs.content
.replace(/[\s\S]*(?=(<svg))/gi, '')
.replace(/(?<=(\/svg>))[\s\S]*/g, '');
}
});
}

ngOnDestroy() {
this.creator.svgCode$.unsubscribe();
}

async submit() {
this.submitting = true;
this.previewTitle = 'Generating...';
let svgCodes = this.creator.extractSVGCode(this.userInput);
if (svgCodes.length) {
this.svgCode = svgCodes.join('\n');
}
const nonSVG = this.creator.extractNonSVGCode(this.userInput);
// 仅在有其他内容的时候,调用 API
if (nonSVG) {
const outputMsg = await this.creator.analyzeInput(this.userInput, this.svgCode);
svgCodes = this.creator.extractSVGCode(outputMsg);
if (svgCodes.length) {
this.svgCode = svgCodes.join('\n');
}
await this.creator.analyzeInputStreaming(this.userInput, this.svgCode);
}
this.submitting = false;
this.previewTitle = 'Preview';
}

stop() {
this.creator.stopAnalyze();
}

clear() {
Expand All @@ -59,10 +84,8 @@ export class SvgCreatorComponent {
}

exportSVG() {
console.log('exportSVG');
const svg = this.svgPreview?.nativeElement.querySelector('svg');
if (!svg) {
console.log('null svg');
return;
}
const serializer = new XMLSerializer();
Expand Down
93 changes: 72 additions & 21 deletions src/app/services/chatgpt-api.service.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import { Injectable } from '@angular/core';
import { HttpClient } from '@angular/common/http';
import { lastValueFrom } from 'rxjs';
import { Subject } from 'rxjs';
import { ConfigService } from './config.service';
import { Message } from '../types';
import { Message, StreamingResult } from '../types';
import { MessageService } from 'primeng/api';

@Injectable({
providedIn: 'root',
})
export class ChatGPTAPIService {
constructor(
private config: ConfigService,
private http: HttpClient,
private messageSvc: MessageService
) {}
constructor(private config: ConfigService, private messageSvc: MessageService) {}

async doChat(messages: Message[]): Promise<string> {
async doChatStream(
messages: Message[],
sub: Subject<StreamingResult>,
signal: AbortSignal
): Promise<boolean> {
const apiKey = await this.config.get('openAPIKey');
if (!apiKey) {
this.messageSvc.add({
severity: 'warn',
summary: 'Missing API Key',
detail: 'Please update your openAPIKey in Settings.',
});
return '';
return false;
}
const apiHost = await this.config.get('openAPIHost');
const currentModel = await this.config.get('currentModel');
Expand All @@ -33,23 +32,75 @@ export class ChatGPTAPIService {
messages: messages,
max_tokens: 2048,
temperature: 0.5,
};
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
stream: true,
};
try {
const rsp$ = this.http.post<any>(apiURL, data, { headers });
const rsp = await lastValueFrom(rsp$);
return rsp.choices[0].message.content.trim();
} catch (err) {
console.error(err);
const rsp = await fetch(apiURL, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(data),
signal,
});
if (!rsp.ok) {
this.messageSvc.add({
severity: 'error',
summary: 'Error',
detail: 'Failed call openai api',
});
return false;
}
if (!rsp.body) {
return true;
}
await this.parseSSE(rsp.body, sub);
} catch (err: any) {
this.messageSvc.add({
severity: 'error',
summary: 'Error',
detail: 'Failed call openai api',
detail: err.message || 'Failed call openai api',
});
return '';
return false;
}
return true;
}

async parseSSE(rs: ReadableStream, sub: Subject<StreamingResult>): Promise<void> {
let readData;
let partialMessage = '';
let resultContent = '';
const reader = rs.getReader();
const decoder = new TextDecoder('utf-8');
do {
readData = await reader.read();
if (!readData.done) {
partialMessage += decoder.decode(readData.value, { stream: true });
// Split the message by new lines
const lines = partialMessage.split('\n');
// If the last line is not empty, it means that it's an incomplete message
if (lines[lines.length - 1] !== '') {
partialMessage = lines.pop() ?? '';
} else {
partialMessage = '';
}
// Process each line
for (const line of lines) {
if (line.startsWith('data: ')) {
const jsonString = line.slice('data: '.length);
// Check if the jsonString is not "[DONE]"
if (jsonString !== '[DONE]') {
const payload = JSON.parse(jsonString);
resultContent += payload?.choices[0]?.delta?.content ?? '';
sub.next({ content: resultContent, done: false });
} else {
sub.next({ content: resultContent, done: true });
return;
}
}
}
}
} while (!readData.done);
}
}
37 changes: 26 additions & 11 deletions src/app/services/creator.service.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Injectable } from '@angular/core';
import { Message } from '../types';
import { Subject } from 'rxjs';
import { Message, StreamingResult } from '../types';
import { ChatGPTAPIService } from './chatgpt-api.service';

const SVG_REG_EXP: RegExp = /<svg[\s\S]*?<\/svg>/gi;
Expand All @@ -8,7 +9,9 @@ const SVG_REG_EXP: RegExp = /<svg[\s\S]*?<\/svg>/gi;
providedIn: 'root',
})
export class CreatorService {
readonly systemPrompts: Message[] = [
public svgCode$ = new Subject<StreamingResult>();

private readonly systemPrompts: Message[] = [
{
role: 'user',
content:
Expand All @@ -19,22 +22,34 @@ export class CreatorService {
content: 'ok',
},
];
private abortCtl?: AbortController;

constructor(private api: ChatGPTAPIService) {}

async analyzeInput(msg: string, originalSVGCode: string) {
async analyzeInputStreaming(msg: string, originalSVGCode: string) {
let headMsg = '';
if (originalSVGCode) {
headMsg += `Given the original SVG: ${originalSVGCode}\n`;
}
const res = await this.api.doChat([
...this.systemPrompts,
{
role: 'user',
content: headMsg + msg,
},
]);
return res;
this.abortCtl = new AbortController();
const rsp = await this.api.doChatStream(
[
...this.systemPrompts,
{
role: 'user',
content: headMsg + msg,
},
],
this.svgCode$,
this.abortCtl.signal
);
if (!rsp) {
return;
}
}

stopAnalyze() {
this.abortCtl?.abort();
}

// 提取 SVG 代码
Expand Down
2 changes: 2 additions & 0 deletions src/app/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export type Message = { role: string; content: string };

export type StreamingResult = { content: string; done: boolean };

export const enum GPTModels {
GPT_3_5_TURBO = 'gpt-3.5-turbo',
GPT_4 = 'gpt-4',
Expand Down
2 changes: 1 addition & 1 deletion src/manifest.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "ChatGPT SVG Creator",
"version": "1.0.1",
"version": "1.1.0",
"description": "An extension uses ChatGPT to create and preview SVG",
"manifest_version": 3,
"action": {},
Expand Down

0 comments on commit 36d4d6d

Please sign in to comment.