Skip to content

Commit

Permalink
feat(pipelines): implement parts of speech model
Browse files Browse the repository at this point in the history
  • Loading branch information
sno2 committed Nov 19, 2021
1 parent 20486e4 commit 294a5fd
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
Cargo.lock
.vscode
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ methods on those classes.
- `NERModel`
- `QAModel`
- `SentimentModel`
- `POSModel`

To test out these pipelines, you can try and run the `dev.ts` file. However,
this will automatically install the necessary models so I advise you comment out
Expand Down
6 changes: 6 additions & 0 deletions dev.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ const SECTION = (name: string) => console.log(`\n${name}\n`);

const manager = new ModelManager();

SECTION("POS TAGGING MODEL");

const posModel = await manager.createPOSModel();

console.log(await posModel.predict(["What are the parts in this?"]));

const convoModel = await manager.createConversationModel();

const convoManager = await convoModel.createConversationManager();
Expand Down
3 changes: 3 additions & 0 deletions mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ export type {
TranslateInit,
TranslationModelInit,
} from "./models/translation/mod.ts";

export { POSModel } from "./models/pos.ts";
export type { POSEntity } from "./models/pos.ts";
32 changes: 26 additions & 6 deletions model_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { SentimentModel } from "./models/sentiment.ts";
import { TranslationModel } from "./models/translation/mod.ts";
import { ConversationModel } from "./models/conversation.ts";
import type { TranslationModelInit } from "./models/translation/mod.ts";
import { POSModel } from "./models/pos.ts";
import { encode } from "./utils/encode.ts";
import { decode } from "./utils/decode.ts";

Expand Down Expand Up @@ -42,12 +43,6 @@ const symbolDefinitions = {
result: "isize",
nonblocking: true,
},
error_len: { parameters: [], result: "usize", nonblocking: true },
fill_result: {
parameters: ["buffer", "usize"],
result: "void",
nonblocking: true,
},
create_conversation_model: {
parameters: [],
result: "isize",
Expand All @@ -68,6 +63,22 @@ const symbolDefinitions = {
result: "isize",
nonblocking: true,
},
create_pos_model: {
parameters: [],
result: "isize",
nonblocking: true,
},
pos_predict: {
parameters: ["usize", "buffer", "usize"],
result: "isize",
nonblocking: true,
},
error_len: { parameters: [], result: "usize", nonblocking: true },
fill_result: {
parameters: ["buffer", "usize"],
result: "void",
nonblocking: true,
},
fill_error: { parameters: ["buffer"], result: "void", nonblocking: true },
delete_model: { parameters: ["usize"], result: "isize", nonblocking: true },
} as const;
Expand Down Expand Up @@ -192,6 +203,15 @@ export class ModelManager {
return model;
}

async createPOSModel(): Promise<POSModel> {
const rid = await this.bindings.create_pos_model().then(
this.assertCode,
);
const model = new POSModel(this, rid);
this.#models.push(model);
return model;
}

close() {
this.#close();
this.#isClosed = true;
Expand Down
28 changes: 28 additions & 0 deletions models/pos.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import type { ModelManager } from "../model_manager.ts";
import { Model } from "../model.ts";
import { encode } from "../utils/encode.ts";
import { decode } from "../utils/decode.ts";

export interface POSEntity {
word: string;
score: number;
label: string;
}

export class POSModel extends Model {
constructor(manager: ModelManager, rid: number) {
super(manager, rid);
}

async predict(inputs: string[]): Promise<POSEntity[]> {
const { bindings, assertCode } = this.manager;
const bytes = encode(JSON.stringify(inputs));
const len = await bindings
.pos_predict(this.rid, bytes, bytes.length)
.then(assertCode);
const buf = new Uint8Array(len);
await bindings.fill_result(buf, len);
const json = decode(buf);
return JSON.parse(json);
}
}
11 changes: 4 additions & 7 deletions models/translation/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@ export class TranslationModel extends Model {
this.#init = init;
}

async translate(init: TranslateInit) {
const { sourceLanguage, targetLanguage } = init;
async translate(init: TranslateInit): Promise<string[]> {
const { bindings, assertCode } = this.manager;
const bytes = encode(JSON.stringify(init));
const len = await bindings.translation_translate(
this.rid,
bytes,
bytes.length,
).then(assertCode);
const len = await bindings
.translation_translate(this.rid, bytes, bytes.length)
.then(assertCode);
const buf = new Uint8Array(len);
await bindings.fill_result(buf, len);
const json = decode(buf);
Expand Down
7 changes: 6 additions & 1 deletion src/allocators.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use crate::create_allocator;
use rust_bert::pipelines::{conversation, ner, question_answering, sentiment, translation};
use rust_bert::pipelines::{
conversation, ner, pos_tagging, question_answering, sentiment, translation,
zero_shot_classification,
};

pub enum Model {
TranslationModel(translation::TranslationModel),
QuestionAnsweringModel(question_answering::QuestionAnsweringModel),
NERModel(ner::NERModel),
SentimentModel(sentiment::SentimentModel),
ConversationModel(conversation::ConversationModel),
POSModel(pos_tagging::POSModel),
ZeroShotClassificationModel(zero_shot_classification::ZeroShotClassificationModel),
}

pub enum ModelResource {
Expand Down
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
mod allocators;
mod conversation;
mod ner;
mod pos_tagging;
mod qa;
mod sentiment;
mod translation;
mod zero_shot_classification;

pub use zero_shot_classification::*;
pub use allocators::*;
pub use conversation::*;
pub use ner::*;
use once_cell::sync::Lazy;
pub use pos_tagging::*;
pub use qa::*;
pub use sentiment::*;
use std::sync::Mutex;
Expand All @@ -27,7 +31,7 @@ pub fn set_result(v: Vec<u8>) -> usize {
/// Inspired by deno_sqlite3's `exec` helper by @littledivvy
pub fn exec<F>(f: F) -> isize
where
F: FnOnce() -> Result<isize, anyhow::Error>,
F: FnOnce() -> Result<isize, anyhow::Error>,
{
match f() {
Ok(a) => a,
Expand Down
32 changes: 32 additions & 0 deletions src/pos_tagging.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use crate::{exec, models, set_result, Model};
use anyhow::Context;
use rust_bert::pipelines::pos_tagging::POSModel;

#[no_mangle]
pub extern "C" fn create_pos_model() -> isize {
exec(|| {
let model = POSModel::new(Default::default())
.context("Failed to load Parts of Speech Tagging model.")?;

models::allocate(Model::POSModel(model)).map(|a| a as isize)
})
}

#[no_mangle]
pub extern "C" fn pos_predict(rid: usize, buf: *const u8, buf_len: usize) -> isize {
exec(|| {
let inputs: Vec<String> =
serde_json::from_slice(unsafe { std::slice::from_raw_parts(buf, buf_len) })?;

models::with_access(rid, |model| {
let model = match model {
Model::POSModel(m) => m,
_ => return Err(anyhow::anyhow!("Expected POS Model at rid '{}'.", rid)),
};

let outputs = model.predict(&inputs);
let outputs = serde_json::to_vec(&outputs).context("Failed to serialize POS tags.")?;
Ok(set_result(outputs) as isize)
})
})
}
28 changes: 28 additions & 0 deletions src/zero_shot_classification.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::{exec, models, Model};
use anyhow::Context;
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
use serde::Deserialize;

#[no_mangle]
pub extern "C" fn create_zero_shot_model() -> isize {
exec(|| {
let model = ZeroShotClassificationModel::new(Default::default())
.context("Failed to create zero shot classification model.")?;

models::allocate(Model::ZeroShotClassificationModel(model)).map(|a| a as isize)
})
}

#[derive(Deserialize)]
pub struct ZeroShotInput {}

#[no_mangle]
pub extern "C" fn zero_shot_predict(rid: usize, buf: *const u8, buf_len: usize) -> isize {
exec(|| {
let inputs: Vec<String> =
serde_json::from_slice(unsafe { std::slice::from_raw_parts(buf, buf_len) })
.context("Failed to parse zero shot input.")?;

Ok(0)
})
}

0 comments on commit 294a5fd

Please sign in to comment.