diff --git a/multiple.go b/multiple.go index 35130fe..a961d53 100644 --- a/multiple.go +++ b/multiple.go @@ -19,6 +19,7 @@ package sqlx import ( "context" "database/sql" + "strconv" "strings" "github.com/zeromicro/go-zero/core/stores/sqlx" @@ -40,6 +41,11 @@ var ( var _ sqlx.SqlConn = (*multipleSqlConn)(nil) type ( + FollowerDB struct { + name string + datasource string + Added bool + } DBConf struct { Leader string Followers []string `json:",optional"` @@ -48,16 +54,15 @@ type ( SqlOption func(*sqlOptions) sqlOptions struct { - accept func(error) bool + accept func(error) bool + watcher <-chan FollowerDB } multipleSqlConn struct { - leader sqlx.SqlConn - enableFollower bool - p2cPicker picker // picker - followers []sqlx.SqlConn - conf DBConf - accept func(error) bool - driveName string + leader sqlx.SqlConn + p2cPicker picker // picker + conf DBConf + driveName string + sqlOptions *sqlOptions } ) @@ -69,21 +74,36 @@ func NewMultipleSqlConn(driverName string, conf DBConf, opts ...SqlOption) sqlx. } leader := sqlx.NewSqlConn(driverName, conf.Leader, sqlx.WithAcceptable(sqlOpt.accept)) - followers := make([]sqlx.SqlConn, 0, len(conf.Followers)) - for _, datasource := range conf.Followers { - followers = append(followers, sqlx.NewSqlConn(driverName, datasource, sqlx.WithAcceptable(sqlOpt.accept))) - } conn := &multipleSqlConn{ - leader: leader, - enableFollower: len(followers) != 0, - followers: followers, - conf: conf, - driveName: driverName, - accept: sqlOpt.accept, + leader: leader, + conf: conf, + driveName: driverName, + sqlOptions: &sqlOpt, } - conn.p2cPicker = newP2cPicker(followers, conn.accept) + p2cPickerObj := newP2cPicker(driverName, sqlOpt.accept) + go func() { + if sqlOpt.watcher == nil { + return + } + + for { + select { + case follow := <-sqlOpt.watcher: + if follow.Added { + p2cPickerObj.add(follow.name, follow.datasource) + } else { + p2cPickerObj.del(follow.name) + } + } + } + }() + + for i, datasource := range conf.Followers { + p2cPickerObj.add(strconv.Itoa(i), datasource) + } + conn.p2cPicker = p2cPickerObj return conn } @@ -178,10 +198,6 @@ func (m *multipleSqlConn) getQueryDB(ctx context.Context, query string) queryDB return queryDB{conn: m.leader} } - if !m.enableFollower { - return queryDB{conn: m.leader} - } - if !m.containSelect(query) { return queryDB{conn: m.leader} } @@ -212,10 +228,10 @@ func (m *multipleSqlConn) startSpanWithLeader(ctx context.Context) (context.Cont return ctx, span } -func (m *multipleSqlConn) startSpanWithFollower(ctx context.Context, db int) (context.Context, oteltrace.Span) { +func (m *multipleSqlConn) startSpanWithFollower(ctx context.Context, dbName string) (context.Context, oteltrace.Span) { ctx, span := m.startSpan(ctx) span.SetAttributes(followerTypeAttributeKey) - span.SetAttributes(followerDBSqlAttributeKey.Int(db)) + span.SetAttributes(followerDBSqlAttributeKey.String(dbName)) return ctx, span } @@ -239,7 +255,7 @@ type queryDB struct { error error done func(err error) follower bool - followerDB int + followerDB string } func (q *queryDB) query(ctx context.Context, query func(ctx context.Context, conn sqlx.SqlConn) error) (err error) { @@ -261,6 +277,12 @@ func WithAccept(accept func(err error) bool) SqlOption { } } +func WithWatchFollowerDB(watcher <-chan FollowerDB) SqlOption { + return func(opts *sqlOptions) { + opts.watcher = watcher + } +} + type forceLeaderKey struct{} func ForceLeaderContext(ctx context.Context) context.Context { diff --git a/picker.go b/picker.go index 22ceeff..9ca1f6b 100644 --- a/picker.go +++ b/picker.go @@ -53,57 +53,78 @@ type ( pickResult struct { conn sqlx.SqlConn done func(err error) - followerDB int + followerDB string } p2cPicker struct { - conns []*subConn - r *rand.Rand - stamp *syncx.AtomicDuration - lock sync.Mutex - accept func(err error) bool + r *rand.Rand + stamp *syncx.AtomicDuration + accept func(err error) bool + driverName string + + connsMap map[string]*subConn + lock sync.Mutex } ) -func newP2cPicker(followers []sqlx.SqlConn, accept func(err error) bool) *p2cPicker { - conns := make([]*subConn, 0, len(followers)) - for i, follower := range followers { - conns = append(conns, &subConn{ - success: initSuccess, - db: i, - conn: follower, - }) +func newP2cPicker(driverName string, accept func(err error) bool) *p2cPicker { + return &p2cPicker{ + r: rand.New(rand.NewSource(time.Now().UnixNano())), + stamp: syncx.NewAtomicDuration(), + accept: accept, + connsMap: map[string]*subConn{}, + driverName: driverName, } +} - return &p2cPicker{ - conns: conns, - r: rand.New(rand.NewSource(time.Now().UnixNano())), - stamp: syncx.NewAtomicDuration(), - accept: accept, +func (p *p2cPicker) del(name string) { + p.lock.Lock() + defer p.lock.Unlock() + p.connsMap[name] = nil + delete(p.connsMap, name) +} + +func (p *p2cPicker) add(name, dns string) { + p.lock.Lock() + defer p.lock.Unlock() + p.connsMap[name] = newSubConn(p.driverName, name, dns) +} + +func (p *p2cPicker) getSubConns() []*subConn { + p.lock.Lock() + defer p.lock.Unlock() + conns := make([]*subConn, 0, len(p.connsMap)) + for _, conn := range p.connsMap { + if conn != nil { + conns = append(conns, conn) + } } + + return conns } func (p *p2cPicker) pick() (*pickResult, error) { p.lock.Lock() defer p.lock.Unlock() + conns := p.getSubConns() var chosen *subConn - switch len(p.conns) { + switch len(conns) { case 0: return nil, ErrNoFollowerAvailable case 1: - chosen = p.choose(p.conns[0], nil) + chosen = p.choose(conns[0], nil) case 2: - chosen = p.choose(p.conns[0], p.conns[1]) + chosen = p.choose(conns[0], conns[1]) default: var node1, node2 *subConn for i := 0; i < pickTimes; i++ { - a := p.r.Intn(len(p.conns)) - b := p.r.Intn(len(p.conns) - 1) + a := p.r.Intn(len(conns)) + b := p.r.Intn(len(conns) - 1) if b >= a { b++ } - node1 = p.conns[a] - node2 = p.conns[b] + node1 = conns[a] + node2 = conns[b] if node1.healthy() && node2.healthy() { break } @@ -118,7 +139,7 @@ func (p *p2cPicker) pick() (*pickResult, error) { return &pickResult{ conn: chosen.conn, done: p.buildDoneFunc(chosen), - followerDB: chosen.db, + followerDB: chosen.name, }, nil } @@ -191,10 +212,11 @@ func (p *p2cPicker) choose(c1, c2 *subConn) *subConn { func (p *p2cPicker) logStats() { p.lock.Lock() defer p.lock.Unlock() - stats := make([]string, 0, len(p.conns)) - for _, conn := range p.conns { - stats = append(stats, fmt.Sprintf("db: %d, load: %d, reqs: %d", - conn.db, conn.load(), atomic.SwapInt64(&conn.requests, 0))) + conns := p.getSubConns() + stats := make([]string, 0, len(conns)) + for _, conn := range conns { + stats = append(stats, fmt.Sprintf("db: %s, load: %d, reqs: %d", + conn.name, conn.load(), atomic.SwapInt64(&conn.requests, 0))) } logx.Statf("follower db - p2c - %s", strings.Join(stats, "; ")) @@ -207,7 +229,7 @@ type subConn struct { requests int64 // 用来保存请求总数 last int64 // 用来保存上一次请求耗时, 用于计算 ewma 值 pick int64 // 保存上一次被选中的时间点 - db int + name string conn sqlx.SqlConn } @@ -229,10 +251,18 @@ func (c *subConn) load() int64 { } func (p *p2cPicker) acceptable(err error) bool { - ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled + ok := err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) || errors.Is(err, context.Canceled) if p.accept == nil { return ok } return ok || p.accept(err) } + +func newSubConn(driverName, name, datasource string) *subConn { + return &subConn{ + success: initSuccess, + name: name, + conn: sqlx.NewSqlConn(driverName, datasource), + } +}