Skip to content

Commit

Permalink
fix: handle navigation and container lifecycle
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff MAURY <[email protected]>
  • Loading branch information
jeffmaury committed Feb 6, 2025
1 parent 1ea3cd2 commit c8249b9
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 43 deletions.
2 changes: 1 addition & 1 deletion packages/backend/src/instructlab-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export class InstructlabApiImpl implements InstructlabAPI {
return this.instructlabManager.getSessions();
}

requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<string> {
requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<void> {
return this.instructlabManager.requestCreateInstructlabContainer(config);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { TestEventEmitter } from '../../tests/utils';
import { VMType } from '@shared/src/models/IPodman';
import type { Task } from '@shared/src/models/ITask';
import instructlab_images from '../../assets/instructlab-images.json';
import { INSTRUCTLAB_CONTAINER_TRACKINGID } from '@shared/src/models/instructlab/IInstructlabContainerInfo';

vi.mock('@podman-desktop/api', () => {
return {
Expand Down Expand Up @@ -61,6 +62,7 @@ beforeEach(() => {
const containerRegistry = new ContainerRegistry();
containerRegistry.init();
instructlabManager = new InstructlabManager('', taskRegistry, podmanConnection, containerRegistry, telemetryMock);
taskRegistry.deleteByLabels({ trackingId: INSTRUCTLAB_CONTAINER_TRACKINGID });
});

test('getInstructLabContainer should return undefined if no containers', async () => {
Expand Down Expand Up @@ -92,9 +94,9 @@ test('requestCreateInstructlabContainer throws error if no podman connection', a
await expect(containerIdPromise).rejects.toBeInstanceOf(Error);
});

async function waitTasks(containerId: string, nb: number): Promise<Task[]> {
async function waitTasks(id: string, nb: number): Promise<Task[]> {
return vi.waitFor(() => {
const tasks = taskRegistry.getTasksByLabels({ trackingId: containerId });
const tasks = taskRegistry.getTasksByLabels({ trackingId: id });
if (tasks.length !== nb) {
throw new Error('not completed');
}
Expand All @@ -113,9 +115,8 @@ test('requestCreateInstructlabContainer returns id and error if listImage return
},
});
vi.mocked(containerEngine.listImages).mockRejectedValue(new Error());
const containerId = await instructlabManager.requestCreateInstructlabContainer({});
expect(containerId).toBeDefined();
const tasks = await waitTasks(containerId, 2);
await instructlabManager.requestCreateInstructlabContainer({});
const tasks = await waitTasks(INSTRUCTLAB_CONTAINER_TRACKINGID, 2);
expect(tasks.some(task => task.state === 'error')).toBeTruthy();
});

Expand All @@ -132,9 +133,8 @@ test('requestCreateInstructlabContainer returns id and error if listImage return
vi.mocked(containerEngine.listImages).mockResolvedValue([
{ RepoTags: [instructlab_images.default] } as unknown as ImageInfo,
]);
const containerId = await instructlabManager.requestCreateInstructlabContainer({});
expect(containerId).toBeDefined();
const tasks = await waitTasks(containerId, 3);
await instructlabManager.requestCreateInstructlabContainer({});
const tasks = await waitTasks(INSTRUCTLAB_CONTAINER_TRACKINGID, 3);
expect(tasks.some(task => task.state === 'error')).toBeTruthy();
});

Expand All @@ -154,8 +154,7 @@ test('requestCreateInstructlabContainer returns id and no error if createContain
vi.mocked(containerEngine.createContainer).mockResolvedValue({
id: 'containerId',
} as unknown as ContainerCreateResult);
const containerId = await instructlabManager.requestCreateInstructlabContainer({});
expect(containerId).toBeDefined();
const tasks = await waitTasks(containerId, 3);
await instructlabManager.requestCreateInstructlabContainer({});
const tasks = await waitTasks(INSTRUCTLAB_CONTAINER_TRACKINGID, 3);
expect(tasks.some(task => task.state === 'error')).toBeFalsy();
});
12 changes: 7 additions & 5 deletions packages/backend/src/managers/instructlab/instructlabManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import type { InstructlabSession } from '@shared/src/models/instructlab/IInstructlabSession';
import type { InstructlabContainerConfiguration } from '@shared/src/models/instructlab/IInstructlabContainerConfiguration';
import { getRandomString } from '../../utils/randomUtils';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import {
type TelemetryLogger,
Expand All @@ -33,6 +32,8 @@ import path from 'node:path';
import fs from 'node:fs/promises';
import type { ContainerRegistry, ContainerEvent } from '../../registries/ContainerRegistry';
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils';
import { INSTRUCTLAB_CONTAINER_TRACKINGID } from '@shared/src/models/instructlab/IInstructlabContainerInfo';
import { getRandomName } from '../../utils/randomUtils';

export const INSTRUCTLAB_CONTAINER_LABEL = 'ai-lab-instructlab-container';

Expand Down Expand Up @@ -76,8 +77,10 @@ export class InstructlabManager {
}

private onStopContainerEvent(event: ContainerEvent): void {
console.log('event id:', event.id, ' containerId: ', this.#containerId);
if (this.#containerId === event.id) {
this.#containerId = undefined;
this.taskRegistry.deleteByLabels({ trackingId: INSTRUCTLAB_CONTAINER_TRACKINGID });
}
}

Expand Down Expand Up @@ -114,9 +117,9 @@ export class InstructlabManager {
return this.#containerId;
}

async requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<string> {
async requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<void> {
// create a tracking id to put in the labels
const trackingId: string = getRandomString();
const trackingId: string = INSTRUCTLAB_CONTAINER_TRACKINGID;

const labels = {
trackingId: trackingId,
Expand Down Expand Up @@ -170,7 +173,6 @@ export class InstructlabManager {
});
this.telemetryLogger.logError('instructlab.startContainer', { error: err });
});
return trackingId;
}

async createInstructlabContainer(
Expand Down Expand Up @@ -200,7 +202,7 @@ export class InstructlabManager {
const containerTask = this.taskRegistry.createTask('Starting InstructLab container', 'loading', labels);
const createContainerOptions: ContainerCreateOptions = {
Image: imageInfo.Id,
name: `instructlab-${labels['trackingId']}`,
name: getRandomName('instructlab'),
Labels: { [INSTRUCTLAB_CONTAINER_LABEL]: image },
HostConfig: {
AutoRemove: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ import { instructlabClient, studioClient } from '/@/utils/client';
import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo';
import { VMType } from '@shared/src/models/IPodman';
import userEvent from '@testing-library/user-event';
import { tasks } from '/@/stores/tasks';
import type { Task } from '@shared/src/models/ITask';
import * as tasks from '/@/stores/tasks';
import { writable } from 'svelte/store';

vi.mock('../../stores/tasks', async () => {
return {
tasks: {
subscribe: vi.fn().mockReturnValue((): void => {}),
},
tasks: vi.fn(),
};
});

Expand Down Expand Up @@ -73,6 +71,7 @@ const containerProviderConnection: ContainerProviderConnectionInfo = {

beforeEach(() => {
getContainerConnectionInfoMock.mockReturnValue([containerProviderConnection]);
vi.mocked(tasks).tasks = writable([]);
});

test('start button should be displayed if no InstructLab container', async () => {
Expand All @@ -82,16 +81,12 @@ test('start button should be displayed if no InstructLab container', async () =>
expect(startBtn).toBeDefined();
});

test('start button should be displayed and disabled', async () => {
vi.mocked(tasks.subscribe).mockImplementation((f: (tasks: Task[]) => void) => {
f([]);
return (): void => {};
});
render(StartInstructLabContainer, { trackingId: 'trackingId' });
test('start button should be displayed and enabled', async () => {
render(StartInstructLabContainer);

const startBtn = screen.getByTitle('Start InstructLab container');
expect(startBtn).toBeDefined();
expect(startBtn).toBeDisabled();
expect(startBtn).toBeEnabled();
});

test('open button should be displayed if no InstructLab container', async () => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
<script lang="ts">
import { router } from 'tinro';
import { tasks } from '/@/stores/tasks';
import type { ContainerProviderConnectionInfo } from '@shared/src/models/IContainerConnectionInfo';
import { Button, ErrorMessage, FormPage } from '@podman-desktop/ui-svelte';
Expand All @@ -9,14 +8,8 @@ import TrackedTasks from '/@/lib/progress/TrackedTasks.svelte';
import { instructlabClient, studioClient } from '/@/utils/client';
import type { Task } from '@shared/src/models/ITask';
import { onMount } from 'svelte';
interface Props {
// The tracking id is a unique identifier provided by the
// backend when calling requestCreateInstructLabContainer
trackingId?: string;
}
let { trackingId }: Props = $props();
import { INSTRUCTLAB_CONTAINER_TRACKINGID } from '@shared/src/models/instructlab/IInstructlabContainerInfo';
import { filterByLabel } from '/@/utils/taskUtils';
// The container provider connection to use
let containerProviderConnection: ContainerProviderConnectionInfo | undefined = $state(undefined);
Expand All @@ -34,7 +27,11 @@ let containerId: string | undefined = $state(undefined);
// available means the container is started
let available: boolean = $derived(!!containerId);
// loading state
let loading = $derived(trackingId !== undefined && !errorMsg);
let loading = $derived(
containerId === undefined &&
filterByLabel($tasks, { trackingId: INSTRUCTLAB_CONTAINER_TRACKINGID }).length > 0 &&
!errorMsg,
);
onMount(async () => {
containerId = await instructlabClient.getInstructlabContainerId();
Expand Down Expand Up @@ -62,10 +59,9 @@ function processTasks(trackedTasks: Task[]): void {
async function submit(): Promise<void> {
errorMsg = undefined;
try {
trackingId = await instructlabClient.requestCreateInstructlabContainer({
await instructlabClient.requestCreateInstructlabContainer({
connection: $state.snapshot(containerProviderConnection),
});
router.location.query.set('trackingId', trackingId);
} catch (err: unknown) {
console.error('Something wrong while trying to create the InstructLab container.', err);
errorMsg = String(err);
Expand Down Expand Up @@ -100,7 +96,11 @@ function openDocumentation(): void {
</div>
</header>
<!-- tasks tracked -->
<TrackedTasks class="mx-5 mt-5" onChange={processTasks} trackingId={trackingId} tasks={$tasks} />
<TrackedTasks
class="mx-5 mt-5"
onChange={processTasks}
trackingId={INSTRUCTLAB_CONTAINER_TRACKINGID}
tasks={$tasks} />

<!-- form -->
<div class="bg-[var(--pd-content-card-bg)] m-5 space-y-6 px-8 sm:pb-6 xl:pb-8 rounded-lg h-fit">
Expand Down
2 changes: 1 addition & 1 deletion packages/shared/src/InstructlabAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export abstract class InstructlabAPI {
*
* @param config
*/
abstract requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<string>;
abstract requestCreateInstructlabContainer(config: InstructlabContainerConfiguration): Promise<void>;

abstract routeToInstructLabContainerTerminal(containerId: string): Promise<void>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
export const INSTRUCTLAB_CONTAINER_TRACKINGID = 'instructlab.trackingid';

export interface InstructlabContainerInfo {
/**
* The container engine it is running on
Expand Down

0 comments on commit c8249b9

Please sign in to comment.