diff --git a/clickhouse.go b/clickhouse.go index 10c80e6..17caec2 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "crypto/tls" "errors" "fmt" @@ -21,6 +22,7 @@ type ClickhouseServer struct { Bad bool Client *http.Client LogQueries bool + Credentials *Credentials } // Clickhouse - main clickhouse sender object @@ -33,15 +35,15 @@ type Clickhouse struct { Dumper Dumper wg sync.WaitGroup Transport *http.Transport + Credentials *Credentials } // ClickhouseRequest - request struct for queue type ClickhouseRequest struct { - Params string - Query string - Content string - Count int - isInsert bool + Params string + Query string + Content string + Count int } // ErrServerIsDown - signals about server is down @@ -52,7 +54,7 @@ var ErrNoServers = errors.New("No working clickhouse servers") // NewClickhouse - get clickhouse object func NewClickhouse(downTimeout int, connectTimeout int, tlsServerName string, tlsSkipVerify bool) (c *Clickhouse) { - tlsConfig := &tls.Config{} + tlsConfig := new(tls.Config) if tlsServerName != "" { tlsConfig.ServerName = tlsServerName } @@ -60,21 +62,32 @@ func NewClickhouse(downTimeout int, connectTimeout int, tlsServerName string, tl tlsConfig.InsecureSkipVerify = tlsSkipVerify } - c = new(Clickhouse) - c.DownTimeout = downTimeout - c.ConnectTimeout = connectTimeout - if c.ConnectTimeout <= 0 { - c.ConnectTimeout = 10 + if connectTimeout <= 0 { + connectTimeout = 10 } - c.Servers = make([]*ClickhouseServer, 0) - c.Queue = queue.New(1000) - c.Transport = &http.Transport{ - TLSClientConfig: tlsConfig, + + c = &Clickhouse{ + DownTimeout: downTimeout, + ConnectTimeout: connectTimeout, + Servers: make([]*ClickhouseServer, 0), + Queue: queue.New(1000), + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + Credentials: &Credentials{ + User: "default", + Pass: "", + }, } + go c.Run() return c } +func (click *Clickhouse) SetCreds(creds *Credentials) { + click.Credentials = creds +} + // AddServer - add clickhouse server url func (c *Clickhouse) AddServer(url string, logQueries bool) { c.mu.Lock() @@ -125,6 +138,8 @@ func (c *Clickhouse) GetNextServer() (srv *ClickhouseServer) { if srv != nil { srv.LastRequest = time.Now() } + + srv.Credentials = c.Credentials return srv } @@ -190,29 +205,39 @@ func (c *Clickhouse) WaitFlush() (err error) { // SendQuery - sends query to server and return result func (srv *ClickhouseServer) SendQuery(r *ClickhouseRequest) (response string, status int, err error) { if srv.URL != "" { + log.Printf("INFO: sending %+v rows to %+v\n", r.Count, srv.URL) + if cnf.Debug { + log.Printf("DEBUG: query %+v\n", r.Query) + } + url := srv.URL if r.Params != "" { url += "?" + r.Params } - if r.isInsert && srv.LogQueries { - log.Printf("INFO: sending %+v rows to %+v of %+v\n", r.Count, srv.URL, r.Query) - } - resp, err := srv.Client.Post(url, "text/plain", strings.NewReader(r.Content)) + + conn := srv.Client + req, _ := http.NewRequest("POST", url, strings.NewReader(r.Content)) + req.Header.Add("X-ClickHouse-User", srv.Credentials.User) + req.Header.Add("X-ClickHouse-Key", srv.Credentials.Pass) + resp, err := conn.Do(req) if err != nil { srv.Bad = true return err.Error(), http.StatusBadGateway, ErrServerIsDown } - if r.isInsert && srv.LogQueries { - log.Printf("INFO: sent %+v rows to %+v of %+v\n", r.Count, srv.URL, r.Query) - } + defer resp.Body.Close() buf, _ := ioutil.ReadAll(resp.Body) s := string(buf) if resp.StatusCode >= 502 { srv.Bad = true err = ErrServerIsDown } else if resp.StatusCode >= 400 { - err = fmt.Errorf("Wrong server status %+v:\nresponse: %+v\nrequest: %#v", resp.StatusCode, s, r.Content) + err = fmt.Errorf("ERROR: Wrong server status %+v:\nresponse: %+v\n", resp.StatusCode, s) + if cnf.Debug { + err = fmt.Errorf("ERROR: Wrong server status %+v:\nresponse: %+v\nRequest: %#v\n", resp.StatusCode, s, r.Content) + } } + + log.Printf("INFO: sent %+v rows to %+v\n", r.Count, srv.URL) return s, resp.StatusCode, err } @@ -234,3 +259,26 @@ func (c *Clickhouse) SendQuery(r *ClickhouseRequest) (response string, status in return "", http.StatusServiceUnavailable, ErrNoServers } } + +func (c *Clickhouse) PassThru(req *http.Request, clientReqBody []byte) (res *http.Response, buf *bytes.Buffer) { + for { + s := c.GetNextServer() + if s != nil { + reqBuf := bytes.NewBuffer(clientReqBody) + + clickReq, _ := http.NewRequest(req.Method, s.URL, reqBuf) + + CopyHeader(clickReq.Header, req.Header) + res, err := s.Client.Do(clickReq) + if errors.Is(err, ErrServerIsDown) { + log.Printf("ERROR: server down (%+v): %+v\n", res.Status, res) + continue + } + + resBody, _ := ioutil.ReadAll(res.Body) + defer res.Body.Close() + + return res, bytes.NewBuffer(resBody) + } + } +} diff --git a/collector.go b/collector.go index cff5db6..986b3d3 100644 --- a/collector.go +++ b/collector.go @@ -72,7 +72,7 @@ func NewCollector(sender Sender, count int, interval int, cleanInterval int, rem // Content - get text content of rowsfor query func (t *Table) Content() string { rowDelimiter := "\n" - if t.Format == "RowBinary" { + if strings.HasPrefix(t.Format, "RowBinary") { rowDelimiter = "" } return t.Query + "\n" + strings.Join(t.Rows, rowDelimiter) @@ -81,11 +81,10 @@ func (t *Table) Content() string { // Flush - sends collected data in table to clickhouse func (t *Table) Flush() { req := ClickhouseRequest{ - Params: t.Params, - Query: t.Query, - Content: t.Content(), - Count: len(t.Rows), - isInsert: true, + Params: t.Params, + Query: t.Query, + Content: t.Content(), + Count: len(t.Rows), } t.Sender.Send(&req) t.Rows = make([]string, 0, t.FlushCount) diff --git a/dump.go b/dump.go index 0f82d3f..36050ce 100644 --- a/dump.go +++ b/dump.go @@ -162,7 +162,13 @@ func (d *FileDumper) ProcessNextDump(sender Sender) error { query = lines[1] data = strings.Join(lines[1:], "\n") } - _, status, err := sender.SendQuery(&ClickhouseRequest{Params: params, Content: data, Query: query, Count: len(lines[2:]), isInsert: true}) + cr := &ClickhouseRequest{ + Params: params, + Content: data, + Query: query, + Count: len(lines[2:]), + } + _, status, err := sender.SendQuery(cr) if err != nil { return fmt.Errorf("server error (%+v) %+v", status, err) } diff --git a/main.go b/main.go index 62f5e88..09e1b5e 100644 --- a/main.go +++ b/main.go @@ -23,9 +23,9 @@ func main() { return } - cnf, err := ReadConfig(*configFile) + err := ReadConfig(*configFile) if err != nil { log.Fatalf("ERROR: %+v\n", err) } - RunServer(cnf) + RunServer() } diff --git a/sender.go b/sender.go index 8cb082f..f399f51 100644 --- a/sender.go +++ b/sender.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "log" "net/http" "sync" @@ -10,9 +11,11 @@ import ( type Sender interface { Send(r *ClickhouseRequest) SendQuery(r *ClickhouseRequest) (response string, status int, err error) + PassThru(req *http.Request, clientReqBody []byte) (res *http.Response, buf *bytes.Buffer) Len() int64 Empty() bool WaitFlush() (err error) + SetCreds(c *Credentials) } type fakeSender struct { @@ -32,6 +35,9 @@ func (s *fakeSender) SendQuery(r *ClickhouseRequest) (response string, status in log.Printf("DEBUG: send query: %+v\n", s.sendQueryHistory) return "", http.StatusOK, nil } +func (c *fakeSender) PassThru(req *http.Request, clientReqBody []byte) (res *http.Response, buf *bytes.Buffer) { + return +} func (s *fakeSender) Len() int64 { return 0 @@ -44,3 +50,6 @@ func (s *fakeSender) Empty() bool { func (s *fakeSender) WaitFlush() error { return nil } + +func (s *fakeSender) SetCreds(c *Credentials) { +} diff --git a/server.go b/server.go index df45553..872b87a 100644 --- a/server.go +++ b/server.go @@ -42,7 +42,8 @@ func NewServer(listen string, collector *Collector, debug bool, logQueries bool) } func (server *Server) writeHandler(c echo.Context) error { - q, _ := ioutil.ReadAll(c.Request().Body) + req := c.Request() + q, _ := ioutil.ReadAll(req.Body) s := string(q) if server.Debug { @@ -50,14 +51,8 @@ func (server *Server) writeHandler(c echo.Context) error { } qs := c.QueryString() - user, password, ok := c.Request().BasicAuth() - if ok { - if qs == "" { - qs = "user=" + user + "&password=" + password - } else { - qs = "user=" + user + "&password=" + password + "&" + qs - } - } + server.Collector.Sender.SetCreds(getAuth(req)) + params, content, insert := server.Collector.ParseQuery(qs, s) if insert { if len(content) == 0 { @@ -67,8 +62,14 @@ func (server *Server) writeHandler(c echo.Context) error { go server.Collector.Push(params, content) return c.String(http.StatusOK, "") } - resp, status, _ := server.Collector.Sender.SendQuery(&ClickhouseRequest{Params: qs, Content: s, isInsert: false}) - return c.String(status, resp) + + res, buf := server.Collector.Sender.PassThru(req, q) + + defer res.Body.Close() + CopyHeader(c.Response().Header(), res.Header) + c.Response().WriteHeader(res.StatusCode) + c.Response().Header().Set("Collection", "close") + return c.Stream(200, "application/octet-stream", buf) } func (server *Server) statusHandler(c echo.Context) error { @@ -100,7 +101,7 @@ func (server *Server) tablesCleanHandler(c echo.Context) error { } // Start - start http server -func (server *Server) Start(cnf Config) error { +func (server *Server) Start() error { if cnf.UseTLS { return server.echo.StartTLS(server.Listen, cnf.TLSCertFile, cnf.TLSKeyFile) } else { @@ -141,7 +142,7 @@ func SafeQuit(collect *Collector, sender Sender) { } // RunServer - run all -func RunServer(cnf Config) { +func RunServer() { InitMetrics(cnf.MetricsPrefix) dumper := NewDumper(cnf.DumpDir) sender := NewClickhouse(cnf.Clickhouse.DownTimeout, cnf.Clickhouse.ConnectTimeout, cnf.Clickhouse.tlsServerName, cnf.Clickhouse.tlsSkipVerify) @@ -176,7 +177,7 @@ func RunServer(cnf Config) { dumper.Listen(sender, cnf.DumpCheckInterval) } - err := srv.Start(cnf) + err := srv.Start() if err != nil { log.Printf("ListenAndServe: %+v\n", err) SafeQuit(collect, sender) diff --git a/server_test.go b/server_test.go index c383f7c..2156b9d 100644 --- a/server_test.go +++ b/server_test.go @@ -18,10 +18,10 @@ import ( ) func TestRunServer(t *testing.T) { - cnf, _ := ReadConfig("wrong_config.json") + _ = ReadConfig("wrong_config.json") collector := NewCollector(&fakeSender{}, 1000, 1000, 0, true) server := InitServer("", collector, false, true) - go server.Start(cnf) + go server.Start() server.echo.POST("/", server.writeHandler) status, resp := request("POST", "/", "", server.echo) @@ -125,11 +125,11 @@ func TestServer_MultiServer(t *testing.T) { assert.True(t, sender.Empty()) os.Setenv("DUMP_CHECK_INTERVAL", "10") - cnf, err := ReadConfig("wrong_config.json") + err := ReadConfig("wrong_config.json") os.Unsetenv("DUMP_CHECK_INTERVAL") assert.Nil(t, err) assert.Equal(t, 10, cnf.DumpCheckInterval) - go RunServer(cnf) + go RunServer() time.Sleep(1000) } diff --git a/utils.go b/utils.go index 0ee6a6b..e9478d5 100644 --- a/utils.go +++ b/utils.go @@ -3,6 +3,7 @@ package main import ( "encoding/json" "log" + "net/http" "os" "strconv" "strings" @@ -10,6 +11,8 @@ import ( const sampleConfig = "config.sample.json" +var cnf Config + type clickhouseConfig struct { Servers []string `json:"servers"` tlsServerName string `json:"tls_server_name"` @@ -29,11 +32,16 @@ type Config struct { DumpCheckInterval int `json:"dump_check_interval"` DumpDir string `json:"dump_dir"` Debug bool `json:"debug"` - LogQueries bool `json:"log_queries"` + LogQueries bool `json:"log_queries"` MetricsPrefix string `json:"metrics_prefix"` UseTLS bool `json:"use_tls"` - TLSCertFile string `json:"tls_cert_file"` - TLSKeyFile string `json:"tls_key_file"` + TLSCertFile string `json:"tls_cert_file"` + TLSKeyFile string `json:"tls_key_file"` +} + +type Credentials struct { + User string + Pass string } // ReadJSON - read json file to struct @@ -82,8 +90,8 @@ func readEnvString(name string, value *string) { } // ReadConfig init config data -func ReadConfig(configFile string) (Config, error) { - cnf := Config{} +func ReadConfig(configFile string) error { + cnf = Config{} err := ReadJSON(configFile, &cnf) if err != nil { log.Printf("INFO: Config file %+v not found. Used%+v\n", configFile, sampleConfig) @@ -116,5 +124,49 @@ func ReadConfig(configFile string) (Config, error) { cnf.Clickhouse.tlsServerName = tlsServerName } - return cnf, err + return err +} + +// getAuth retrieves auth credentials from request +// according to CH documentation @see "https://clickhouse.yandex/docs/en/interfaces/http/" +func getAuth(req *http.Request) *Credentials { + // check X-ClickHouse- headers + name := req.Header.Get("X-ClickHouse-User") + pass := req.Header.Get("X-ClickHouse-Key") + if name != "" { + return &Credentials{ + User: name, + Pass: pass, + } + } + // if header is empty - check basicAuth + if name, pass, ok := req.BasicAuth(); ok { + return &Credentials{ + User: name, + Pass: pass, + } + } + // if basicAuth is empty - check URL params `user` and `password` + params := req.URL.Query() + if name := params.Get("user"); name != "" { + pass := params.Get("password") + return &Credentials{ + User: name, + Pass: pass, + } + } + // if still no credentials - treat it as `default` user request + return &Credentials{ + User: "default", + Pass: "", + } +} + +func CopyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } + dst.Add("Connection", "Close") }