From e2d5a2e6b5839337575fbd622ef271ca1d46be1b Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Wed, 17 Apr 2024 08:44:10 +0200 Subject: [PATCH] feat: provides model properties to model service while running an AI App (#904) * feat: provides model properties to model service while running an AI App Fixes #903 Signed-off-by: Jeff MAURY * refactor: refactor following @alex7083 review Signed-off-by: Jeff MAURY --------- Signed-off-by: Jeff MAURY --- .../src/managers/applicationManager.spec.ts | 27 +++++++++++++++---- .../src/managers/applicationManager.ts | 5 +++- packages/backend/src/utils/inferenceUtils.ts | 18 +++++-------- packages/backend/src/utils/modelsUtils.ts | 13 +++++++++ 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index e0a292d98..33d5a82e1 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -819,9 +819,11 @@ describe('createApplicationPod', () => { vi.spyOn(manager, 'createPod').mockResolvedValue(pod); const createAndAddContainersToPodMock = vi .spyOn(manager, 'createAndAddContainersToPod') - .mockImplementation((_pod: ApplicationPodInfo, _images: ImageInfo[], _modelPath: string) => Promise.resolve([])); + .mockImplementation((_pod: ApplicationPodInfo, _images: ImageInfo[], _modelInfo: ModelInfo, _modelPath: string) => + Promise.resolve([]), + ); await manager.createApplicationPod({ id: 'recipe-id' } as Recipe, { id: 'model-id' } as ModelInfo, images, 'path'); - expect(createAndAddContainersToPodMock).toBeCalledWith(pod, images, 'path'); + expect(createAndAddContainersToPodMock).toBeCalledWith(pod, images, { id: 'model-id' }, 'path'); expect(mocks.updateTaskMock).toBeCalledWith({ id: expect.any(String), state: 'success', @@ -935,17 +937,17 @@ describe('createAndAddContainersToPod', () => { modelService: true, ports: ['8085'], }; - test('check that containers are correctly created', async () => { + async function checkContainers(modelInfo: ModelInfo, extraEnvs: string[]) { mocks.createContainerMock.mockResolvedValue({ id: 'container-1', }); vi.spyOn(podman, 'isQEMUMachine').mockResolvedValue(false); vi.spyOn(manager, 'getRandomName').mockReturnValue('name'); - await manager.createAndAddContainersToPod(pod, [imageInfo1, imageInfo2], 'path'); + await manager.createAndAddContainersToPod(pod, [imageInfo1, imageInfo2], modelInfo, 'path'); expect(mocks.createContainerMock).toHaveBeenNthCalledWith(1, 'engine', { Image: 'id', Detach: true, - Env: ['MODEL_ENDPOINT=http://localhost:8085'], + Env: ['MODEL_ENDPOINT=http://localhost:8085', ...extraEnvs], start: false, name: 'name', pod: 'id', @@ -980,6 +982,21 @@ describe('createAndAddContainersToPod', () => { Timeout: 2000000000, }, }); + } + + test('check that containers are correctly created with no model properties', async () => { + await checkContainers({} as ModelInfo, []); + }); + + test('check that containers are correctly created with model properties', async () => { + await checkContainers( + { + properties: { + modelName: 'myModel', + }, + } as unknown as ModelInfo, + ['MODEL_MODEL_NAME=myModel'], + ); }); }); diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index 7be0a9efe..1efff532f 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -48,6 +48,7 @@ import type { TaskRegistry } from '../registries/TaskRegistry'; import { Publisher } from '../utils/Publisher'; import { isQEMUMachine } from '../utils/podman'; import { SECOND } from '../utils/inferenceUtils'; +import { getModelPropertiesForEnvironment } from '../utils/modelsUtils'; export const LABEL_MODEL_ID = 'ai-lab-model-id'; export const LABEL_MODEL_PORTS = 'ai-lab-model-ports'; @@ -251,7 +252,7 @@ export class ApplicationManager extends Publisher implements let attachedContainers: ContainerAttachedInfo[]; try { - attachedContainers = await this.createAndAddContainersToPod(podInfo, images, modelPath); + attachedContainers = await this.createAndAddContainersToPod(podInfo, images, model, modelPath); task.state = 'success'; } catch (e) { console.error(`error when creating pod ${podInfo.Id}`, e); @@ -269,6 +270,7 @@ export class ApplicationManager extends Publisher implements async createAndAddContainersToPod( podInfo: ApplicationPodInfo, images: ImageInfo[], + modelInfo: ModelInfo, modelPath: string, ): Promise { const containers: ContainerAttachedInfo[] = []; @@ -299,6 +301,7 @@ export class ApplicationManager extends Publisher implements if (modelService && modelService.ports.length > 0) { const endPoint = `http://localhost:${modelService.ports[0]}`; envs = [`MODEL_ENDPOINT=${endPoint}`]; + envs.push(...getModelPropertiesForEnvironment(modelInfo)); } } if (image.ports.length > 0) { diff --git a/packages/backend/src/utils/inferenceUtils.ts b/packages/backend/src/utils/inferenceUtils.ts index 973086262..aee340a62 100644 --- a/packages/backend/src/utils/inferenceUtils.ts +++ b/packages/backend/src/utils/inferenceUtils.ts @@ -16,18 +16,19 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ import { - containerEngine, - provider, type ContainerCreateOptions, + containerEngine, type ContainerProviderConnection, - type PullEvent, - type ProviderContainerConnection, type ImageInfo, type ListImagesOptions, + provider, + type ProviderContainerConnection, + type PullEvent, } from '@podman-desktop/api'; import type { CreationInferenceServerOptions, InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from './utils'; import { getFreeRandomPort } from './ports'; +import { getModelPropertiesForEnvironment } from './modelsUtils'; export const SECOND: number = 1_000_000_000; @@ -116,14 +117,7 @@ export function generateContainerCreateOptions( } const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000']; - if (modelInfo.properties) { - envs.push( - ...Object.entries(modelInfo.properties).map(([key, value]) => { - const formattedKey = key.replace(/[A-Z]/g, m => `_${m}`).toUpperCase(); - return `MODEL_${formattedKey}=${value}`; - }), - ); - } + envs.push(...getModelPropertiesForEnvironment(modelInfo)); return { Image: imageInfo.Id, diff --git a/packages/backend/src/utils/modelsUtils.ts b/packages/backend/src/utils/modelsUtils.ts index 117e448cb..5e9ed2f5b 100644 --- a/packages/backend/src/utils/modelsUtils.ts +++ b/packages/backend/src/utils/modelsUtils.ts @@ -71,3 +71,16 @@ export async function deleteRemoteModel(machine: string, modelInfo: ModelInfo): console.error('Something went wrong while trying to stat remote model path', err); } } + +export function getModelPropertiesForEnvironment(modelInfo: ModelInfo): string[] { + const envs: string[] = []; + if (modelInfo.properties) { + envs.push( + ...Object.entries(modelInfo.properties).map(([key, value]) => { + const formattedKey = key.replace(/[A-Z]/g, m => `_${m}`).toUpperCase(); + return `MODEL_${formattedKey}=${value}`; + }), + ); + } + return envs; +}