From 2c5465e5ab7dd7c101364fb953a6e4100090ae88 Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Thu, 30 Nov 2023 20:22:00 +0000 Subject: [PATCH] Add hugging face API but it's a total mess Signed-off-by: Milos Gajdos --- cmd/huggingface/main.go | 49 +++++++++++++++++++++++ huggingface/client.go | 74 ++++++++++++++++++++++++++++++++++ huggingface/client_test.go | 52 ++++++++++++++++++++++++ huggingface/embedding.go | 81 ++++++++++++++++++++++++++++++++++++++ huggingface/error.go | 17 ++++++++ 5 files changed, 273 insertions(+) create mode 100644 cmd/huggingface/main.go create mode 100644 huggingface/client.go create mode 100644 huggingface/client_test.go create mode 100644 huggingface/embedding.go create mode 100644 huggingface/error.go diff --git a/cmd/huggingface/main.go b/cmd/huggingface/main.go new file mode 100644 index 0000000..2c6aca6 --- /dev/null +++ b/cmd/huggingface/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + + "github.com/milosgajdos/go-embeddings/cohere" + "github.com/milosgajdos/go-embeddings/huggingface" +) + +var ( + input string + model string + wait bool +) + +func init() { + flag.StringVar(&input, "input", "what is life", "input data") + flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name") + flag.BoolVar(&wait, "wait", false, "wait for model to start") +} + +func main() { + flag.Parse() + + c := huggingface.NewClient(). + WithModel(model) + + embReq := &huggingface.EmbeddingRequest{ + Inputs: []string{input}, + Options: huggingface.Options{ + WaitForModel: &wait, + }, + } + + embResp, err := c.Embeddings(context.Background(), embReq) + if err != nil { + log.Fatal(err) + } + + embs, err := huggingface.ToEmbeddings(embResp) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("got %d embeddings", len(embs)) +} diff --git a/huggingface/client.go b/huggingface/client.go new file mode 100644 index 0000000..76d9105 --- /dev/null +++ b/huggingface/client.go @@ -0,0 +1,74 @@ +package huggingface + +import ( + "encoding/json" + "net/http" + "os" +) + +const ( + // BaseURL is Cohere HTTP API base URL. + BaseURL = "https://api-inference.huggingface.co/models" +) + +// Client is Cohere HTTP API client. +type Client struct { + apiKey string + baseURL string + model string + hc *http.Client +} + +// NewClient creates a new HTTP API client and returns it. +// By default it reads the Cohere API key from HUGGINGFACE_API_KEY +// env var and uses the default Go http.Client for making API requests. +// You can override the default options via the client methods. +func NewClient() *Client { + return &Client{ + apiKey: os.Getenv("HUGGINGFACE_API_KEY"), + baseURL: BaseURL, + hc: &http.Client{}, + } +} + +// WithAPIKey sets the API key. +func (c *Client) WithAPIKey(apiKey string) *Client { + c.apiKey = apiKey + return c +} + +// WithBaseURL sets the API base URL. +func (c *Client) WithBaseURL(baseURL string) *Client { + c.baseURL = baseURL + return c +} + +// WithModel sets the model name +func (c *Client) WithModel(model string) *Client { + c.model = model + return c +} + +// WithHTTPClient sets the HTTP client. +func (c *Client) WithHTTPClient(httpClient *http.Client) *Client { + c.hc = httpClient + return c +} + +func (c *Client) doRequest(req *http.Request) (*http.Response, error) { + resp, err := c.hc.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest { + return resp, nil + } + defer resp.Body.Close() + + var apiErr APIError + if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil { + return nil, err + } + + return nil, apiErr +} diff --git a/huggingface/client_test.go b/huggingface/client_test.go new file mode 100644 index 0000000..b441e2e --- /dev/null +++ b/huggingface/client_test.go @@ -0,0 +1,52 @@ +package huggingface + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + huggingFaceKey = "somekey" +) + +func TestClient(t *testing.T) { + t.Setenv("HUGGINGFACE_API_KEY", huggingFaceKey) + + t.Run("API key", func(t *testing.T) { + c := NewClient() + assert.Equal(t, c.apiKey, huggingFaceKey) + + testVal := "foo" + c.WithAPIKey(testVal) + assert.Equal(t, c.apiKey, testVal) + }) + + t.Run("BaseURL", func(t *testing.T) { + c := NewClient() + assert.Equal(t, c.baseURL, BaseURL) + + testVal := "http://foo" + c.WithBaseURL(testVal) + assert.Equal(t, c.baseURL, testVal) + }) + + t.Run("Model", func(t *testing.T) { + c := NewClient() + assert.Equal(t, c.model, "") + + testVal := "foo/bar" + c.WithModel(testVal) + assert.Equal(t, c.model, testVal) + }) + + t.Run("http client", func(t *testing.T) { + c := NewClient() + assert.NotNil(t, c.hc) + + testVal := &http.Client{} + c.WithHTTPClient(testVal) + assert.NotNil(t, c.hc) + }) +} diff --git a/huggingface/embedding.go b/huggingface/embedding.go new file mode 100644 index 0000000..880bed4 --- /dev/null +++ b/huggingface/embedding.go @@ -0,0 +1,81 @@ +package huggingface + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" + + "github.com/milosgajdos/go-embeddings" + "github.com/milosgajdos/go-embeddings/request" +) + +// EmbeddingRequest sent to API endpoint. +type EmbeddingRequest struct { + Inputs []string `json:"inputs"` + Options Options `json:"options,omitempty"` +} + +// Options +type Options struct { + WaitForModel *bool `json:"wait_for_model,omitempty"` +} + +// EmbedddingResponse is returned by API. +// TODO: hugging face APIs are a mess +type EmbedddingResponse [][][][]float64 + +// ToEmbeddings converts the raw API response, +// parses it into a slice of embeddings and returns it. +func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) { + emb := *e + embs := make([]*embeddings.Embedding, 0, len(emb)) + //for i := range emb { + // vals := emb[i] + // floats := make([]float64, len(vals)) + // copy(floats, vals) + // emb := &embeddings.Embedding{ + // Vector: floats, + // } + // embs = append(embs, emb) + //} + return embs, nil +} + +// Embeddings returns embeddings for every object in EmbeddingRequest. +func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) { + u, err := url.Parse(c.baseURL + "/" + c.model) + if err != nil { + return nil, err + } + + var body = &bytes.Buffer{} + enc := json.NewEncoder(body) + enc.SetEscapeHTML(false) + if err := enc.Encode(embReq); err != nil { + return nil, err + } + + options := []request.Option{ + request.WithBearer(c.apiKey), + } + + req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...) + if err != nil { + return nil, err + } + + resp, err := c.doRequest(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + e := new(EmbedddingResponse) + if err := json.NewDecoder(resp.Body).Decode(e); err != nil { + return nil, err + } + + return e, nil +} diff --git a/huggingface/error.go b/huggingface/error.go new file mode 100644 index 0000000..219d789 --- /dev/null +++ b/huggingface/error.go @@ -0,0 +1,17 @@ +package huggingface + +import "encoding/json" + +// APIError is error returned by API +type APIError struct { + Message string `json:"error"` +} + +// Error implements errors interface. +func (e APIError) Error() string { + b, err := json.Marshal(e) + if err != nil { + return "unknown error" + } + return string(b) +}