diff --git a/server/middleware/metric/metric.go b/server/middleware/metric/metric.go index 7f0dd72..ca7aec0 100644 --- a/server/middleware/metric/metric.go +++ b/server/middleware/metric/metric.go @@ -5,7 +5,6 @@ import ( "sort" "strings" "sync" - "sync/atomic" "time" "cloud.google.com/go/monitoring/apiv3/v2/monitoringpb" @@ -13,16 +12,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -func (c *CounterMetric) Increment(key any, i int64) int64 { - v, _ := c.LoadOrStore(key, new(atomic.Int64)) - ai, ok := v.(*atomic.Int64) - if !ok { - ai = new(atomic.Int64) - } - ai.Add(i) // Initialize and increment atomically - return ai.Load() -} - type customCounterKey struct { t string l string diff --git a/server/middleware/metric/metric_middleware.go b/server/middleware/metric/metric_middleware.go index 492c143..b4c81f0 100644 --- a/server/middleware/metric/metric_middleware.go +++ b/server/middleware/metric/metric_middleware.go @@ -56,12 +56,13 @@ func MetricsMiddleware(client *monitoring.MetricClient, config *config.Config) e // CounterMetric safely increments counters using concurrent maps and atomic operations. type CounterMetric struct{ sync.Map } -func (c *CounterMetric) increment(key any, i int64) int64 { - v, loaded := c.LoadOrStore(key, new(atomic.Int64)) - ai := v.(*atomic.Int64) - if !loaded { - ai.Add(i) // Initialize and increment atomically +func (c *CounterMetric) Increment(key any, i int64) int64 { + v, _ := c.LoadOrStore(key, new(atomic.Int64)) + ai, ok := v.(*atomic.Int64) + if !ok { + ai = new(atomic.Int64) } + ai.Add(i) // Initialize and increment atomically return ai.Load() } @@ -172,7 +173,7 @@ var reqCountMetric = CounterMetric{Map: sync.Map{}} // createRequestMetric constructs a cumulative metric for counting requests. func createRequestMetric(c echo.Context) *monitoringpb.TimeSeries { key := endpointMetricKeyFromEcho(c) - val := reqCountMetric.increment(key, 1) + val := reqCountMetric.Increment(key, 1) return &monitoringpb.TimeSeries{ Metric: &metricpb.Metric{ Type: MetricTypePrefix + "/request_count", @@ -204,7 +205,7 @@ func createErrorMetric(c echo.Context, err error) *monitoringpb.TimeSeries { } key := endpointMetricKeyFromEcho(c) - val := reqErrCountMetric.increment(key, 1) + val := reqErrCountMetric.Increment(key, 1) return &monitoringpb.TimeSeries{ Metric: &metricpb.Metric{ Type: MetricTypePrefix + "/request_errors", diff --git a/server/middleware/metric/metric_test.go b/server/middleware/metric/metric_test.go new file mode 100644 index 0000000..c15fb2b --- /dev/null +++ b/server/middleware/metric/metric_test.go @@ -0,0 +1,39 @@ +package metric + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCounterMetric(t *testing.T) { + m := CounterMetric{Map: sync.Map{}} + count := 1000 + expexted := 0 + for i := 0; i < count; i++ { + expexted += i + } + expexted += 1 + + wg := sync.WaitGroup{} + wg.Add(count * 2) + for i := 0; i < count; i++ { + i := i + go func() { + defer wg.Done() + m.Increment("test", int64(i)) + }() + go func() { + defer wg.Done() + m.Increment("test1", int64(i)) + }() + } + wg.Wait() + v := m.Increment("test", 1) + v1 := m.Increment("test1", 1) + + assert.Equal(t, int64(expexted), v) + assert.Equal(t, int64(expexted), v1) + +}