Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AzureOpenAI and Skylark support stream ChatCompletion #2

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/Api/AzureOpenAI/AzureOpenAIConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ class AzureOpenAIConfig
{
public function __construct(
protected array $config = [],
) {

}
) {}

public function getApiKey(): ?string
{
Expand Down
27 changes: 12 additions & 15 deletions src/Api/AzureOpenAI/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
use Hyperf\Odin\Exception\NotImplementedException;
use Hyperf\Odin\Message\MessageInterface;
use Hyperf\Odin\Tool\ToolInterface;
use InvalidArgumentException;
use Psr\Log\LoggerInterface;

class Client implements ClientInterface
Expand All @@ -33,12 +32,13 @@ class Client implements ClientInterface
*/
protected array $clients = [];

protected ?LoggerInterface $logger;
protected ?LoggerInterface $logger = null;

protected bool $debug = false;

protected string $model;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上应该想办法去掉这个 $this->model 才对

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,其实这个 model 没啥用,里面的 $clients 其实也只会有一个值


public function __construct(AzureOpenAIConfig $config, LoggerInterface $logger, string $model)
public function __construct(AzureOpenAIConfig $config, ?LoggerInterface $logger, string $model)
{
$this->logger = $logger;
$this->model = $model;
Expand All @@ -54,7 +54,7 @@ public function chat(
array $tools = [],
bool $stream = false,
): ChatCompletionResponse {
$deploymentPath = $this->buildDeploymentPath($model);
$deploymentPath = $this->buildDeploymentPath();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥所有 $model 传参都去掉了?这样会让这个类在 DI 单例的情况下只能支持一种 Model

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为下面的那个

protected function buildDeploymentPath(): string
{
    return 'openai/deployments/' . $this->config->getDeploymentName();
}

中的 $this->config->getDeploymentName() 已经不提供入参了,语法检测会提示让我去掉

对于 AzureOpenAI 来说,传入一种配置,其实也就只能访问一个模型,因为 url 里面会有模型名,这里传入不同的 model 其实是无影响的,AzureOpenAI 的接口没有使用该参数

$messagesArr = [];
foreach ($messages as $message) {
if ($message instanceof MessageInterface) {
Expand All @@ -65,6 +65,7 @@ public function chat(
'messages' => $messagesArr,
'model' => $model,
'temperature' => $temperature,
'stream' => $stream,
];
if ($maxTokens) {
$json['max_tokens'] = $maxTokens;
Expand All @@ -91,10 +92,9 @@ public function chat(
$this->debug && $this->logger?->debug(sprintf("Send Messages: %s\nTools: %s", json_encode($messagesArr, JSON_UNESCAPED_UNICODE), json_encode($tools, JSON_UNESCAPED_UNICODE)));
$response = $this->getClient($model)->post($deploymentPath . '/chat/completions', [
'query' => [
'api-version' => $this->config->getApiVersion($model),
'api-version' => $this->config->getApiVersion(),
],
'json' => $json,
'verify' => false,
]);
$chatCompletionResponse = new ChatCompletionResponse($response);
$this->debug && $this->logger?->debug('Receive: ' . $chatCompletionResponse);
Expand All @@ -107,10 +107,10 @@ public function completions(
float $temperature = 0.9,
int $maxTokens = 200
): TextCompletionResponse {
$deploymentPath = $this->buildDeploymentPath($model);
$deploymentPath = $this->buildDeploymentPath();
$response = $this->getClient($model)->post($deploymentPath . '/completions', [
'query' => [
'api-version' => $this->config->getApiVersion($model),
'api-version' => $this->config->getApiVersion(),
],
'json' => [
'prompt' => $prompt,
Expand All @@ -133,14 +133,14 @@ public function embedding(
string $model = 'text-embedding-ada-002',
?string $user = null
): ListResponse {
$deploymentPath = $this->buildDeploymentPath($model);
$deploymentPath = $this->buildDeploymentPath();
$json = [
'input' => $input,
];
$user && $json['user'] = $user;
$response = $this->getClient($model)->post($deploymentPath . '/embeddings', [
'query' => [
'api-version' => $this->config->getApiVersion($model),
'api-version' => $this->config->getApiVersion(),
],
'json' => $json,
'verify' => false,
Expand All @@ -161,9 +161,6 @@ public function setDebug(bool $debug): static

protected function initConfig(AzureOpenAIConfig $config): static
{
if (! $config instanceof AzureOpenAIConfig) {
throw new InvalidArgumentException('AzureOpenAIConfig is required');
}
$this->config = $config;
$headers = [
'api-key' => $config->getApiKey(),
Expand All @@ -182,8 +179,8 @@ protected function getClient(string $model): ?GuzzleClient
return $this->clients[$model];
}

protected function buildDeploymentPath(string $model = 'gpt-3.5-turbo'): string
protected function buildDeploymentPath(): string
{
return 'openai/deployments/' . $this->config->getDeploymentName($model);
return 'openai/deployments/' . $this->config->getDeploymentName();
}
}
14 changes: 11 additions & 3 deletions src/Api/OpenAI/Response/ChatCompletionChoice.php
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@ public function __construct(
public ?int $index = null,
public ?string $logprobs = null,
public ?string $finishReason = null
) {
}
) {}

public static function fromArray(array $choice): static
{
return new static(Message::fromArray($choice['message']), $choice['index'] ?? null, $choice['logprobs'] ?? null, $choice['finish_reason'] ?? null);
$message = $choice['message'] ?? [];
if (isset($choice['delta'])) {
$message = [
'role' => $choice['delta']['role'] ?? 'assistant',
'content' => $choice['delta']['content'] ?? '',
'tool_calls' => $choice['delta']['tool_calls'] ?? [],
];
}

return new static(Message::fromArray($message), $choice['index'] ?? null, $choice['logprobs'] ?? null, $choice['finish_reason'] ?? null);
}

public function getMessage(): MessageInterface
Expand Down
78 changes: 76 additions & 2 deletions src/Api/OpenAI/Response/ChatCompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

namespace Hyperf\Odin\Api\OpenAI\Response;

use Generator;
use Hyperf\Odin\Exception\RuntimeException;
use Psr\Http\Message\ResponseInterface as PsrResponseInterface;
use Psr\Http\Message\StreamInterface;
use Stringable;

class ChatCompletionResponse extends AbstractResponse implements Stringable
Expand All @@ -24,13 +28,15 @@ class ChatCompletionResponse extends AbstractResponse implements Stringable

protected ?string $model = null;

protected array|null $choices = [];
protected ?array $choices = [];

protected ?Usage $usage = null;

protected bool $isChunked = false;

public function __toString(): string
{
return trim($this->getChoices()[0]?->getMessage()?->getContent() ? : '');
return trim($this->getChoices()[0]?->getMessage()?->getContent() ?: '');
}

public function getId(): ?string
Expand Down Expand Up @@ -82,6 +88,9 @@ public function getFirstChoice(): ?ChatCompletionChoice
return $this->choices[0] ?? null;
}

/**
* @return null|ChatCompletionChoice[]
*/
public function getChoices(): ?array
{
return $this->choices;
Expand All @@ -104,8 +113,59 @@ public function setUsage(?Usage $usage): static
return $this;
}

public function isChunked(): bool
{
return $this->isChunked;
}

public function getStreamIterator(): Generator
{
while (! $this->originResponse->getBody()->eof()) {
$line = $this->readLine($this->originResponse->getBody());

if (! str_starts_with($line, 'data:')) {
continue;
}
$data = trim(substr($line, strlen('data:')));
if (str_starts_with('[DONE]', $data)) {
break;
}
$content = json_decode($data, true);
if (json_last_error() !== JSON_ERROR_NONE) {
throw new RuntimeException('Invalid JSON response | ' . $line);
}
if (isset($content['error'])) {
throw new RuntimeException('Steam Error | ' . $content['error']);
}
$this->setId($content['id'] ?? null);
$this->setObject($content['object'] ?? null);
$this->setCreated($content['created'] ?? null);
$this->setModel($content['model'] ?? null);
if (empty($content['choices'])) {
continue;
}
foreach ($content['choices'] as $choice) {
yield ChatCompletionChoice::fromArray($choice);
}
}
}

public function setOriginResponse(PsrResponseInterface $originResponse): static
{
$this->originResponse = $originResponse;
$this->success = $originResponse->getStatusCode() === 200;
$this->parseContent();
return $this;
}

protected function parseContent(): static
{
if ($this->originResponse->hasHeader('Transfer-Encoding')
&& $this->originResponse->getHeaderLine('Transfer-Encoding') === 'chunked') {
$this->isChunked = true;
return $this;
}
$this->content = $this->originResponse->getBody()->getContents();
$content = json_decode($this->content, true);
if (isset($content['id'])) {
$this->setId($content['id']);
Expand Down Expand Up @@ -137,4 +197,18 @@ protected function buildChoices(mixed $choices): array
return $result;
}

private function readLine(StreamInterface $stream): string
{
$buffer = '';
while (! $stream->eof()) {
if ('' === ($byte = $stream->read(1))) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这代码性能爆炸

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不影响大局,暂时先这样吧

return $buffer;
}
$buffer .= $byte;
if ($byte === "\n") {
break;
}
}
return $buffer;
}
}
31 changes: 19 additions & 12 deletions src/Api/OpenAI/Response/ToolCall.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,13 @@

class ToolCall implements Arrayable
{

/**
* @param string $name
* @param array $arguments
* @param bool $shouldFix Sometimes the API will return a wrong function call. If this flag is true will attempt to fix that.
*/
public function __construct(
protected string $name,
protected array $arguments,
protected string $id,
protected string $type = 'function'
)
{
}
protected string $type = 'function',
protected string $streamArguments = '',
) {}

public static function fromArray(array $toolCalls): array
{
Expand All @@ -50,7 +43,7 @@ public static function fromArray(array $toolCalls): array
$name = $function['name'] ?? '';
$id = $toolCall['id'] ?? '';
$type = $toolCall['type'] ?? 'function';
$static = new static($name, $arguments, $id, $type);
$static = new static($name, $arguments, $id, $type, $function['arguments']);
$toolCallsResult[] = $static;
}
return $toolCallsResult;
Expand Down Expand Up @@ -81,12 +74,16 @@ public function setName(string $name): static

public function getArguments(): array
{
if (! empty($this->streamArguments)) {
$arguments = json_decode($this->streamArguments, true);
return is_array($arguments) ? $arguments : [];
}
return $this->arguments;
}

public function getSerializedArguments(): string
{
return json_encode($this->arguments);
return json_encode($this->getArguments(), JSON_UNESCAPED_UNICODE);
}

public function setArguments(array $arguments): static
Expand Down Expand Up @@ -116,4 +113,14 @@ public function setType(string $type): static
$this->type = $type;
return $this;
}

public function getStreamArguments(): string
{
return $this->streamArguments;
}

public function appendStreamArguments(string $arguments): void
{
$this->streamArguments .= $arguments;
}
}
Loading
Loading