diff --git a/gateways/algolia/algolia.go b/gateways/algolia/algolia.go index 3c1e48a..39f48ce 100644 --- a/gateways/algolia/algolia.go +++ b/gateways/algolia/algolia.go @@ -7,6 +7,7 @@ import ( "registry-backend/ent" "registry-backend/entity" "registry-backend/mapper" + "registry-backend/tracing" "github.com/algolia/algoliasearch-client-go/v3/algolia/search" "github.com/rs/zerolog/log" @@ -49,6 +50,8 @@ func NewAlgoliaService(cfg *config.Config) (AlgoliaService, error) { // IndexNodes indexes the provided nodes in Algolia. func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { + defer tracing.TraceDefaultSegment(ctx, "AlgoliaService.IndexNodes")() + index := a.client.InitIndex("nodes_index") objects := make([]entity.AlgoliaNode, len(nodes)) @@ -66,6 +69,8 @@ func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { // SearchNodes searches for nodes in Algolia matching the query. func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) { + defer tracing.TraceDefaultSegment(ctx, "AlgoliaService.SearchNodes")() + index := a.client.InitIndex("nodes_index") res, err := index.Search(query, opts...) if err != nil { @@ -84,6 +89,8 @@ func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interfa // DeleteNode deletes the specified node from Algolia. func (a *algolia) DeleteNode(ctx context.Context, node *ent.Node) error { + defer tracing.TraceDefaultSegment(ctx, "AlgoliaService.DeleteNode")() + index := a.client.InitIndex("nodes_index") res, err := index.DeleteObject(node.ID) if err != nil { @@ -94,6 +101,8 @@ func (a *algolia) DeleteNode(ctx context.Context, node *ent.Node) error { // IndexNodeVersions implements AlgoliaService. func (a *algolia) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + defer tracing.TraceDefaultSegment(ctx, "AlgoliaService.IndexNodeVersions")() + index := a.client.InitIndex("node_versions_index") objects := make([]struct { ObjectID string `json:"objectID"` @@ -122,6 +131,8 @@ func (a *algolia) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersi // DeleteNodeVersions implements AlgoliaService. func (a *algolia) DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + defer tracing.TraceDefaultSegment(ctx, "AlgoliaService.DeleteNodeVersions")() + index := a.client.InitIndex("node_versions_index") ids := []string{} for _, node := range nodes { @@ -136,6 +147,8 @@ func (a *algolia) DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVers // SearchNodeVersions implements AlgoliaService. func (a *algolia) SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) { + defer tracing.TraceDefaultSegment(ctx, "AlgoliaService.SearchNodeVersions")() + index := a.client.InitIndex("node_versions_index") res, err := index.Search(query, opts...) if err != nil { diff --git a/gateways/pubsub/pubsub.go b/gateways/pubsub/pubsub.go index 4c78128..6b675c2 100644 --- a/gateways/pubsub/pubsub.go +++ b/gateways/pubsub/pubsub.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "registry-backend/config" + "registry-backend/tracing" "strconv" "strings" "time" @@ -47,6 +48,8 @@ func NewPubSubService(c *config.Config) (PubSubService, error) { // PublishNodePack implements PubSubService. func (p *pubsubimpl) PublishNodePack(ctx context.Context, storageURL string) (err error) { + defer tracing.TraceDefaultSegment(ctx, "PubSubService.PublishNodePack")() + u, err := url.Parse(storageURL) if err != nil { return fmt.Errorf("invalid storage URL: %w", err) diff --git a/gateways/storage/files.go b/gateways/storage/files.go index 54fe01b..4da562b 100644 --- a/gateways/storage/files.go +++ b/gateways/storage/files.go @@ -8,9 +8,11 @@ import ( "os" "time" + "registry-backend/config" + "registry-backend/tracing" + "cloud.google.com/go/storage" "github.com/rs/zerolog/log" - "registry-backend/config" ) // StorageService defines the interface for interacting with cloud storage. @@ -52,6 +54,8 @@ func NewStorageService(cfg *config.Config) (StorageService, error) { // UploadFile uploads an object to GCP storage. func (s *storageService) UploadFile(ctx context.Context, bucket, object, filePath string) (string, error) { + defer tracing.TraceDefaultSegment(ctx, "StorageService.UploadFile")() + log.Ctx(ctx).Info().Msgf("Uploading %v to %v/%v.\n", filePath, bucket, object) // Open local file @@ -120,6 +124,8 @@ func (s *storageService) StreamFileUpload(w io.Writer, objectName, blob string) // GetFileUrl gets the public URL of a file from GCP storage. func (s *storageService) GetFileUrl(ctx context.Context, bucketName, objectPath string) (string, error) { + defer tracing.TraceDefaultSegment(ctx, "StorageService.GetFileUrl")() + // Get the public URL of a file in a bucket attrs, err := s.client.Bucket(bucketName).Object(objectPath).Attrs(ctx) if err != nil { diff --git a/server/implementation/cicd.go b/server/implementation/cicd.go index 586971a..c17a814 100644 --- a/server/implementation/cicd.go +++ b/server/implementation/cicd.go @@ -8,6 +8,7 @@ import ( "registry-backend/ent/gitcommit" "registry-backend/ent/schema" "registry-backend/mapper" + "registry-backend/tracing" "sort" "strings" @@ -17,6 +18,8 @@ import ( ) func (impl *DripStrictServerImplementation) GetGitcommit(ctx context.Context, request drip.GetGitcommitRequestObject) (drip.GetGitcommitResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetGitcommit")() + var commitId uuid.UUID = uuid.Nil if request.Params.CommitId != nil { log.Ctx(ctx).Info().Msgf("getting commit data for %s", *request.Params.CommitId) @@ -124,6 +127,7 @@ func (impl *DripStrictServerImplementation) GetGitcommit(ctx context.Context, re } func (impl *DripStrictServerImplementation) GetGitcommitsummary(ctx context.Context, request drip.GetGitcommitsummaryRequestObject) (drip.GetGitcommitsummaryResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetGitcommitsummary")() log.Ctx(ctx).Info().Msg("Getting git commit summary") // Prep relevant vars @@ -254,6 +258,8 @@ func (impl *DripStrictServerImplementation) GetGitcommitsummary(ctx context.Cont } func (impl *DripStrictServerImplementation) GetWorkflowResult(ctx context.Context, request drip.GetWorkflowResultRequestObject) (drip.GetWorkflowResultResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetWorkflowResult")() + log.Ctx(ctx).Info().Msgf("Getting workflow result with ID %s", request.WorkflowResultId) workflowId := uuid.MustParse(request.WorkflowResultId) workflow, err := impl.Client.CIWorkflowResult.Query().WithGitcommit().WithStorageFile().Where(ciworkflowresult.IDEQ(workflowId)).First(ctx) @@ -278,6 +284,8 @@ func (impl *DripStrictServerImplementation) GetWorkflowResult(ctx context.Contex } func (impl *DripStrictServerImplementation) GetBranch(ctx context.Context, request drip.GetBranchRequestObject) (drip.GetBranchResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetBranch")() + repoNameFilter := strings.ToLower(request.Params.RepoName) branches, err := impl.Client.GitCommit. @@ -295,6 +303,8 @@ func (impl *DripStrictServerImplementation) GetBranch(ctx context.Context, reque } func (impl *DripStrictServerImplementation) PostUploadArtifact(ctx context.Context, request drip.PostUploadArtifactRequestObject) (drip.PostUploadArtifactResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.PostUploadArtifact")() + err := impl.ComfyCIService.ProcessCIRequest(ctx, impl.Client, &request) if err != nil { log.Ctx(ctx).Error().Msgf("Error processing CI request w/ err: %v", err) diff --git a/server/implementation/registry.go b/server/implementation/registry.go index b0e6fff..e6772ff 100644 --- a/server/implementation/registry.go +++ b/server/implementation/registry.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/newrelic/go-agent/v3/newrelic" "registry-backend/drip" "registry-backend/ent" "registry-backend/ent/publisher" @@ -13,6 +12,7 @@ import ( drip_logging "registry-backend/logging" "registry-backend/mapper" drip_services "registry-backend/services/registry" + "registry-backend/tracing" "time" "github.com/google/uuid" @@ -23,6 +23,7 @@ import ( func (impl *DripStrictServerImplementation) ListPublishersForUser( ctx context.Context, request drip.ListPublishersForUserRequestObject) (drip.ListPublishersForUserResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListPublishersForUser")() // Extract user ID from context userId, err := mapper.GetUserIDFromContext(ctx) @@ -55,6 +56,8 @@ func (impl *DripStrictServerImplementation) ListPublishersForUser( func (s *DripStrictServerImplementation) ValidatePublisher( ctx context.Context, request drip.ValidatePublisherRequestObject) (drip.ValidatePublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ValidatePublisher")() + // Check if the username is empty name := request.Params.Username if name == "" { @@ -91,6 +94,8 @@ func (s *DripStrictServerImplementation) ValidatePublisher( func (s *DripStrictServerImplementation) CreatePublisher( ctx context.Context, request drip.CreatePublisherRequestObject) (drip.CreatePublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.CreatePublisher")() + // Extract user ID from context userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { @@ -131,6 +136,8 @@ func (s *DripStrictServerImplementation) CreatePublisher( func (s *DripStrictServerImplementation) ListPublishers( ctx context.Context, request drip.ListPublishersRequestObject) (drip.ListPublishersResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListPublishers")() + pubs, err := s.RegistryService.ListPublishers(ctx, s.Client, nil) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to retrieve list of publishers w/ err: %v", err) @@ -148,6 +155,8 @@ func (s *DripStrictServerImplementation) ListPublishers( func (s *DripStrictServerImplementation) DeletePublisher( ctx context.Context, request drip.DeletePublisherRequestObject) (drip.DeletePublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.DeletePublisher")() + err := s.RegistryService.DeletePublisher(ctx, s.Client, request.PublisherId) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to delete publisher with ID %s w/ err: %v", request.PublisherId, err) @@ -160,6 +169,7 @@ func (s *DripStrictServerImplementation) DeletePublisher( func (s *DripStrictServerImplementation) GetPublisher( ctx context.Context, request drip.GetPublisherRequestObject) (drip.GetPublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetPublisher")() publisherId := request.PublisherId publisher, err := s.RegistryService.GetPublisher(ctx, s.Client, request.PublisherId) @@ -178,6 +188,7 @@ func (s *DripStrictServerImplementation) GetPublisher( func (s *DripStrictServerImplementation) UpdatePublisher( ctx context.Context, request drip.UpdatePublisherRequestObject) (drip.UpdatePublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.UpdatePublisher")() updateOne := mapper.ApiUpdatePublisherToUpdateFields(request.PublisherId, request.Body, s.Client) updatedPublisher, err := s.RegistryService.UpdatePublisher(ctx, s.Client, updateOne) @@ -192,6 +203,7 @@ func (s *DripStrictServerImplementation) UpdatePublisher( func (s *DripStrictServerImplementation) CreateNode( ctx context.Context, request drip.CreateNodeRequestObject) (drip.CreateNodeResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.CreateNode")() node, err := s.RegistryService.CreateNode(ctx, s.Client, request.PublisherId, request.Body) if mapper.IsErrorBadRequest(err) || ent.IsConstraintError(err) { @@ -210,6 +222,7 @@ func (s *DripStrictServerImplementation) CreateNode( func (s *DripStrictServerImplementation) ListNodesForPublisher( ctx context.Context, request drip.ListNodesForPublisherRequestObject) (drip.ListNodesForPublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListNodesForPublisher")() nodeResults, err := s.RegistryService.ListNodes( ctx, s.Client /*page=*/, 1 /*limit=*/, 10, &entity.NodeFilter{ @@ -238,10 +251,7 @@ func (s *DripStrictServerImplementation) ListNodesForPublisher( func (s *DripStrictServerImplementation) ListAllNodes( ctx context.Context, request drip.ListAllNodesRequestObject) (drip.ListAllNodesResponseObject, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("DripStrictServerImplementation.ListAllNodes") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListAllNodes")() err := s.MixpanelService.Track(ctx, []*mixpanel.Event{ s.MixpanelService.NewEvent("List All Nodes", "", map[string]any{ @@ -316,6 +326,7 @@ func (s *DripStrictServerImplementation) ListAllNodes( // SearchNodes implements drip.StrictServerInterface. func (s *DripStrictServerImplementation) SearchNodes(ctx context.Context, request drip.SearchNodesRequestObject) (drip.SearchNodesResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.SearchNodes")() // Set default values for pagination parameters page := 1 @@ -375,6 +386,7 @@ func (s *DripStrictServerImplementation) SearchNodes(ctx context.Context, reques func (s *DripStrictServerImplementation) DeleteNode( ctx context.Context, request drip.DeleteNodeRequestObject) (drip.DeleteNodeResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.DeleteNode")() err := s.RegistryService.DeleteNode(ctx, s.Client, request.NodeId) if err != nil && !ent.IsNotFound(err) { @@ -388,6 +400,7 @@ func (s *DripStrictServerImplementation) DeleteNode( func (s *DripStrictServerImplementation) GetNode( ctx context.Context, request drip.GetNodeRequestObject) (drip.GetNodeResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetNode")() node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) if ent.IsNotFound(err) { @@ -415,6 +428,7 @@ func (s *DripStrictServerImplementation) GetNode( func (s *DripStrictServerImplementation) UpdateNode( ctx context.Context, request drip.UpdateNodeRequestObject) (drip.UpdateNodeResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.UpdateNode")() updateOneFunc := func(client *ent.Client) *ent.NodeUpdateOne { return mapper.ApiUpdateNodeToUpdateFields(request.NodeId, request.Body, client) @@ -435,6 +449,7 @@ func (s *DripStrictServerImplementation) UpdateNode( func (s *DripStrictServerImplementation) ListNodeVersions( ctx context.Context, request drip.ListNodeVersionsRequestObject) (drip.ListNodeVersionsResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListNodeVersions")() apiStatus := mapper.ApiNodeVersionStatusesToDbNodeVersionStatuses(request.Params.Statuses) @@ -459,6 +474,7 @@ func (s *DripStrictServerImplementation) ListNodeVersions( func (s *DripStrictServerImplementation) PublishNodeVersion( ctx context.Context, request drip.PublishNodeVersionRequestObject) (drip.PublishNodeVersionResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.PublishNodeVersion")() // Check if node exists, create if not node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) @@ -514,6 +530,7 @@ func (s *DripStrictServerImplementation) PublishNodeVersion( func (s *DripStrictServerImplementation) UpdateNodeVersion( ctx context.Context, request drip.UpdateNodeVersionRequestObject) (drip.UpdateNodeVersionResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.UpdateNodeVersion")() // Update node version updateOne := mapper.ApiUpdateNodeVersionToUpdateFields(request.VersionId, request.Body, s.Client) @@ -537,6 +554,7 @@ func (s *DripStrictServerImplementation) UpdateNodeVersion( // PostNodeVersionReview implements drip.StrictServerInterface. func (s *DripStrictServerImplementation) PostNodeReview(ctx context.Context, request drip.PostNodeReviewRequestObject) (drip.PostNodeReviewResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.PostNodeReview")() if request.Params.Star < 1 || request.Params.Star > 5 { log.Ctx(ctx).Error().Msgf("Invalid star received: %d", request.Params.Star) @@ -563,6 +581,7 @@ func (s *DripStrictServerImplementation) PostNodeReview(ctx context.Context, req func (s *DripStrictServerImplementation) DeleteNodeVersion( ctx context.Context, request drip.DeleteNodeVersionRequestObject) (drip.DeleteNodeVersionResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.DeleteNodeVersion")() nodeVersion, err := s.RegistryService.GetNodeVersion(ctx, s.Client, request.VersionId) if err != nil { @@ -582,6 +601,7 @@ func (s *DripStrictServerImplementation) DeleteNodeVersion( func (s *DripStrictServerImplementation) GetNodeVersion( ctx context.Context, request drip.GetNodeVersionRequestObject) (drip.GetNodeVersionResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetNodeVersion")() nodeVersion, err := s.RegistryService.GetNodeVersionByVersion(ctx, s.Client, request.NodeId, request.VersionId) if ent.IsNotFound(err) { @@ -604,6 +624,7 @@ func (s *DripStrictServerImplementation) GetNodeVersion( func (s *DripStrictServerImplementation) ListPersonalAccessTokens( ctx context.Context, request drip.ListPersonalAccessTokensRequestObject) (drip.ListPersonalAccessTokensResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListPersonalAccessTokens")() // List personal access tokens personalAccessTokens, err := s.RegistryService.ListPersonalAccessTokens(ctx, s.Client, request.PublisherId) @@ -625,6 +646,7 @@ func (s *DripStrictServerImplementation) ListPersonalAccessTokens( func (s *DripStrictServerImplementation) CreatePersonalAccessToken( ctx context.Context, request drip.CreatePersonalAccessTokenRequestObject) (drip.CreatePersonalAccessTokenResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.CreatePersonalAccessToken")() // Create personal access token description := "" @@ -649,6 +671,7 @@ func (s *DripStrictServerImplementation) CreatePersonalAccessToken( func (s *DripStrictServerImplementation) DeletePersonalAccessToken( ctx context.Context, request drip.DeletePersonalAccessTokenRequestObject) (drip.DeletePersonalAccessTokenResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.DeletePersonalAccessToken")() // Retrieve user ID from context userId, err := mapper.GetUserIDFromContext(ctx) @@ -688,6 +711,8 @@ func (s *DripStrictServerImplementation) DeletePersonalAccessToken( func (s *DripStrictServerImplementation) InstallNode( ctx context.Context, request drip.InstallNodeRequestObject) (drip.InstallNodeResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.InstallNode")() + // TODO(robinhuang): Refactor to separate class // Get node node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) @@ -761,6 +786,7 @@ func (s *DripStrictServerImplementation) InstallNode( func (s *DripStrictServerImplementation) GetPermissionOnPublisherNodes( ctx context.Context, request drip.GetPermissionOnPublisherNodesRequestObject) (drip.GetPermissionOnPublisherNodesResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetPermissionOnPublisherNodes")() err := s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, request.NodeId) if err != nil { @@ -778,6 +804,8 @@ func (s *DripStrictServerImplementation) GetPermissionOnPublisher( // BanPublisher implements drip.StrictServerInterface. func (s *DripStrictServerImplementation) BanPublisher(ctx context.Context, request drip.BanPublisherRequestObject) (drip.BanPublisherResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.BanPublisher")() + userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) @@ -814,6 +842,8 @@ func (s *DripStrictServerImplementation) BanPublisher(ctx context.Context, reque // BanPublisherNode implements drip.StrictServerInterface. func (s *DripStrictServerImplementation) BanPublisherNode(ctx context.Context, request drip.BanPublisherNodeRequestObject) (drip.BanPublisherNodeResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.BanPublisherNode")() + userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) @@ -851,6 +881,8 @@ func (s *DripStrictServerImplementation) BanPublisherNode(ctx context.Context, r func (s *DripStrictServerImplementation) AdminUpdateNodeVersion( ctx context.Context, request drip.AdminUpdateNodeVersionRequestObject) (drip.AdminUpdateNodeVersionResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.AdminUpdateNodeVersion")() + userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) @@ -896,6 +928,8 @@ func (s *DripStrictServerImplementation) AdminUpdateNodeVersion( func (s *DripStrictServerImplementation) SecurityScan( ctx context.Context, request drip.SecurityScanRequestObject) (drip.SecurityScanResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.SecurityScan")() + minAge := 30 * time.Minute if request.Params.MinAge != nil { minAge = *request.Params.MinAge @@ -929,6 +963,7 @@ func (s *DripStrictServerImplementation) SecurityScan( func (s *DripStrictServerImplementation) ListAllNodeVersions( ctx context.Context, request drip.ListAllNodeVersionsRequestObject) (drip.ListAllNodeVersionsResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ListAllNodeVersions")() // Default values for pagination page := 1 @@ -1005,6 +1040,8 @@ func (s *DripStrictServerImplementation) ListAllNodeVersions( } func (s *DripStrictServerImplementation) ReindexNodes(ctx context.Context, request drip.ReindexNodesRequestObject) (res drip.ReindexNodesResponseObject, err error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ReindexNodes")() + reindexCtx := drip_logging.ReuseContextLogger(ctx, context.Background()) err = s.RegistryService.ReindexAllNodesBackground(reindexCtx, s.Client) if err != nil { @@ -1019,6 +1056,7 @@ func (s *DripStrictServerImplementation) ReindexNodes(ctx context.Context, reque // CreateComfyNodes bulk-stores comfy-nodes extraction result for a node version func (impl *DripStrictServerImplementation) CreateComfyNodes( ctx context.Context, request drip.CreateComfyNodesRequestObject) (res drip.CreateComfyNodesResponseObject, err error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.CreateComfyNodes")() cb := mapper.ApiComfyNodeCloudBuildToDbComfyNodeCloudBuild(request.Body.CloudBuildInfo) // Check if extraction was marked as unsuccessful @@ -1065,6 +1103,7 @@ func (impl *DripStrictServerImplementation) CreateComfyNodes( // GetComfyNode returns a specific comfy-node of a certain node version func (impl *DripStrictServerImplementation) GetComfyNode( ctx context.Context, request drip.GetComfyNodeRequestObject) (res drip.GetComfyNodeResponseObject, err error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetComfyNode")() // Retrieve the comfy-node from the registry n, err := impl.RegistryService.GetComfyNode(ctx, impl.Client, request.NodeId, request.Version, request.ComfyNodeId) @@ -1091,6 +1130,7 @@ func (impl *DripStrictServerImplementation) GetComfyNode( // ComfyNodesBackfill triggers a backfill process for comfy-nodes func (impl *DripStrictServerImplementation) ComfyNodesBackfill( ctx context.Context, request drip.ComfyNodesBackfillRequestObject) (drip.ComfyNodesBackfillResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.ComfyNodesBackfill")() // Trigger the backfill process with a specified maximum node err := impl.RegistryService.TriggerComfyNodesBackfill(ctx, impl.Client, request.Params.MaxNode) diff --git a/server/implementation/user.go b/server/implementation/user.go index 16e2ada..80bbfb8 100644 --- a/server/implementation/user.go +++ b/server/implementation/user.go @@ -5,11 +5,14 @@ import ( "registry-backend/drip" "registry-backend/ent/user" "registry-backend/mapper" + "registry-backend/tracing" "github.com/rs/zerolog/log" ) func (impl *DripStrictServerImplementation) GetUser(ctx context.Context, request drip.GetUserRequestObject) (drip.GetUserResponseObject, error) { + defer tracing.TraceDefaultSegment(ctx, "DripStrictServerImplementation.GetUser")() + userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { log.Ctx(ctx).Error().Stack().Err(err).Msg("") diff --git a/server/middleware/authentication/firebase_auth.go b/server/middleware/authentication/firebase_auth.go index 9fe6383..99190e5 100644 --- a/server/middleware/authentication/firebase_auth.go +++ b/server/middleware/authentication/firebase_auth.go @@ -7,9 +7,9 @@ import ( "regexp" "registry-backend/db" "registry-backend/ent" + "registry-backend/tracing" "strings" - "github.com/newrelic/go-agent/v3/newrelic" "github.com/rs/zerolog/log" firebase "firebase.google.com/go" @@ -51,10 +51,8 @@ func FirebaseAuthMiddleware(entClient *ent.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx echo.Context) error { - if txn := newrelic.FromContext(ctx.Request().Context()); txn != nil { - segment := txn.StartSegment("FirebaseAuthMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx.Request().Context(), "FirebaseAuthMiddleware")() + // Check if the request is in the allow list. reqPath := ctx.Request().URL.Path reqMethod := ctx.Request().Method diff --git a/server/middleware/authentication/jwt_admin_auth.go b/server/middleware/authentication/jwt_admin_auth.go index 84e9522..474918b 100644 --- a/server/middleware/authentication/jwt_admin_auth.go +++ b/server/middleware/authentication/jwt_admin_auth.go @@ -6,11 +6,11 @@ import ( "net/http" "regexp" "registry-backend/ent" + "registry-backend/tracing" "strings" "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v4" - "github.com/newrelic/go-agent/v3/newrelic" ) // JWTAdminAuthMiddleware checks for a JWT token in the Authorization header, @@ -34,10 +34,8 @@ func JWTAdminAuthMiddleware(entClient *ent.Client, secret string) echo.Middlewar return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if txn := newrelic.FromContext(c.Request().Context()); txn != nil { - segment := txn.StartSegment("JWTAdminAuthMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(c.Request().Context(), "JWTAdminAuthMiddleware")() + reqPath := c.Request().URL.Path // Check if the request path matches any of the protected endpoints diff --git a/server/middleware/authentication/service_account_auth.go b/server/middleware/authentication/service_account_auth.go index c3bc649..d8333e5 100644 --- a/server/middleware/authentication/service_account_auth.go +++ b/server/middleware/authentication/service_account_auth.go @@ -4,10 +4,10 @@ import ( "net/http" "os" "regexp" + "registry-backend/tracing" "strings" "github.com/labstack/echo/v4" - "github.com/newrelic/go-agent/v3/newrelic" "github.com/rs/zerolog/log" "google.golang.org/api/idtoken" ) @@ -23,10 +23,8 @@ func ServiceAccountAuthMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx echo.Context) error { - if txn := newrelic.FromContext(ctx.Request().Context()); txn != nil { - segment := txn.StartSegment("ServiceAccountAuthMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx.Request().Context(), "ServiceAccountAuthMiddleware")() + // Check if the request reqPath and method are in the checklist reqPath := ctx.Request().URL.Path reqMethod := ctx.Request().Method diff --git a/server/middleware/metric/metric_middleware.go b/server/middleware/metric/metric_middleware.go index 7c3f235..3d086c2 100644 --- a/server/middleware/metric/metric_middleware.go +++ b/server/middleware/metric/metric_middleware.go @@ -4,6 +4,7 @@ import ( "context" "os" "registry-backend/config" + "registry-backend/tracing" "strconv" "sync" "sync/atomic" @@ -14,7 +15,6 @@ import ( "cloud.google.com/go/monitoring/apiv3/v2/monitoringpb" "github.com/labstack/echo/v4" - "github.com/newrelic/go-agent/v3/newrelic" "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -37,10 +37,7 @@ func init() { func MetricsMiddleware(client *monitoring.MetricClient, config *config.Config) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if txn := newrelic.FromContext(c.Request().Context()); txn != nil { - segment := txn.StartSegment("MetricsMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(c.Request().Context(), "MetricsMiddleware")() startTime := time.Now() err := next(c) endTime := time.Now() diff --git a/server/middleware/request_logger.go b/server/middleware/request_logger.go index a87e4c2..3c0f0c5 100644 --- a/server/middleware/request_logger.go +++ b/server/middleware/request_logger.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "registry-backend/tracing" "github.com/labstack/echo/v4" echo_middleware "github.com/labstack/echo/v4/middleware" @@ -75,10 +76,7 @@ func RequestLoggerMiddleware() echo.MiddlewareFunc { mw := func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if txn := newrelic.FromContext(c.Request().Context()); txn != nil { - segment := txn.StartSegment("RequestLoggerMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(c.Request().Context(), "RequestLoggerMiddleware")() req := c.Request() reader := &teeReader{ReadCloser: req.Body} diff --git a/server/middleware/response_logger.go b/server/middleware/response_logger.go index 90cd430..75cb3d4 100644 --- a/server/middleware/response_logger.go +++ b/server/middleware/response_logger.go @@ -4,9 +4,9 @@ import ( "bytes" "fmt" "net/http" + "registry-backend/tracing" "github.com/labstack/echo/v4" - "github.com/newrelic/go-agent/v3/newrelic" "github.com/rs/zerolog/log" ) @@ -30,10 +30,8 @@ func (rw *responseWriter) Write(p []byte) (n int, err error) { func ResponseLoggerMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if txn := newrelic.FromContext(c.Request().Context()); txn != nil { - segment := txn.StartSegment("ResponseLoggerMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(c.Request().Context(), "DripStrictServerImplementation.ResponseLoggerMiddleware")() + // Create a custom response writer to capture the response body rw := &responseWriter{ ResponseWriter: c.Response().Writer, diff --git a/server/middleware/tracing_middleware.go b/server/middleware/tracing_middleware.go index 412d304..d191ae9 100644 --- a/server/middleware/tracing_middleware.go +++ b/server/middleware/tracing_middleware.go @@ -3,10 +3,10 @@ package middleware import ( "context" drip_logging "registry-backend/logging" + "registry-backend/tracing" "github.com/google/uuid" "github.com/labstack/echo/v4" - "github.com/newrelic/go-agent/v3/newrelic" "github.com/rs/zerolog" ) @@ -20,10 +20,7 @@ func generateFallbackCorrelationID() string { // TracingMiddleware is a middleware that adds a trace ID to the context func TracingMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if txn := newrelic.FromContext(c.Request().Context()); txn != nil { - segment := txn.StartSegment("TracingMiddleware") - defer segment.End() - } + defer tracing.TraceDefaultSegment(c.Request().Context(), "TracingMiddleware")() traceID := c.Request().Header.Get("X-Cloud-Trace-Context") diff --git a/services/comfy_ci/comfy_ci_svc.go b/services/comfy_ci/comfy_ci_svc.go index 1d61b24..60a8e65 100644 --- a/services/comfy_ci/comfy_ci_svc.go +++ b/services/comfy_ci/comfy_ci_svc.go @@ -10,6 +10,7 @@ import ( "registry-backend/ent/gitcommit" "registry-backend/mapper" drip_metric "registry-backend/server/middleware/metric" + "registry-backend/tracing" "strings" "time" @@ -34,6 +35,8 @@ func NewComfyCIService(config *config.Config) *ComfyCIService { // ProcessCIRequest handles the incoming request and creates/updates the necessary entities. func (s *ComfyCIService) ProcessCIRequest(ctx context.Context, client *ent.Client, req *drip.PostUploadArtifactRequestObject) error { + defer tracing.TraceDefaultSegment(ctx, "ComfyCIService.ProcessCIRequest")() + return db.WithTx(ctx, client, func(tx *ent.Tx) error { existingCommit, err := tx.GitCommit.Query().Where(gitcommit.CommitHashEQ(req.Body.CommitHash)).Where(gitcommit.RepoNameEQ(req.Body.Repo)).Only(ctx) if ent.IsNotSingular(err) { @@ -119,6 +122,8 @@ func (s *ComfyCIService) ProcessCIRequest(ctx context.Context, client *ent.Clien // UpsertCommit creates or updates a GitCommit entity. func (s *ComfyCIService) UpsertCommit(ctx context.Context, client *ent.Client, hash, branchName, repoName, commitIsoTime, commitMessage, prNumber, author string) (uuid.UUID, error) { + defer tracing.TraceDefaultSegment(ctx, "ComfyCIService.UpsertCommit")() + log.Ctx(ctx).Info().Msgf("Upserting commit %s", hash) commitTime, err := time.Parse(time.RFC3339, commitIsoTime) if err != nil { @@ -149,6 +154,8 @@ func (s *ComfyCIService) UpsertCommit(ctx context.Context, client *ent.Client, h // UpsertRunResult creates or updates a ActionRunResult entity. func (s *ComfyCIService) UpsertRunResult(ctx context.Context, client *ent.Client, gitcommit *ent.GitCommit, os, cudaVersion, workflowName, runId, jobId string, startTime, endTime int64, avgVram, peakVram int, pythonVersion, pytorchVersion, jobTriggerUser, comfyRunFlags string, status drip.WorkflowRunStatus, machineStats *drip.MachineStats) (uuid.UUID, error) { + defer tracing.TraceDefaultSegment(ctx, "ComfyCIService.UpsertRunResult")() + log.Ctx(ctx).Info().Msgf("Upserting workflow result for commit %s", gitcommit.CommitHash) dbWorkflowRunStatus, err := mapper.ApiWorkflowRunStatusToDb(status) if err != nil { @@ -180,6 +187,8 @@ func (s *ComfyCIService) UpsertRunResult(ctx context.Context, client *ent.Client } func (s *ComfyCIService) UpdateWorkflowResult(ctx context.Context, client *ent.Client, id uuid.UUID, status drip.WorkflowRunStatus, files []*drip.StorageFile) error { + defer tracing.TraceDefaultSegment(ctx, "ComfyCIService.UpdateWorkflowResult")() + dbWorkflowRunStatus, err := mapper.ApiWorkflowRunStatusToDb(status) if err != nil { return err @@ -199,6 +208,8 @@ func (s *ComfyCIService) UpdateWorkflowResult(ctx context.Context, client *ent.C // UpsertStorageFile creates or updates a RunFile entity. func (s *ComfyCIService) UpsertStorageFile(ctx context.Context, client *ent.Client, publicUrl, bucketName, filePath, fileType string) (*ent.StorageFile, error) { + defer tracing.TraceDefaultSegment(ctx, "ComfyCIService.UpsertStorageFile")() + log.Ctx(ctx).Info().Msgf("Upserting storage file for URL %s", publicUrl) return client.StorageFile. Create(). @@ -217,6 +228,8 @@ type ObjectInfo struct { // GetPublicUrlForOutputFiles downloads the artifact, extracts it, and uploads each file to GCS func GetPublicUrlForOutputFiles(ctx context.Context, bucketName, objects string) ([]ObjectInfo, error) { + defer tracing.TraceDefaultSegment(ctx, "GetPublicUrlForOutputFiles")() + objectArr := strings.Split(objects, ",") var result []ObjectInfo for _, object := range objectArr { diff --git a/services/registry/registry_svc.go b/services/registry/registry_svc.go index a065dee..10eb127 100644 --- a/services/registry/registry_svc.go +++ b/services/registry/registry_svc.go @@ -30,6 +30,7 @@ import ( "registry-backend/gateways/storage" "registry-backend/mapper" drip_metric "registry-backend/server/middleware/metric" + "registry-backend/tracing" "strings" "sync" "time" @@ -67,16 +68,13 @@ func NewRegistryService(storageSvc storage.StorageService, pubsubService pubsub. // ListNodes retrieves a paginated list of nodes with optional filtering. func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, page, limit int, filter *entity.NodeFilter) (*entity.ListNodesResult, error) { // Start New Relic transaction segment - var txn *newrelic.Transaction - if txnCtx := newrelic.FromContext(ctx); txnCtx != nil { - txn = txnCtx + txn, deferer := tracing.TraceSegment(ctx, "RegistryService.ListNodes", func(txn *newrelic.Transaction) { txn.Application().RecordCustomMetric( "Custom/ListNodes/Limit", float64(limit), ) - segment := txn.StartSegment("RegistryService.ListNodes") - defer segment.End() - } + }) + defer deferer() // Ensure valid pagination parameters if page < 1 { @@ -180,10 +178,8 @@ func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, pag // ListPublishers queries the Publisher table with an optional user ID filter via PublisherPermission func (s *RegistryService) ListPublishers(ctx context.Context, client *ent.Client, filter *entity.PublisherFilter) ([]*ent.Publisher, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.ListPublishers") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.ListPublishers")() + log.Ctx(ctx).Info().Msg("Listing publishers") query := client.Publisher.Query() @@ -206,10 +202,8 @@ func (s *RegistryService) ListPublishers(ctx context.Context, client *ent.Client } func (s *RegistryService) CreatePublisher(ctx context.Context, client *ent.Client, userId string, publisher *drip.Publisher) (*ent.Publisher, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.CreatePublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.CreatePublisher")() + publisherValid := mapper.ValidatePublisher(publisher) if publisherValid != nil { return nil, fmt.Errorf("invalid publisher: %w", publisherValid) @@ -237,10 +231,8 @@ func (s *RegistryService) CreatePublisher(ctx context.Context, client *ent.Clien } func (s *RegistryService) UpdatePublisher(ctx context.Context, client *ent.Client, update *ent.PublisherUpdateOne) (*ent.Publisher, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.UpdatePublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.UpdatePublisher")() + log.Ctx(ctx).Info().Msgf("updating publisher fields: %v", update.Mutation().Fields()) publisher, err := update.Save(ctx) log.Ctx(ctx).Info().Msgf("success: updated publisher: %v", publisher) @@ -252,10 +244,8 @@ func (s *RegistryService) UpdatePublisher(ctx context.Context, client *ent.Clien } func (s *RegistryService) GetPublisher(ctx context.Context, client *ent.Client, publisherID string) (*ent.Publisher, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.GetPublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.GetPublisher")() + log.Ctx(ctx).Info().Msgf("getting publisher: %v", publisherID) publisher, err := client.Publisher. Query(). @@ -271,10 +261,8 @@ func (s *RegistryService) GetPublisher(ctx context.Context, client *ent.Client, } func (s *RegistryService) CreatePersonalAccessToken(ctx context.Context, client *ent.Client, publisherID, name, description string) (*ent.PersonalAccessToken, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.CreatePersonalAccessToken") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.CreatePersonalAccessToken")() + log.Ctx(ctx).Info().Msgf("creating personal access token for publisher: %v", publisherID) token := uuid.New().String() pat, err := client.PersonalAccessToken. @@ -292,10 +280,8 @@ func (s *RegistryService) CreatePersonalAccessToken(ctx context.Context, client } func (s *RegistryService) ListPersonalAccessTokens(ctx context.Context, client *ent.Client, publisherID string) ([]*ent.PersonalAccessToken, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.ListPersonalAccessTokens") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.ListPersonalAccessTokens")() + pats, err := client.PersonalAccessToken.Query(). Where(personalaccesstoken.PublisherIDEQ(publisherID)). All(ctx) @@ -306,10 +292,8 @@ func (s *RegistryService) ListPersonalAccessTokens(ctx context.Context, client * } func (s *RegistryService) DeletePersonalAccessToken(ctx context.Context, client *ent.Client, tokenID uuid.UUID) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.DeletePersonalAccessToken") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.DeletePersonalAccessToken")() + log.Ctx(ctx).Info().Msgf("deleting personal access token: %v", tokenID) err := client.PersonalAccessToken. DeleteOneID(tokenID). @@ -321,10 +305,8 @@ func (s *RegistryService) DeletePersonalAccessToken(ctx context.Context, client } func (s *RegistryService) CreateNode(ctx context.Context, client *ent.Client, publisherId string, node *drip.Node) (*ent.Node, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.CreateNode") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.CreateNode")() + validNode := mapper.ValidateNode(node) if validNode != nil { return nil, fmt.Errorf("invalid node: %w", validNode) @@ -355,10 +337,8 @@ func (s *RegistryService) CreateNode(ctx context.Context, client *ent.Client, pu } func (s *RegistryService) UpdateNode(ctx context.Context, client *ent.Client, updateFunc func(client *ent.Client) *ent.NodeUpdateOne) (*ent.Node, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.UpdateNode") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.UpdateNode")() + var n *ent.Node err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { update := updateFunc(tx.Client()) @@ -380,10 +360,7 @@ func (s *RegistryService) UpdateNode(ctx context.Context, client *ent.Client, up } func (s *RegistryService) GetNode(ctx context.Context, client *ent.Client, nodeID string) (*ent.Node, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.GetNode") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.GetNode")() log.Ctx(ctx).Info().Msgf("getting node: %v", nodeID) node, err := client.Node.Get(ctx, nodeID) @@ -394,10 +371,8 @@ func (s *RegistryService) GetNode(ctx context.Context, client *ent.Client, nodeI } func (s *RegistryService) CreateNodeVersion(ctx context.Context, client *ent.Client, publisherID, nodeID string, nodeVersion *drip.NodeVersion) (*NodeVersionCreation, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.CreateNodeVersion") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.CreateNodeVersion")() + log.Ctx(ctx).Info().Msgf("creating node version: %v for nodeId %v", nodeVersion, nodeID) bucketName := "comfy-registry" return db.WithTxResult(ctx, client, func(tx *ent.Tx) (*NodeVersionCreation, error) { @@ -464,10 +439,8 @@ type NodeVersionCreation struct { } func (s *RegistryService) ListNodeVersions(ctx context.Context, client *ent.Client, filter *entity.NodeVersionFilter) (*entity.ListNodeVersionsResult, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.ListNodeVersions") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.ListNodeVersions")() + query := client.NodeVersion.Query(). WithStorageFile(). Order(ent.Desc(nodeversion.FieldVersion)) @@ -535,10 +508,8 @@ func (s *RegistryService) ListNodeVersions(ctx context.Context, client *ent.Clie } func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client, nodeId, userID string, star int) (n *ent.Node, err error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.AddNodeReview") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.AddNodeReview")() + log.Ctx(ctx).Info().Msgf("add review to node: %v ", nodeId) err = db.WithTx(ctx, client, func(tx *ent.Tx) error { @@ -573,10 +544,8 @@ func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client, } func (s *RegistryService) GetNodeVersionByVersion(ctx context.Context, client *ent.Client, nodeId, nodeVersion string) (*ent.NodeVersion, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.GetNodeVersionByVersion") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.GetNodeVersionByVersion")() + log.Ctx(ctx).Info().Msgf("getting node version %v@%v", nodeId, nodeVersion) return client.NodeVersion. Query(). @@ -587,20 +556,16 @@ func (s *RegistryService) GetNodeVersionByVersion(ctx context.Context, client *e } func (s *RegistryService) GetNodeVersion(ctx context.Context, client *ent.Client, nodeVersionId string) (*ent.NodeVersion, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.GetNodeVersion") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.GetNodeVersion")() + log.Ctx(ctx).Info().Msgf("getting node version %v", nodeVersionId) return client.NodeVersion. Get(ctx, uuid.MustParse(nodeVersionId)) } func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Client, update *ent.NodeVersionUpdateOne) (*ent.NodeVersion, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.UpdateNodeVersion") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.UpdateNodeVersion")() + log.Ctx(ctx).Info().Msgf("updating node version fields: %v", update.Mutation().Fields()) return db.WithTxResult(ctx, client, func(tx *ent.Tx) (*ent.NodeVersion, error) { node, err := update.Save(ctx) @@ -618,10 +583,8 @@ func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Cli } func (s *RegistryService) RecordNodeInstallation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.RecordNodeInstallation") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.RecordNodeInstallation")() + var n *ent.Node err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { n, err = tx.Node.UpdateOne(node).AddTotalInstall(1).Save(ctx) @@ -639,10 +602,8 @@ func (s *RegistryService) RecordNodeInstallation(ctx context.Context, client *en } func (s *RegistryService) GetLatestNodeVersion(ctx context.Context, client *ent.Client, nodeId string) (*ent.NodeVersion, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.GetLatestNodeVersion") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.GetLatestNodeVersion")() + log.Ctx(ctx).Info().Msgf("Getting latest version of node: %v", nodeId) nodeVersion, err := client.NodeVersion. Query(). @@ -679,10 +640,8 @@ func (s *RegistryService) MarkComfyNodeExtractionFailed( nodeVersion string, info *schema.ComfyNodeCloudBuildInfo, ) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.MarkComfyNodeExtractionFailed") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.MarkComfyNodeExtractionFailed")() + u := client.NodeVersion. Update(). Where( @@ -703,10 +662,8 @@ func (s *RegistryService) CreateComfyNodes( comfyNodes map[string]drip.ComfyNode, info *schema.ComfyNodeCloudBuildInfo, ) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.CreateComfyNodes") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.CreateComfyNodes")() + return db.WithTx(ctx, client, func(tx *ent.Tx) error { // Query the NodeVersion with the given nodeID and nodeVersion, lock it for updates nv, err := tx.NodeVersion.Query(). @@ -793,10 +750,8 @@ func (s *RegistryService) GetComfyNode( client *ent.Client, nodeID, nodeVersion, comfyNodeName string, ) (*ent.ComfyNode, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.GetComfyNode") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.GetComfyNode")() + // Query the NodeVersion with the given nodeID and nodeVersion, ensuring extraction status is success nv, err := client.NodeVersion.Query(). Where(nodeversion.VersionEQ(nodeVersion)). @@ -822,10 +777,8 @@ func (s *RegistryService) GetComfyNode( func (s *RegistryService) TriggerComfyNodesBackfill( ctx context.Context, client *ent.Client, max *int) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.TriggerComfyNodesBackfill") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.TriggerComfyNodesBackfill")() + // Query all NodeVersions with pending comfy node extraction status q := client.NodeVersion. Query(). @@ -868,10 +821,8 @@ func (s *RegistryService) AssertPublisherPermissions(ctx context.Context, userID string, permissions []schema.PublisherPermissionType, ) (err error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.AssertPublisherPermissions") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.AssertPublisherPermissions")() + w, err := client.Publisher.Get(ctx, publisherID) if err != nil { return fmt.Errorf("fail to query publisher by id: %s %w", publisherID, err) @@ -896,10 +847,8 @@ func (s *RegistryService) IsPersonalAccessTokenValidForPublisher(ctx context.Con publisherID string, accessToken string, ) (bool, error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.IsPersonalAccessTokenValidForPublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.IsPersonalAccessTokenValidForPublisher")() + w, err := client.Publisher.Get(ctx, publisherID) if err != nil { log.Ctx(ctx).Error().Err(err).Msgf("fail to find publisher by id: %s", publisherID) @@ -920,10 +869,8 @@ func (s *RegistryService) IsPersonalAccessTokenValidForPublisher(ctx context.Con } func (s *RegistryService) AssertNodeBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, nodeID string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.AssertNodeBelongsToPublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.AssertNodeBelongsToPublisher")() + node, err := client.Node.Get(ctx, nodeID) if err != nil { return fmt.Errorf("failed to get node: %w", err) @@ -935,10 +882,8 @@ func (s *RegistryService) AssertNodeBelongsToPublisher(ctx context.Context, clie } func (s *RegistryService) AssertAccessTokenBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, tokenId uuid.UUID) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.AssertAccessTokenBelongsToPublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.AssertAccessTokenBelongsToPublisher")() + pat, err := client.PersonalAccessToken.Query().Where( personalaccesstoken.IDEQ(tokenId), personalaccesstoken.PublisherIDEQ(publisherID), @@ -953,10 +898,8 @@ func (s *RegistryService) AssertAccessTokenBelongsToPublisher(ctx context.Contex } func (s *RegistryService) DeletePublisher(ctx context.Context, client *ent.Client, publisherID string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.DeletePublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.DeletePublisher")() + log.Ctx(ctx).Info().Msgf("deleting publisher: %v", publisherID) return db.WithTx(ctx, client, func(tx *ent.Tx) error { client = tx.Client() @@ -990,10 +933,8 @@ func (s *RegistryService) DeletePublisher(ctx context.Context, client *ent.Clien } func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, nodeID string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.DeleteNode") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.DeleteNode")() + log.Ctx(ctx).Info().Msgf("deleting node: %v", nodeID) db.WithTx(ctx, client, func(tx *ent.Tx) error { nv, err := tx.Client().NodeVersion.Query().Where(nodeversion.NodeID(nodeID)).All(ctx) @@ -1020,10 +961,8 @@ func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, no } func (s *RegistryService) DeleteNodeVersion(ctx context.Context, client *ent.Client, nodeIDVersion string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.DeleteNodeVersion") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.DeleteNodeVersion")() + log.Ctx(ctx).Info().Msgf("deleting node version: %v", nodeIDVersion) db.WithTx(ctx, client, func(tx *ent.Tx) error { nv, err := tx.Client().NodeVersion.Get(ctx, uuid.MustParse(nodeIDVersion)) @@ -1069,10 +1008,8 @@ func IsPermissionError(err error) bool { } func (s *RegistryService) BanPublisher(ctx context.Context, client *ent.Client, id string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.BanPublisher") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.BanPublisher")() + log.Ctx(ctx).Info().Msgf("banning publisher: %v", id) pub, err := client.Publisher.Get(ctx, id) if err != nil { @@ -1124,10 +1061,8 @@ func (s *RegistryService) BanPublisher(ctx context.Context, client *ent.Client, } func (s *RegistryService) BanNode(ctx context.Context, client *ent.Client, publisherid, id string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.BanNode") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.BanNode")() + log.Ctx(ctx).Info().Msgf("banning publisher node: %v %v", publisherid, id) return db.WithTx(ctx, client, func(tx *ent.Tx) error { @@ -1163,10 +1098,8 @@ func (s *RegistryService) BanNode(ctx context.Context, client *ent.Client, publi } func (s *RegistryService) AssertNodeBanned(ctx context.Context, client *ent.Client, nodeID string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.AssertNodeBanned") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.AssertNodeBanned")() + node, err := client.Node.Get(ctx, nodeID) if ent.IsNotFound(err) { return nil @@ -1181,10 +1114,8 @@ func (s *RegistryService) AssertNodeBanned(ctx context.Context, client *ent.Clie } func (s *RegistryService) AssertPublisherBanned(ctx context.Context, client *ent.Client, publisherID string) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.AssertPublisherBanned") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.AssertPublisherBanned")() + publisher, err := client.Publisher.Get(ctx, publisherID) if ent.IsNotFound(err) { return nil @@ -1199,10 +1130,8 @@ func (s *RegistryService) AssertPublisherBanned(ctx context.Context, client *ent } func (s *RegistryService) ReindexAllNodes(ctx context.Context, client *ent.Client) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.ReindexAllNodes") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.ReindexAllNodes")() + log.Ctx(ctx).Info().Msgf("reindexing nodes") nodes, err := s.decorateNodeQueryWithLatestVersion(client.Node.Query()).All(ctx) if err != nil { @@ -1232,10 +1161,8 @@ func (s *RegistryService) ReindexAllNodes(ctx context.Context, client *ent.Clien var reindexLock = sync.Mutex{} func (s *RegistryService) ReindexAllNodesBackground(ctx context.Context, client *ent.Client) (err error) { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.ReindexAllNodesBackground") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.ReindexAllNodesBackground")() + if !reindexLock.TryLock() { return fmt.Errorf("another reindex is in progress") } @@ -1286,10 +1213,8 @@ func (s *RegistryService) decorateNodeQueryWithLatestVersion(q *ent.NodeQuery) * func (s *RegistryService) PerformSecurityCheck( ctx context.Context, client *ent.Client, nodeVersion *ent.NodeVersion) error { - if txn := newrelic.FromContext(ctx); txn != nil { - segment := txn.StartSegment("RegistryService.PerformSecurityCheck") - defer segment.End() - } + defer tracing.TraceDefaultSegment(ctx, "RegistryService.PerformSecurityCheck")() + log.Ctx(ctx).Info().Msgf("Scanning node %s@%s w/ version ID: %s", nodeVersion.NodeID, nodeVersion.Version, nodeVersion.ID) diff --git a/tracing/tracing.go b/tracing/tracing.go new file mode 100644 index 0000000..8d70940 --- /dev/null +++ b/tracing/tracing.go @@ -0,0 +1,24 @@ +package tracing + +import ( + "context" + + "github.com/newrelic/go-agent/v3/newrelic" +) + +func TraceDefaultSegment(ctx context.Context, segmentName string) func() { + _, f := TraceSegment(ctx, segmentName) + return f +} + +func TraceSegment(ctx context.Context, segmentName string, opts ...func(*newrelic.Transaction)) (*newrelic.Transaction, func()) { + txn := newrelic.FromContext(ctx) + if txn == nil { + return nil, func() {} + } + for _, opt := range opts { + opt(txn) + } + segment := txn.StartSegment(segmentName) + return txn, func() { segment.End() } +}