-
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AzureOpenAI and Skylark support stream ChatCompletion #2
base: master
Are you sure you want to change the base?
Changes from 3 commits
6f596fb
7c53757
a7880de
1177f76
0cd928e
fb6476b
3dcb480
7854c08
050a253
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ | |
use Hyperf\Odin\Exception\NotImplementedException; | ||
use Hyperf\Odin\Message\MessageInterface; | ||
use Hyperf\Odin\Tool\ToolInterface; | ||
use InvalidArgumentException; | ||
use Psr\Log\LoggerInterface; | ||
|
||
class Client implements ClientInterface | ||
|
@@ -33,12 +32,13 @@ class Client implements ClientInterface | |
*/ | ||
protected array $clients = []; | ||
|
||
protected ?LoggerInterface $logger; | ||
protected ?LoggerInterface $logger = null; | ||
|
||
protected bool $debug = false; | ||
|
||
protected string $model; | ||
|
||
public function __construct(AzureOpenAIConfig $config, LoggerInterface $logger, string $model) | ||
public function __construct(AzureOpenAIConfig $config, ?LoggerInterface $logger, string $model) | ||
{ | ||
$this->logger = $logger; | ||
$this->model = $model; | ||
|
@@ -54,7 +54,7 @@ public function chat( | |
array $tools = [], | ||
bool $stream = false, | ||
): ChatCompletionResponse { | ||
$deploymentPath = $this->buildDeploymentPath($model); | ||
$deploymentPath = $this->buildDeploymentPath(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥所有 $model 传参都去掉了?这样会让这个类在 DI 单例的情况下只能支持一种 Model There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为下面的那个 protected function buildDeploymentPath(): string
{
return 'openai/deployments/' . $this->config->getDeploymentName();
} 中的 对于 AzureOpenAI 来说,传入一种配置,其实也就只能访问一个模型,因为 url 里面会有模型名,这里传入不同的 model 其实是无影响的,AzureOpenAI 的接口没有使用该参数 |
||
$messagesArr = []; | ||
foreach ($messages as $message) { | ||
if ($message instanceof MessageInterface) { | ||
|
@@ -65,6 +65,7 @@ public function chat( | |
'messages' => $messagesArr, | ||
'model' => $model, | ||
'temperature' => $temperature, | ||
'stream' => $stream, | ||
]; | ||
if ($maxTokens) { | ||
$json['max_tokens'] = $maxTokens; | ||
|
@@ -91,10 +92,9 @@ public function chat( | |
$this->debug && $this->logger?->debug(sprintf("Send Messages: %s\nTools: %s", json_encode($messagesArr, JSON_UNESCAPED_UNICODE), json_encode($tools, JSON_UNESCAPED_UNICODE))); | ||
$response = $this->getClient($model)->post($deploymentPath . '/chat/completions', [ | ||
'query' => [ | ||
'api-version' => $this->config->getApiVersion($model), | ||
'api-version' => $this->config->getApiVersion(), | ||
], | ||
'json' => $json, | ||
'verify' => false, | ||
]); | ||
$chatCompletionResponse = new ChatCompletionResponse($response); | ||
$this->debug && $this->logger?->debug('Receive: ' . $chatCompletionResponse); | ||
|
@@ -107,10 +107,10 @@ public function completions( | |
float $temperature = 0.9, | ||
int $maxTokens = 200 | ||
): TextCompletionResponse { | ||
$deploymentPath = $this->buildDeploymentPath($model); | ||
$deploymentPath = $this->buildDeploymentPath(); | ||
$response = $this->getClient($model)->post($deploymentPath . '/completions', [ | ||
'query' => [ | ||
'api-version' => $this->config->getApiVersion($model), | ||
'api-version' => $this->config->getApiVersion(), | ||
], | ||
'json' => [ | ||
'prompt' => $prompt, | ||
|
@@ -133,14 +133,14 @@ public function embedding( | |
string $model = 'text-embedding-ada-002', | ||
?string $user = null | ||
): ListResponse { | ||
$deploymentPath = $this->buildDeploymentPath($model); | ||
$deploymentPath = $this->buildDeploymentPath(); | ||
$json = [ | ||
'input' => $input, | ||
]; | ||
$user && $json['user'] = $user; | ||
$response = $this->getClient($model)->post($deploymentPath . '/embeddings', [ | ||
'query' => [ | ||
'api-version' => $this->config->getApiVersion($model), | ||
'api-version' => $this->config->getApiVersion(), | ||
], | ||
'json' => $json, | ||
'verify' => false, | ||
|
@@ -161,9 +161,6 @@ public function setDebug(bool $debug): static | |
|
||
protected function initConfig(AzureOpenAIConfig $config): static | ||
{ | ||
if (! $config instanceof AzureOpenAIConfig) { | ||
throw new InvalidArgumentException('AzureOpenAIConfig is required'); | ||
} | ||
$this->config = $config; | ||
$headers = [ | ||
'api-key' => $config->getApiKey(), | ||
|
@@ -182,8 +179,8 @@ protected function getClient(string $model): ?GuzzleClient | |
return $this->clients[$model]; | ||
} | ||
|
||
protected function buildDeploymentPath(string $model = 'gpt-3.5-turbo'): string | ||
protected function buildDeploymentPath(): string | ||
{ | ||
return 'openai/deployments/' . $this->config->getDeploymentName($model); | ||
return 'openai/deployments/' . $this->config->getDeploymentName(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,10 @@ | |
|
||
namespace Hyperf\Odin\Api\OpenAI\Response; | ||
|
||
use Generator; | ||
use Hyperf\Odin\Exception\RuntimeException; | ||
use Psr\Http\Message\ResponseInterface as PsrResponseInterface; | ||
use Psr\Http\Message\StreamInterface; | ||
use Stringable; | ||
|
||
class ChatCompletionResponse extends AbstractResponse implements Stringable | ||
|
@@ -24,13 +28,15 @@ class ChatCompletionResponse extends AbstractResponse implements Stringable | |
|
||
protected ?string $model = null; | ||
|
||
protected array|null $choices = []; | ||
protected ?array $choices = []; | ||
|
||
protected ?Usage $usage = null; | ||
|
||
protected bool $isChunked = false; | ||
|
||
public function __toString(): string | ||
{ | ||
return trim($this->getChoices()[0]?->getMessage()?->getContent() ? : ''); | ||
return trim($this->getChoices()[0]?->getMessage()?->getContent() ?: ''); | ||
} | ||
|
||
public function getId(): ?string | ||
|
@@ -82,6 +88,9 @@ public function getFirstChoice(): ?ChatCompletionChoice | |
return $this->choices[0] ?? null; | ||
} | ||
|
||
/** | ||
* @return null|ChatCompletionChoice[] | ||
*/ | ||
public function getChoices(): ?array | ||
{ | ||
return $this->choices; | ||
|
@@ -104,8 +113,59 @@ public function setUsage(?Usage $usage): static | |
return $this; | ||
} | ||
|
||
public function isChunked(): bool | ||
{ | ||
return $this->isChunked; | ||
} | ||
|
||
public function getStreamIterator(): Generator | ||
{ | ||
while (! $this->originResponse->getBody()->eof()) { | ||
$line = $this->readLine($this->originResponse->getBody()); | ||
|
||
if (! str_starts_with($line, 'data:')) { | ||
continue; | ||
} | ||
$data = trim(substr($line, strlen('data:'))); | ||
if (str_starts_with('[DONE]', $data)) { | ||
break; | ||
} | ||
$content = json_decode($data, true); | ||
if (json_last_error() !== JSON_ERROR_NONE) { | ||
throw new RuntimeException('Invalid JSON response | ' . $line); | ||
} | ||
if (isset($content['error'])) { | ||
throw new RuntimeException('Steam Error | ' . $content['error']); | ||
} | ||
$this->setId($content['id'] ?? null); | ||
$this->setObject($content['object'] ?? null); | ||
$this->setCreated($content['created'] ?? null); | ||
$this->setModel($content['model'] ?? null); | ||
if (empty($content['choices'])) { | ||
continue; | ||
} | ||
foreach ($content['choices'] as $choice) { | ||
yield ChatCompletionChoice::fromArray($choice); | ||
} | ||
} | ||
} | ||
|
||
public function setOriginResponse(PsrResponseInterface $originResponse): static | ||
{ | ||
$this->originResponse = $originResponse; | ||
$this->success = $originResponse->getStatusCode() === 200; | ||
$this->parseContent(); | ||
return $this; | ||
} | ||
|
||
protected function parseContent(): static | ||
{ | ||
if ($this->originResponse->hasHeader('Transfer-Encoding') | ||
&& $this->originResponse->getHeaderLine('Transfer-Encoding') === 'chunked') { | ||
$this->isChunked = true; | ||
return $this; | ||
} | ||
$this->content = $this->originResponse->getBody()->getContents(); | ||
$content = json_decode($this->content, true); | ||
if (isset($content['id'])) { | ||
$this->setId($content['id']); | ||
|
@@ -137,4 +197,18 @@ protected function buildChoices(mixed $choices): array | |
return $result; | ||
} | ||
|
||
private function readLine(StreamInterface $stream): string | ||
{ | ||
$buffer = ''; | ||
while (! $stream->eof()) { | ||
if ('' === ($byte = $stream->read(1))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这代码性能爆炸 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不影响大局,暂时先这样吧 |
||
return $buffer; | ||
} | ||
$buffer .= $byte; | ||
if ($byte === "\n") { | ||
break; | ||
} | ||
} | ||
return $buffer; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
理论上应该想办法去掉这个 $this->model 才对
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对,其实这个 model 没啥用,里面的 $clients 其实也只会有一个值