generated from milosgajdos/go-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: Add support for VoyageAI embeddings
See: https://docs.voyageai.com/reference/embeddings-api Signed-off-by: Milos Gajdos <[email protected]>
- Loading branch information
1 parent
817f998
commit 6c2eea1
Showing
10 changed files
with
452 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ | |
|
||
# Dependency directories (remove the comment below to include it) | ||
# vendor/ | ||
.env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"flag" | ||
"fmt" | ||
"log" | ||
|
||
"github.com/milosgajdos/go-embeddings/voyage" | ||
) | ||
|
||
var ( | ||
input string | ||
model string | ||
truncation bool | ||
inputType string | ||
) | ||
|
||
func init() { | ||
flag.StringVar(&input, "input", "what is life", "input data") | ||
flag.StringVar(&model, "model", voyage.VoyageV2.String(), "model name") | ||
flag.StringVar(&inputType, "input-type", voyage.DocInput.String(), "input type") | ||
flag.BoolVar(&truncation, "truncate", false, "truncate type") | ||
} | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
c := voyage.NewClient() | ||
|
||
embReq := &voyage.EmbeddingRequest{ | ||
Input: []string{input}, | ||
Model: voyage.Model(model), | ||
InputType: voyage.InputType(inputType), | ||
Truncation: truncation, | ||
} | ||
|
||
embs, err := c.Embed(context.Background(), embReq) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Printf("got %d embeddings", len(embs)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package voyage | ||
|
||
import ( | ||
"os" | ||
|
||
"github.com/milosgajdos/go-embeddings" | ||
"github.com/milosgajdos/go-embeddings/client" | ||
) | ||
|
||
const ( | ||
// BaseURL is VoyageAI HTTP API base URL. | ||
BaseURL = "https://api.voyageai.com" | ||
// EmbedAPIVersion is the latest stable embedding API version. | ||
EmbedAPIVersion = "v1" | ||
) | ||
|
||
// Client is Voyage HTTP API client. | ||
type Client struct { | ||
opts Options | ||
} | ||
|
||
// Options are client options | ||
type Options struct { | ||
APIKey string | ||
BaseURL string | ||
Version string | ||
HTTPClient *client.HTTP | ||
} | ||
|
||
// Option is functional graph option. | ||
type Option func(*Options) | ||
|
||
// NewClient creates a new HTTP API client and returns it. | ||
// By default it reads the Voyage API key from VOYAGE_API_KEY | ||
// env var and uses the default Go http.Client for making API requests. | ||
// You can override the default options via the client methods. | ||
func NewClient(opts ...Option) *Client { | ||
options := Options{ | ||
APIKey: os.Getenv("VOYAGE_API_KEY"), | ||
BaseURL: BaseURL, | ||
Version: EmbedAPIVersion, | ||
HTTPClient: client.NewHTTP(), | ||
} | ||
|
||
for _, apply := range opts { | ||
apply(&options) | ||
} | ||
|
||
return &Client{ | ||
opts: options, | ||
} | ||
} | ||
|
||
// NewEmbedder creates a client that implements embeddings.Embedder | ||
func NewEmbedder(opts ...Option) embeddings.Embedder[*EmbeddingRequest] { | ||
return NewClient(opts...) | ||
} | ||
|
||
// WithAPIKey sets the API key. | ||
func WithAPIKey(apiKey string) Option { | ||
return func(o *Options) { | ||
o.APIKey = apiKey | ||
} | ||
} | ||
|
||
// WithBaseURL sets the API base URL. | ||
func WithBaseURL(baseURL string) Option { | ||
return func(o *Options) { | ||
o.BaseURL = baseURL | ||
} | ||
} | ||
|
||
// WithVersion sets the API version. | ||
func WithVersion(version string) Option { | ||
return func(o *Options) { | ||
o.Version = version | ||
} | ||
} | ||
|
||
// WithHTTPClient sets the HTTP client. | ||
func WithHTTPClient(httpClient *client.HTTP) Option { | ||
return func(o *Options) { | ||
o.HTTPClient = httpClient | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package voyage | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/milosgajdos/go-embeddings/client" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
const ( | ||
voyageAPIKey = "somekey" | ||
) | ||
|
||
func TestClient(t *testing.T) { | ||
t.Setenv("VOYAGE_API_KEY", voyageAPIKey) | ||
|
||
t.Run("API key", func(t *testing.T) { | ||
c := NewClient() | ||
assert.Equal(t, c.opts.APIKey, voyageAPIKey) | ||
|
||
testVal := "foo" | ||
c = NewClient(WithAPIKey(testVal)) | ||
assert.Equal(t, c.opts.APIKey, testVal) | ||
}) | ||
|
||
t.Run("BaseURL", func(t *testing.T) { | ||
c := NewClient() | ||
assert.Equal(t, c.opts.BaseURL, BaseURL) | ||
|
||
testVal := "http://foo" | ||
c = NewClient(WithBaseURL(testVal)) | ||
assert.Equal(t, c.opts.BaseURL, testVal) | ||
}) | ||
|
||
t.Run("version", func(t *testing.T) { | ||
c := NewClient() | ||
assert.Equal(t, c.opts.Version, EmbedAPIVersion) | ||
|
||
testVal := "v3" | ||
c = NewClient(WithVersion(testVal)) | ||
assert.Equal(t, c.opts.Version, testVal) | ||
}) | ||
|
||
t.Run("http client", func(t *testing.T) { | ||
c := NewClient() | ||
assert.NotNil(t, c.opts.HTTPClient) | ||
|
||
testVal := client.NewHTTP() | ||
c = NewClient(WithHTTPClient(testVal)) | ||
assert.NotNil(t, c.opts.HTTPClient) | ||
}) | ||
} |
Oops, something went wrong.