Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: refactor agent runtime #6284

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 11 additions & 329 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,10 @@ import { ClientOptions } from 'openai';
import type { TracePayload } from '@/const/trace';

import { LobeRuntimeAI } from './BaseAI';
import { LobeAi21AI } from './ai21';
import { LobeAi360AI } from './ai360';
import { LobeAnthropicAI } from './anthropic';
import { LobeAzureOpenAI } from './azureOpenai';
import { LobeAzureAI } from './azureai';
import { LobeBaichuanAI } from './baichuan';
import { LobeBedrockAI, LobeBedrockAIParams } from './bedrock';
import { LobeCloudflareAI, LobeCloudflareParams } from './cloudflare';
import { LobeDeepSeekAI } from './deepseek';
import { LobeFireworksAI } from './fireworksai';
import { LobeGiteeAI } from './giteeai';
import { LobeGithubAI } from './github';
import { LobeGoogleAI } from './google';
import { LobeGroq } from './groq';
import { LobeHigressAI } from './higress';
import { LobeHuggingFaceAI } from './huggingface';
import { LobeHunyuanAI } from './hunyuan';
import { LobeInternLMAI } from './internlm';
import { LobeJinaAI } from './jina';
import { LobeLMStudioAI } from './lmstudio';
import { LobeMinimaxAI } from './minimax';
import { LobeMistralAI } from './mistral';
import { LobeMoonshotAI } from './moonshot';
import { LobeNovitaAI } from './novita';
import { LobeNvidiaAI } from './nvidia';
import { LobeOllamaAI } from './ollama';
import { LobeBedrockAIParams } from './bedrock';
import { LobeCloudflareParams } from './cloudflare';
import { LobeOpenAI } from './openai';
import { LobeOpenRouterAI } from './openrouter';
import { LobePerplexityAI } from './perplexity';
import { LobeQwenAI } from './qwen';
import { LobeSambaNovaAI } from './sambanova';
import { LobeSenseNovaAI } from './sensenova';
import { LobeSiliconCloudAI } from './siliconcloud';
import { LobeSparkAI } from './spark';
import { LobeStepfunAI } from './stepfun';
import { LobeTaichuAI } from './taichu';
import { LobeTencentCloudAI } from './tencentcloud';
import { LobeTogetherAI } from './togetherai';
import { providerRuntimeMap } from './runtimeMap';
import {
ChatCompetitionOptions,
ChatStreamPayload,
Expand All @@ -50,13 +16,6 @@ import {
TextToImagePayload,
TextToSpeechPayload,
} from './types';
import { LobeUpstageAI } from './upstage';
import { LobeVLLMAI } from './vllm';
import { LobeVolcengineAI } from './volcengine';
import { LobeWenxinAI } from './wenxin';
import { LobeXAI } from './xai';
import { LobeZeroOneAI } from './zeroone';
import { LobeZhipuAI } from './zhipu';

export interface AgentChatOptions {
enableTrace?: boolean;
Expand Down Expand Up @@ -127,9 +86,7 @@ class AgentRuntime {
* Try to initialize the runtime with the provider and the options.
* @example
* ```ts
* const runtime = await AgentRuntime.initializeWithProviderOptions(provider, {
* [provider]: {...options},
* })
* const runtime = await AgentRuntime.initializeWithProviderOptions(provider, options)
* ```
* **Note**: If you try to get a AgentRuntime instance from client or server,
* you should use the methods to get the runtime instance at first.
Expand All @@ -138,290 +95,15 @@ class AgentRuntime {
*/
static async initializeWithProviderOptions(
provider: string,
params: Partial<{
ai21: Partial<ClientOptions>;
ai360: Partial<ClientOptions>;
anthropic: Partial<ClientOptions>;
azure: { apiKey?: string; apiVersion?: string; baseURL?: string };
azureai: { apiKey?: string; apiVersion?: string; baseURL?: string };
baichuan: Partial<ClientOptions>;
bedrock: Partial<LobeBedrockAIParams>;
cloudflare: Partial<LobeCloudflareParams>;
deepseek: Partial<ClientOptions>;
doubao: Partial<ClientOptions>;
fireworksai: Partial<ClientOptions>;
giteeai: Partial<ClientOptions>;
github: Partial<ClientOptions>;
google: { apiKey?: string; baseURL?: string };
groq: Partial<ClientOptions>;
higress: Partial<ClientOptions>;
huggingface: { apiKey?: string; baseURL?: string };
hunyuan: Partial<ClientOptions>;
internlm: Partial<ClientOptions>;
jina: Partial<ClientOptions>;
lmstudio: Partial<ClientOptions>;
minimax: Partial<ClientOptions>;
mistral: Partial<ClientOptions>;
moonshot: Partial<ClientOptions>;
novita: Partial<ClientOptions>;
nvidia: Partial<ClientOptions>;
ollama: Partial<ClientOptions>;
openai: Partial<ClientOptions>;
openrouter: Partial<ClientOptions>;
perplexity: Partial<ClientOptions>;
qwen: Partial<ClientOptions>;
sambanova: Partial<ClientOptions>;
sensenova: Partial<ClientOptions>;
siliconcloud: Partial<ClientOptions>;
spark: Partial<ClientOptions>;
stepfun: Partial<ClientOptions>;
taichu: Partial<ClientOptions>;
tencentcloud: Partial<ClientOptions>;
togetherai: Partial<ClientOptions>;
upstage: Partial<ClientOptions>;
vllm: Partial<ClientOptions>;
volcengine: Partial<ClientOptions>;
wenxin: Partial<ClientOptions>;
xai: Partial<ClientOptions>;
zeroone: Partial<ClientOptions>;
zhipu: Partial<ClientOptions>;
}>,
params: Partial<
ClientOptions &
LobeBedrockAIParams &
LobeCloudflareParams & { apiKey?: string; apiVersion?: string; baseURL?: string }
>,
) {
let runtimeModel: LobeRuntimeAI;
const providerAI = providerRuntimeMap[provider as ModelProvider] ?? LobeOpenAI;
const runtimeModel: LobeRuntimeAI = new providerAI(params);

switch (provider) {
default:
case ModelProvider.OpenAI: {
// Will use the openai as default provider
runtimeModel = new LobeOpenAI(params.openai ?? (params as any)[provider]);
break;
}

case ModelProvider.Azure: {
runtimeModel = new LobeAzureOpenAI(
params.azure?.baseURL,
params.azure?.apiKey,
params.azure?.apiVersion,
);
break;
}

case ModelProvider.AzureAI: {
runtimeModel = new LobeAzureAI(params.azureai);
break;
}

case ModelProvider.ZhiPu: {
runtimeModel = new LobeZhipuAI(params.zhipu);
break;
}

case ModelProvider.Google: {
runtimeModel = new LobeGoogleAI(params.google);
break;
}

case ModelProvider.Moonshot: {
runtimeModel = new LobeMoonshotAI(params.moonshot);
break;
}

case ModelProvider.Bedrock: {
runtimeModel = new LobeBedrockAI(params.bedrock);
break;
}

case ModelProvider.LMStudio: {
runtimeModel = new LobeLMStudioAI(params.lmstudio);
break;
}

case ModelProvider.Ollama: {
runtimeModel = new LobeOllamaAI(params.ollama);
break;
}

case ModelProvider.VLLM: {
runtimeModel = new LobeVLLMAI(params.vllm);
break;
}

case ModelProvider.Perplexity: {
runtimeModel = new LobePerplexityAI(params.perplexity);
break;
}

case ModelProvider.Anthropic: {
runtimeModel = new LobeAnthropicAI(params.anthropic);
break;
}

case ModelProvider.DeepSeek: {
runtimeModel = new LobeDeepSeekAI(params.deepseek);
break;
}

case ModelProvider.HuggingFace: {
runtimeModel = new LobeHuggingFaceAI(params.huggingface);
break;
}

case ModelProvider.Minimax: {
runtimeModel = new LobeMinimaxAI(params.minimax);
break;
}

case ModelProvider.Mistral: {
runtimeModel = new LobeMistralAI(params.mistral);
break;
}

case ModelProvider.Groq: {
runtimeModel = new LobeGroq(params.groq);
break;
}

case ModelProvider.Github: {
runtimeModel = new LobeGithubAI(params.github);
break;
}

case ModelProvider.OpenRouter: {
runtimeModel = new LobeOpenRouterAI(params.openrouter);
break;
}

case ModelProvider.TogetherAI: {
runtimeModel = new LobeTogetherAI(params.togetherai);
break;
}

case ModelProvider.FireworksAI: {
runtimeModel = new LobeFireworksAI(params.fireworksai);
break;
}

case ModelProvider.ZeroOne: {
runtimeModel = new LobeZeroOneAI(params.zeroone);
break;
}

case ModelProvider.Qwen: {
runtimeModel = new LobeQwenAI(params.qwen);
break;
}

case ModelProvider.Stepfun: {
runtimeModel = new LobeStepfunAI(params.stepfun);
break;
}

case ModelProvider.Novita: {
runtimeModel = new LobeNovitaAI(params.novita);
break;
}

case ModelProvider.Nvidia: {
runtimeModel = new LobeNvidiaAI(params.nvidia);
break;
}

case ModelProvider.Baichuan: {
runtimeModel = new LobeBaichuanAI(params.baichuan);
break;
}

case ModelProvider.Taichu: {
runtimeModel = new LobeTaichuAI(params.taichu);
break;
}

case ModelProvider.Ai360: {
runtimeModel = new LobeAi360AI(params.ai360);
break;
}

case ModelProvider.SiliconCloud: {
runtimeModel = new LobeSiliconCloudAI(params.siliconcloud);
break;
}

case ModelProvider.GiteeAI: {
runtimeModel = new LobeGiteeAI(params.giteeai);
break;
}

case ModelProvider.Upstage: {
runtimeModel = new LobeUpstageAI(params.upstage);
break;
}

case ModelProvider.Spark: {
runtimeModel = new LobeSparkAI(params.spark);
break;
}

case ModelProvider.Ai21: {
runtimeModel = new LobeAi21AI(params.ai21);
break;
}

case ModelProvider.Hunyuan: {
runtimeModel = new LobeHunyuanAI(params.hunyuan);
break;
}

case ModelProvider.SenseNova: {
runtimeModel = new LobeSenseNovaAI(params.sensenova);
break;
}

case ModelProvider.XAI: {
runtimeModel = new LobeXAI(params.xai);
break;
}

case ModelProvider.Jina: {
runtimeModel = new LobeJinaAI(params.jina);
break;
}

case ModelProvider.SambaNova: {
runtimeModel = new LobeSambaNovaAI(params.sambanova);
break;
}

case ModelProvider.Cloudflare: {
runtimeModel = new LobeCloudflareAI(params.cloudflare);
break;
}

case ModelProvider.InternLM: {
runtimeModel = new LobeInternLMAI(params.internlm);
break;
}

case ModelProvider.Higress: {
runtimeModel = new LobeHigressAI(params.higress);
break;
}

case ModelProvider.TencentCloud: {
runtimeModel = new LobeTencentCloudAI(params[provider]);
break;
}

case ModelProvider.Volcengine:
case ModelProvider.Doubao: {
runtimeModel = new LobeVolcengineAI(params.volcengine || params.doubao);
break;
}

case ModelProvider.Wenxin: {
runtimeModel = new LobeWenxinAI(params.wenxin);
break;
}
}
return new AgentRuntime(runtimeModel);
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/libs/agent-runtime/azureOpenai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ import { OpenAIStream } from '../utils/streams';
export class LobeAzureOpenAI implements LobeRuntimeAI {
client: AzureOpenAI;

constructor(endpoint?: string, apikey?: string, apiVersion?: string) {
if (!apikey || !endpoint)
constructor(params: { apiKey?: string; apiVersion?: string, baseURL?: string; } = {}) {
if (!params.apiKey || !params.baseURL)
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new AzureOpenAI({
apiKey: apikey,
apiVersion,
apiKey: params.apiKey,
apiVersion: params.apiVersion,
dangerouslyAllowBrowser: true,
endpoint,
endpoint: params.baseURL,
});

this.baseURL = endpoint;
this.baseURL = params.baseURL;
}

baseURL: string;
Expand Down
Loading
Loading