Skip to content

Commit

Permalink
feat: provides model properties to model service while running an AI …
Browse files Browse the repository at this point in the history
…App (#904)

* feat: provides model properties to model service while running an AI App

Fixes #903

Signed-off-by: Jeff MAURY <[email protected]>

* refactor: refactor following @alex7083 review

Signed-off-by: Jeff MAURY <[email protected]>

---------

Signed-off-by: Jeff MAURY <[email protected]>
  • Loading branch information
jeffmaury authored Apr 17, 2024
1 parent f95a8cb commit e2d5a2e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
27 changes: 22 additions & 5 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,11 @@ describe('createApplicationPod', () => {
vi.spyOn(manager, 'createPod').mockResolvedValue(pod);
const createAndAddContainersToPodMock = vi
.spyOn(manager, 'createAndAddContainersToPod')
.mockImplementation((_pod: ApplicationPodInfo, _images: ImageInfo[], _modelPath: string) => Promise.resolve([]));
.mockImplementation((_pod: ApplicationPodInfo, _images: ImageInfo[], _modelInfo: ModelInfo, _modelPath: string) =>
Promise.resolve([]),
);
await manager.createApplicationPod({ id: 'recipe-id' } as Recipe, { id: 'model-id' } as ModelInfo, images, 'path');
expect(createAndAddContainersToPodMock).toBeCalledWith(pod, images, 'path');
expect(createAndAddContainersToPodMock).toBeCalledWith(pod, images, { id: 'model-id' }, 'path');
expect(mocks.updateTaskMock).toBeCalledWith({
id: expect.any(String),
state: 'success',
Expand Down Expand Up @@ -935,17 +937,17 @@ describe('createAndAddContainersToPod', () => {
modelService: true,
ports: ['8085'],
};
test('check that containers are correctly created', async () => {
async function checkContainers(modelInfo: ModelInfo, extraEnvs: string[]) {
mocks.createContainerMock.mockResolvedValue({
id: 'container-1',
});
vi.spyOn(podman, 'isQEMUMachine').mockResolvedValue(false);
vi.spyOn(manager, 'getRandomName').mockReturnValue('name');
await manager.createAndAddContainersToPod(pod, [imageInfo1, imageInfo2], 'path');
await manager.createAndAddContainersToPod(pod, [imageInfo1, imageInfo2], modelInfo, 'path');
expect(mocks.createContainerMock).toHaveBeenNthCalledWith(1, 'engine', {
Image: 'id',
Detach: true,
Env: ['MODEL_ENDPOINT=http://localhost:8085'],
Env: ['MODEL_ENDPOINT=http://localhost:8085', ...extraEnvs],
start: false,
name: 'name',
pod: 'id',
Expand Down Expand Up @@ -980,6 +982,21 @@ describe('createAndAddContainersToPod', () => {
Timeout: 2000000000,
},
});
}

test('check that containers are correctly created with no model properties', async () => {
await checkContainers({} as ModelInfo, []);
});

test('check that containers are correctly created with model properties', async () => {
await checkContainers(
{
properties: {
modelName: 'myModel',
},
} as unknown as ModelInfo,
['MODEL_MODEL_NAME=myModel'],
);
});
});

Expand Down
5 changes: 4 additions & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import type { TaskRegistry } from '../registries/TaskRegistry';
import { Publisher } from '../utils/Publisher';
import { isQEMUMachine } from '../utils/podman';
import { SECOND } from '../utils/inferenceUtils';
import { getModelPropertiesForEnvironment } from '../utils/modelsUtils';

export const LABEL_MODEL_ID = 'ai-lab-model-id';
export const LABEL_MODEL_PORTS = 'ai-lab-model-ports';
Expand Down Expand Up @@ -251,7 +252,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements

let attachedContainers: ContainerAttachedInfo[];
try {
attachedContainers = await this.createAndAddContainersToPod(podInfo, images, modelPath);
attachedContainers = await this.createAndAddContainersToPod(podInfo, images, model, modelPath);
task.state = 'success';
} catch (e) {
console.error(`error when creating pod ${podInfo.Id}`, e);
Expand All @@ -269,6 +270,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
async createAndAddContainersToPod(
podInfo: ApplicationPodInfo,
images: ImageInfo[],
modelInfo: ModelInfo,
modelPath: string,
): Promise<ContainerAttachedInfo[]> {
const containers: ContainerAttachedInfo[] = [];
Expand Down Expand Up @@ -299,6 +301,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
if (modelService && modelService.ports.length > 0) {
const endPoint = `http://localhost:${modelService.ports[0]}`;
envs = [`MODEL_ENDPOINT=${endPoint}`];
envs.push(...getModelPropertiesForEnvironment(modelInfo));
}
}
if (image.ports.length > 0) {
Expand Down
18 changes: 6 additions & 12 deletions packages/backend/src/utils/inferenceUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import {
containerEngine,
provider,
type ContainerCreateOptions,
containerEngine,
type ContainerProviderConnection,
type PullEvent,
type ProviderContainerConnection,
type ImageInfo,
type ListImagesOptions,
provider,
type ProviderContainerConnection,
type PullEvent,
} from '@podman-desktop/api';
import type { CreationInferenceServerOptions, InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from './utils';
import { getFreeRandomPort } from './ports';
import { getModelPropertiesForEnvironment } from './modelsUtils';

export const SECOND: number = 1_000_000_000;

Expand Down Expand Up @@ -116,14 +117,7 @@ export function generateContainerCreateOptions(
}

const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000'];
if (modelInfo.properties) {
envs.push(
...Object.entries(modelInfo.properties).map(([key, value]) => {
const formattedKey = key.replace(/[A-Z]/g, m => `_${m}`).toUpperCase();
return `MODEL_${formattedKey}=${value}`;
}),
);
}
envs.push(...getModelPropertiesForEnvironment(modelInfo));

return {
Image: imageInfo.Id,
Expand Down
13 changes: 13 additions & 0 deletions packages/backend/src/utils/modelsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,16 @@ export async function deleteRemoteModel(machine: string, modelInfo: ModelInfo):
console.error('Something went wrong while trying to stat remote model path', err);
}
}

export function getModelPropertiesForEnvironment(modelInfo: ModelInfo): string[] {
const envs: string[] = [];
if (modelInfo.properties) {
envs.push(
...Object.entries(modelInfo.properties).map(([key, value]) => {
const formattedKey = key.replace(/[A-Z]/g, m => `_${m}`).toUpperCase();
return `MODEL_${formattedKey}=${value}`;
}),
);
}
return envs;
}

0 comments on commit e2d5a2e

Please sign in to comment.