Skip to content

Commit

Permalink
Added support for multi model (#52)
Browse files Browse the repository at this point in the history
* Added support for multi model

* update model name

---------

Co-authored-by: Zuhwa Chooi <[email protected]>
  • Loading branch information
Zuhwa and Zuhwa Chooi authored Feb 12, 2025
1 parent d33e3df commit 389c3f6
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 43 deletions.
2 changes: 2 additions & 0 deletions examples/twitter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
GameAgent,
GameFunction,
GameWorker,
LLMModel,
} from "@virtuals-protocol/game";

const postTweetFunction = new GameFunction({
Expand Down Expand Up @@ -108,6 +109,7 @@ const agent = new GameAgent("API_KEY", {
goal: "Search and reply to tweets",
description: "A bot that searches for tweets and replies to them",
workers: [postTweetWorker],
llmModel: LLMModel.DeepSeek_R1, // Optional: Set the LLM model default (LLMModel.Llama_3_1_405B_Instruct)
// Optional: Get the agent state
getAgentState: async () => {
return {
Expand Down
83 changes: 42 additions & 41 deletions src/agent.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import GameClient from "./api";
import GameClientV2 from "./apiV2";
import { ExecutableGameFunctionResponseJSON } from "./function";
import { ActionType, IGameClient } from "./interface/GameClient";
import { ActionType, IGameClient, LLMModel } from "./interface/GameClient";
import GameWorker from "./worker";

interface IGameAgent {
Expand All @@ -10,6 +10,7 @@ interface IGameAgent {
description: string;
workers: GameWorker[];
getAgentState?: () => Promise<Record<string, any>>;
llmModel?: LLMModel | string;
}

class GameAgent implements IGameAgent {
Expand All @@ -31,9 +32,11 @@ class GameAgent implements IGameAgent {
}

constructor(apiKey: string, options: IGameAgent) {
const llmModel = options.llmModel || LLMModel.Llama_3_1_405B_Instruct;

this.gameClient = apiKey.startsWith("apt-")
? new GameClientV2(apiKey)
: new GameClient(apiKey);
? new GameClientV2(apiKey, llmModel)
: new GameClient(apiKey, llmModel);
this.workerId = options.workers[0].id;

this.name = options.name;
Expand Down Expand Up @@ -176,44 +179,42 @@ class GameAgent implements IGameAgent {
}

save(): Record<string, any> {
return {
agentId: this.agentId,
mapId: this.mapId,
gameActionResult: this.gameActionResult,
}
}

async initWorkers() {
this.workers.forEach((worker) => {
worker.setAgentId(this.agentId || '')
worker.setLogger(this.log.bind(this))
worker.setGameClient(this.gameClient)
})
}

static async load(
apiKey: string,
name: string,
goal: string,
description: string,
savedState: Record<string, any>,
workers: GameWorker[]
): Promise<GameAgent> {
const agent = new GameAgent(apiKey, {
name: name,
goal: goal,
description: description,
workers,
})

agent.agentId = savedState.agentId
agent.mapId = savedState.mapId
agent.gameActionResult = savedState.gameActionResult


return agent
}
return {
agentId: this.agentId,
mapId: this.mapId,
gameActionResult: this.gameActionResult,
};
}

async initWorkers() {
this.workers.forEach((worker) => {
worker.setAgentId(this.agentId || "");
worker.setLogger(this.log.bind(this));
worker.setGameClient(this.gameClient);
});
}

static async load(
apiKey: string,
name: string,
goal: string,
description: string,
savedState: Record<string, any>,
workers: GameWorker[]
): Promise<GameAgent> {
const agent = new GameAgent(apiKey, {
name: name,
goal: goal,
description: description,
workers,
});

agent.agentId = savedState.agentId;
agent.mapId = savedState.mapId;
agent.gameActionResult = savedState.gameActionResult;

return agent;
}
}

export default GameAgent;
export default GameAgent;
4 changes: 3 additions & 1 deletion src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import {
GameAction,
GameAgent,
IGameClient,
LLMModel,
Map,
} from "./interface/GameClient";

class GameClient implements IGameClient {
public client: Axios | null = null;
private runnerUrl = "https://game.virtuals.io";

constructor(private apiKey: string) {}
constructor(private apiKey: string, private llmModel: LLMModel | string) {}

async init() {
const accessToken = await this.getAccessToken();
Expand All @@ -22,6 +23,7 @@ class GameClient implements IGameClient {
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${accessToken}`,
model_name: this.llmModel,
},
});
}
Expand Down
4 changes: 3 additions & 1 deletion src/apiV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
GameAction,
GameAgent,
IGameClient,
LLMModel,
Map,
} from "./interface/GameClient";
import GameWorker from "./worker";
Expand All @@ -12,12 +13,13 @@ class GameClientV2 implements IGameClient {
public client: Axios;
private baseUrl = "https://sdk.game.virtuals.io/v2";

constructor(private apiKey: string) {
constructor(private apiKey: string, private llmModel: LLMModel | string) {
this.client = axios.create({
baseURL: this.baseUrl,
headers: {
"Content-Type": "application/json",
"x-api-key": this.apiKey,
model_name: this.llmModel,
},
});
}
Expand Down
2 changes: 2 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import GameFunction, {
ExecutableGameFunctionResponse,
ExecutableGameFunctionStatus,
} from "./function";
import { LLMModel } from "./interface/GameClient";

export {
GameAgent,
GameFunction,
GameWorker,
ExecutableGameFunctionResponse,
ExecutableGameFunctionStatus,
LLMModel,
};
8 changes: 8 additions & 0 deletions src/interface/GameClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ export interface GameAction {
agent_state?: Record<string, any>;
}

export enum LLMModel {
Llama_3_1_405B_Instruct = "Llama-3.1-405B-Instruct",
Llama_3_3_70B_Instruct = "Llama-3.3-70B-Instruct",
DeepSeek_R1 = "DeepSeek-R1",
DeepSeek_V3 = "DeepSeek-V3",
Qwen_2_5_72B_Instruct = "Qwen-2.5-72B-Instruct",
}

export interface IGameClient {
client: Axios | null;
createMap(workers: GameWorker[]): Promise<Map>;
Expand Down

0 comments on commit 389c3f6

Please sign in to comment.