Skip to content

Commit

Permalink
feat: test InstructLab in a container
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff MAURY <[email protected]>
  • Loading branch information
jeffmaury committed Dec 23, 2024
1 parent 472de48 commit 0df95a8
Show file tree
Hide file tree
Showing 15 changed files with 598 additions and 19 deletions.
3 changes: 3 additions & 0 deletions packages/backend/src/assets/instructlab-images.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"default": "docker.io/redhat/instructlab@sha256:c6b2ecb4547b1f43b5539ee99bdbf5c9ae40599fabe1c740622295d9721b91c4"
}
14 changes: 14 additions & 0 deletions packages/backend/src/instructlab-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,25 @@
import type { InstructlabAPI } from '@shared/src/InstructlabAPI';
import type { InstructlabManager } from './managers/instructlab/instructlabManager';
import type { InstructlabSession } from '@shared/src/models/instructlab/IInstructlabSession';
import type { InstructlabContainerConfiguration } from '@shared/src/models/instructlab/IInstructlabContainerConfiguration';
import { navigation } from '@podman-desktop/api';

export class InstructlabApiImpl implements InstructlabAPI {
constructor(private instructlabManager: InstructlabManager) {}

async getIsntructlabSessions(): Promise<InstructlabSession[]> {
return this.instructlabManager.getSessions();
}

requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<string> {
return this.instructlabManager.requestCreateInstructlabContainer(config);
}

routeToInstructLabContainerTerminal(containerId: string): Promise<void> {
return navigation.navigateToContainerTerminal(containerId);
}

getInstructlabContainerId(): Promise<string | undefined> {
return this.instructlabManager.getInstructLabContainer();
}
}
4 changes: 2 additions & 2 deletions packages/backend/src/managers/inference/inferenceManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import type { InferenceServer, InferenceServerStatus, InferenceType } from '@sha
import type { PodmanConnection, PodmanConnectionEvent } from '../podmanConnection';
import { containerEngine, Disposable } from '@podman-desktop/api';
import type { ContainerInfo, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api';
import type { ContainerRegistry, ContainerStart } from '../../registries/ContainerRegistry';
import type { ContainerRegistry, ContainerEvent } from '../../registries/ContainerRegistry';
import { getInferenceType, isTransitioning, LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils';
import { Publisher } from '../../utils/Publisher';
import { Messages } from '@shared/Messages';
Expand Down Expand Up @@ -318,7 +318,7 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
* Listener for container start events
* @param event the event containing the id of the container
*/
private watchContainerStart(event: ContainerStart): void {
private watchContainerStart(event: ContainerEvent): void {
// We might have a start event for an inference server we already know about
if (this.#servers.has(event.id)) return;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/**********************************************************************
* Copyright (C) 2024 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import { TaskRegistry } from '../../registries/TaskRegistry';
import { beforeAll, beforeEach, expect, test, vi } from 'vitest';
import type { ContainerInfo, Webview } from '@podman-desktop/api';
import { containerEngine, EventEmitter } from '@podman-desktop/api';
import type { PodmanConnection } from '../podmanConnection';
import { INSTRUCTLAB_CONTAINER_LABEL, InstructlabManager } from './instructlabManager';
import { ContainerRegistry } from '../../registries/ContainerRegistry';
import { TestEventEmitter } from '../../tests/utils';

vi.mock('@podman-desktop/api', () => {
return {
EventEmitter: vi.fn(),
containerEngine: {
listContainers: vi.fn(),
onEvent: vi.fn(),
},
};
});

const taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockReturnValue(undefined) } as unknown as Webview);

const podmanConnection: PodmanConnection = {
onPodmanConnectionEvent: vi.fn(),
} as unknown as PodmanConnection;

let instructlabManager: InstructlabManager;

beforeAll(() => {
vi.mocked(EventEmitter).mockImplementation(() => new TestEventEmitter() as unknown as EventEmitter<unknown>);
});

beforeEach(() => {
const containerRegistry = new ContainerRegistry();
containerRegistry.init();
instructlabManager = new InstructlabManager('', taskRegistry, podmanConnection, containerRegistry);
});

test('getInstructLabContainer should return undefined if no containers', async () => {
vi.mocked(containerEngine.listContainers).mockResolvedValue([]);
const containerId = await instructlabManager.getInstructLabContainer();
expect(containerId).toBeUndefined();
});

test('getInstructLabContainer should return undefined if no instructlab container', async () => {
vi.mocked(containerEngine.listContainers).mockResolvedValue([{ Id: 'dummyId' } as unknown as ContainerInfo]);
const containerId = await instructlabManager.getInstructLabContainer();
expect(containerId).toBeUndefined();
});

test('getInstructLabContainer should return id if instructlab container', async () => {
vi.mocked(containerEngine.listContainers).mockResolvedValue([
{
Id: 'dummyId',
State: 'running',
Labels: { [`${INSTRUCTLAB_CONTAINER_LABEL}`]: 'dummyLabel' },
} as unknown as ContainerInfo,
]);
const containerId = await instructlabManager.getInstructLabContainer();
expect(containerId).toBe('dummyId');
});
201 changes: 201 additions & 0 deletions packages/backend/src/managers/instructlab/instructlabManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,63 @@
***********************************************************************/

import type { InstructlabSession } from '@shared/src/models/instructlab/IInstructlabSession';
import type { InstructlabContainerConfiguration } from '@shared/src/models/instructlab/IInstructlabContainerConfiguration';
import { getRandomString } from '../../utils/randomUtils';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import { type ContainerProviderConnection, containerEngine, type ContainerCreateOptions } from '@podman-desktop/api';
import type { PodmanConnection, PodmanConnectionEvent } from '../podmanConnection';
import instructlab_images from '../../assets/instructlab-images.json';
import { getImageInfo } from '../../utils/inferenceUtils';
import path from 'node:path';
import fs from 'node:fs/promises';
import type { ContainerRegistry, ContainerEvent } from '../../registries/ContainerRegistry';

export const INSTRUCTLAB_CONTAINER_LABEL = 'ai-lab-instructlab-container';

export class InstructlabManager {
#initialized: boolean;
#containerId: string | undefined;

constructor(
private readonly appUserDirectory: string,
private taskRegistry: TaskRegistry,
private podmanConnection: PodmanConnection,
private containerRegistry: ContainerRegistry,
) {
this.#initialized = false;
this.podmanConnection.onPodmanConnectionEvent(this.watchMachineEvent.bind(this));
this.containerRegistry.onStartContainerEvent(this.onStartContainerEvent.bind(this));
this.containerRegistry.onStopContainerEvent(this.onStopContainerEvent.bind(this));
}

private async refreshInstructlabContainer(id?: string): Promise<void> {
const containers = await containerEngine.listContainers();
const containerId = (this.#containerId = containers
.filter(c => !id || c.Id === id)
.filter(c => c.State === 'running' && c.Labels && INSTRUCTLAB_CONTAINER_LABEL in c.Labels)
.map(c => c.Id)
.at(0));
if ((id && containerId) || !id) {
this.#containerId = containerId;
}
}

private async watchMachineEvent(event: PodmanConnectionEvent): Promise<void> {
if ((event.status === 'started' && !this.#containerId) || (event.status === 'stopped' && this.#containerId)) {
await this.refreshInstructlabContainer();
}
}

private async onStartContainerEvent(event: ContainerEvent): Promise<void> {
await this.refreshInstructlabContainer(event.id);
}

private onStopContainerEvent(event: ContainerEvent): void {
if (this.#containerId === event.id) {
this.#containerId = undefined;
}
}

public getSessions(): InstructlabSession[] {
return [
{
Expand All @@ -39,4 +94,150 @@ export class InstructlabManager {
},
];
}

async getInstructLabContainer(): Promise<string | undefined> {
if (!this.#initialized) {
const containers = await containerEngine.listContainers();
this.#containerId = containers
.filter(c => c.State === 'running' && c.Labels && INSTRUCTLAB_CONTAINER_LABEL in c.Labels)
.map(c => c.Id)
.at(0);
this.#initialized = true;
}
return this.#containerId;
}

async requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<string> {
// create a tracking id to put in the labels
const trackingId: string = getRandomString();

const labels = {
trackingId: trackingId,
};

const task = this.taskRegistry.createTask('Creating InstructLab container', 'loading', {
trackingId: trackingId,
});

let connection: ContainerProviderConnection | undefined;
if (config.connection) {
connection = this.podmanConnection.getContainerProviderConnection(config.connection);
} else {
connection = this.podmanConnection.findRunningContainerProviderConnection();
}

if (!connection) throw new Error('cannot find running container provider connection');

this.createInstructlabContainer(connection, labels)
.then((containerId: string) => {
this.#containerId = containerId;
this.taskRegistry.updateTask({
...task,
state: 'success',
labels: {
...task.labels,
containerId: containerId,
},
});
})
.catch((err: unknown) => {
// Get all tasks using the tracker
const tasks = this.taskRegistry.getTasksByLabels({
trackingId: trackingId,
});
// Filter the one no in loading state
tasks
.filter(t => t.state === 'loading' && t.id !== task.id)
.forEach(t => {
this.taskRegistry.updateTask({
...t,
state: 'error',
});
});
// Update the main task
this.taskRegistry.updateTask({
...task,
state: 'error',
error: `Something went wrong while trying to create an inference server ${String(err)}.`,
});
});
return trackingId;
}

async createInstructlabContainer(
connection: ContainerProviderConnection,
labels: { [p: string]: string },
): Promise<string> {
const image = instructlab_images.default;
const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels);
const imageInfo = await getImageInfo(connection, image, () => {})
.catch((err: unknown) => {
pullingTask.state = 'error';
pullingTask.progress = undefined;
pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`;
throw err;
})
.then(imageInfo => {
pullingTask.state = 'success';
pullingTask.progress = undefined;
return imageInfo;
})
.finally(() => {
this.taskRegistry.updateTask(pullingTask);
});

const folder = await this.getInstructLabContainerFolder();

const containerTask = this.taskRegistry.createTask('Starting InstructLab container', 'loading', labels);
const createContainerOptions: ContainerCreateOptions = {
Image: imageInfo.Id,
name: `instructlab-${labels['trackingId']}`,
Labels: { [INSTRUCTLAB_CONTAINER_LABEL]: image },
HostConfig: {
AutoRemove: true,
Mounts: [
{
Target: '/instructlab/.cache/instructlab',
Source: path.join(folder, '.cache'),
Type: 'bind',
},
{
Target: '/instructlab/.config/instructlab',
Source: path.join(folder, '.config'),
Type: 'bind',
},
{
Target: '/instructlab/.local/share/instructlab',
Source: path.join(folder, '.local'),
Type: 'bind',
},
],
},
OpenStdin: true,
start: true,
};
try {
const { id } = await containerEngine.createContainer(imageInfo.engineId, createContainerOptions);
// update the task
containerTask.state = 'success';
containerTask.progress = undefined;
return id;
} catch (err: unknown) {
containerTask.state = 'error';
containerTask.progress = undefined;
containerTask.error = `Something went wrong while creating container: ${String(err)}`;
throw err;
} finally {
this.taskRegistry.updateTask(containerTask);
}
}

private async getInstructLabContainerFolder(): Promise<string> {
const instructlabPath = path.join(this.appUserDirectory, 'instructlab', 'container');
await fs.mkdir(instructlabPath, { recursive: true });
await fs.mkdir(path.join(instructlabPath, '.cache'), { recursive: true });
await fs.mkdir(path.join(instructlabPath, '.config'), { recursive: true });
await fs.mkdir(path.join(instructlabPath, '.local'), { recursive: true });
return instructlabPath;
}
}
Loading

0 comments on commit 0df95a8

Please sign in to comment.