Skip to content

Commit

Permalink
Add support for embedding models in the ModelMapper class.
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhhui committed Dec 15, 2024
1 parent 3ca0d7b commit 171f493
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
1 change: 1 addition & 0 deletions publish/odin.php
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
],
],
'text-embedding-ada-002' => [
'type' => 'embedding',
'implementation' => AzureOpenAIModel::class,
'config' => [
'api_key' => env('AZURE_OPENAI_TEXT_EMBEDDING_ADA_002_API_KEY'),
Expand Down
56 changes: 49 additions & 7 deletions src/ModelMapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,40 @@
namespace Hyperf\Odin;

use Hyperf\Contract\ConfigInterface;
use Hyperf\Odin\Model\EmbeddingInterface;
use Hyperf\Odin\Model\ModelInterface;
use InvalidArgumentException;

class ModelMapper
{
protected string $defaultModel = '';

protected string $defaultEmbeddingModel = '';

protected array $models = [];

public function __construct(protected ConfigInterface $config)
{
$this->defaultModel = $config->get('odin.llm.default', 'gpt-3.5-turbo');
$this->defaultEmbeddingModel = $config->get('odin.llm.default_embedding', 'text-embedding-ada-002');
$models = $config->get('odin.llm.models', []);
foreach ($models as $model => $item) {
if (! $model || ! isset($item['implementation'])) {
continue;
}
$implementation = $item['implementation'];
$modelObject = new $implementation($model, $item['config'] ?? []);
if (! $modelObject instanceof ModelInterface) {
throw new InvalidArgumentException(sprintf('Model %s must be an instance of %s.', $model, ModelInterface::class));
if (isset($item['type']) && $item['type'] === 'embedding') {
if (! $modelObject instanceof EmbeddingInterface) {
throw new InvalidArgumentException(sprintf('Model %s must be an instance of %s.', $model, EmbeddingInterface::class));
}
$this->models['embedding'][$model] = $modelObject;
} else {
if (! $modelObject instanceof ModelInterface) {
throw new InvalidArgumentException(sprintf('Model %s must be an instance of %s.', $model, ModelInterface::class));
}
$this->models['chat'][$model] = $modelObject;
}
$this->models[$model] = $modelObject;
}
}

Expand All @@ -43,19 +55,49 @@ public function getDefaultModel(): ModelInterface
return $this->getModel($this->defaultModel);
}

public function getDefaultEmbeddingModel(): EmbeddingInterface
{
return $this->getModel($this->defaultEmbeddingModel);
}

/**
* Alias for getChatModel(string $model) method.
*/
public function getModel(string $model): ModelInterface
{
return $this->getChatModel($model);
}

public function getChatModel(string $model): ModelInterface
{
if ($model === '') {
$model = $this->defaultModel;
}
if (! isset($this->models[$model])) {
throw new InvalidArgumentException(sprintf('Model %s is not defined.', $model));
if (! isset($this->models['chat'][$model])) {
throw new InvalidArgumentException(sprintf('Chat Model %s is not defined.', $model));
}
return $this->models[$model];
return $this->models['chat'][$model];
}

public function getModels(): array
public function getEmbeddingModel(string $model): EmbeddingInterface
{
if ($model === '') {
$model = $this->defaultEmbeddingModel;
}
if (! isset($this->models['embedding'][$model])) {
throw new InvalidArgumentException(sprintf('Embedding Model %s is not defined.', $model));
}
return $this->models['embedding'][$model];
}

public function getModels(string $type = ''): array
{
if ($type === 'embedding') {
return $this->models['embedding'];
}
if ($type === 'chat') {
return $this->models['chat'];
}
return $this->models;
}
}

0 comments on commit 171f493

Please sign in to comment.