Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor trace exporter to utilize enrollment details from knapsack #2122

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions cmd/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,6 @@ func runLauncher(ctx context.Context, cancel func(), multiSlogger, systemMultiSl
"build", versionInfo.Revision,
)

if traceExporter != nil {
traceExporter.SetOsqueryClient(osqueryRunner)
}

// Create the control service and services that depend on it
var runner *desktopRunner.DesktopUsersProcessesRunner
var actionsQueue *actionqueue.ActionQueue
Expand Down
102 changes: 48 additions & 54 deletions pkg/traces/exporter/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package exporter

import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -35,18 +34,13 @@ var archAttributeMap = map[string]attribute.KeyValue{
"arm": semconv.HostArchARM32,
}

var osqueryClientRecheckInterval = 30 * time.Second

type querier interface {
Query(query string) ([]map[string]string, error)
}
var enrollmentDetailsRecheckInterval = 5 * time.Second

type TraceExporter struct {
provider *sdktrace.TracerProvider
providerLock sync.Mutex
bufSpanProcessor *bufspanprocessor.BufSpanProcessor
knapsack types.Knapsack
osqueryClient querier
slogger *slog.Logger
attrs []attribute.KeyValue // resource attributes, identifying this device + installation
attrLock sync.RWMutex
Expand Down Expand Up @@ -113,15 +107,18 @@ func NewTraceExporter(ctx context.Context, k types.Knapsack, initialTraceBuffer

t.addDeviceIdentifyingAttributes()

return t, nil
}

func (t *TraceExporter) SetOsqueryClient(client querier) {
t.osqueryClient = client
// Check if enrollment details are already available, add them immediately if so
enrollmentDetails := t.knapsack.GetEnrollmentDetails()
if hasRequiredEnrollmentDetails(enrollmentDetails) {
t.addAttributesFromEnrollmentDetails(enrollmentDetails)
} else {
// Launch a goroutine to wait for enrollment details
gowrapper.Go(context.TODO(), t.slogger, func() {
t.addAttributesFromOsquery()
})
}

gowrapper.Go(context.TODO(), t.slogger, func() {
t.addAttributesFromOsquery()
})
return t, nil
}

// addDeviceIdentifyingAttributes gets device identifiers from the server-provided
Expand Down Expand Up @@ -168,64 +165,61 @@ func (t *TraceExporter) addDeviceIdentifyingAttributes() {
}
}

// addAttributesFromOsquery retrieves device and OS details from osquery and adds them
// to our resource attributes. Since this is called on startup when the osquery client
// may not be ready yet, we perform a few retries.
func (t *TraceExporter) addAttributesFromOsquery() {
t.attrLock.Lock()
defer t.attrLock.Unlock()
// hasRequiredEnrollmentDetails checks if the provided enrollment details contain
// all the required fields for adding trace attributes
func hasRequiredEnrollmentDetails(details types.EnrollmentDetails) bool {
// Check that all required fields have values
return details.OsqueryVersion != "" &&
details.OSName != "" &&
details.OSVersion != "" &&
details.Hostname != ""
}

osqueryInfoQuery := `
SELECT
osquery_info.version as osquery_version,
os_version.name as os_name,
os_version.version as os_version,
system_info.hostname
FROM
os_version,
system_info,
osquery_info;
`

// The osqueryd client may not have initialized yet, so retry for up to three minutes on error.
var resp []map[string]string
var err error
// addAttributesFromOsquery waits for enrollment details to be available
// and then adds the relevant attributes
func (t *TraceExporter) addAttributesFromOsquery() {
// Wait until enrollment details are available
retryTimeout := time.Now().Add(3 * time.Minute)
for {
if time.Now().After(retryTimeout) {
err = errors.New("could not get osquery details before timeout")
break
t.slogger.Log(context.TODO(), slog.LevelWarn,
"could not get enrollment details before timeout",
)
return
}

resp, err = t.osqueryClient.Query(osqueryInfoQuery)
if err == nil && len(resp) > 0 {
break
enrollmentDetails := t.knapsack.GetEnrollmentDetails()
if hasRequiredEnrollmentDetails(enrollmentDetails) {
t.addAttributesFromEnrollmentDetails(enrollmentDetails)
return
}

select {
case <-t.ctx.Done():
t.slogger.Log(context.TODO(), slog.LevelDebug,
"trace exporter interrupted while waiting to add osquery attributes",
"trace exporter interrupted while waiting for enrollment details",
)
return
case <-time.After(osqueryClientRecheckInterval):
case <-time.After(enrollmentDetailsRecheckInterval):
continue
}
}
}

if err != nil || len(resp) == 0 {
t.slogger.Log(context.TODO(), slog.LevelWarn,
"trace exporter could not fetch osquery attributes",
"err", err,
)
return
}
func (t *TraceExporter) addAttributesFromEnrollmentDetails(details types.EnrollmentDetails) {
t.attrLock.Lock()
defer t.attrLock.Unlock()

// Add OS and system attributes from enrollment details
t.attrs = append(t.attrs,
attribute.String("launcher.osquery_version", resp[0]["osquery_version"]),
semconv.OSName(resp[0]["os_name"]),
semconv.OSVersion(resp[0]["os_version"]),
semconv.HostName(resp[0]["hostname"]),
attribute.String("launcher.osquery_version", details.OsqueryVersion),
semconv.OSName(details.OSName),
semconv.OSVersion(details.OSVersion),
semconv.HostName(details.Hostname),
)

t.slogger.Log(context.TODO(), slog.LevelDebug,
"added attributes from enrollment details",
)
}

Expand Down
67 changes: 20 additions & 47 deletions pkg/traces/exporter/exporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
storageci "github.com/kolide/launcher/ee/agent/storage/ci"
"github.com/kolide/launcher/ee/agent/types"
typesmocks "github.com/kolide/launcher/ee/agent/types/mocks"
"github.com/kolide/launcher/ee/localserver/mocks"
"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/kolide/launcher/pkg/traces/bufspanprocessor"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -46,19 +45,14 @@ func TestNewTraceExporter(t *testing.T) { //nolint:paralleltest
mockKnapsack.On("TraceBatchTimeout").Return(1 * time.Minute)
mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.ExportTraces, keys.TraceSamplingRate, keys.TraceIngestServerURL, keys.DisableTraceIngestTLS, keys.TraceBatchTimeout).Return(nil)
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger())

osqueryClient := mocks.NewQuerier(t)
osqueryClient.On("Query", mock.Anything).Return([]map[string]string{
{
"osquery_version": "5.8.0",
"os_name": runtime.GOOS,
"os_version": "3.4.5",
"hostname": "Test-Hostname2",
},
}, nil)
mockKnapsack.On("GetEnrollmentDetails").Return(types.EnrollmentDetails{
OsqueryVersion: "5.8.0",
OSName: runtime.GOOS,
OSVersion: "3.4.5",
Hostname: "Test-Hostname2",
})

traceExporter, err := NewTraceExporter(context.Background(), mockKnapsack, NewInitialTraceBuffer())
traceExporter.SetOsqueryClient(osqueryClient)
require.NoError(t, err)

// Wait a few seconds to allow the osquery queries to go through
Expand All @@ -75,7 +69,6 @@ func TestNewTraceExporter(t *testing.T) { //nolint:paralleltest
require.NotNil(t, traceExporter.provider, "expected provider to be created")

mockKnapsack.AssertExpectations(t)
osqueryClient.AssertExpectations(t)
}

func TestNewTraceExporter_exportNotEnabled(t *testing.T) {
Expand Down Expand Up @@ -189,7 +182,6 @@ func Test_addDeviceIdentifyingAttributes(t *testing.T) {

traceExporter := &TraceExporter{
knapsack: mockKnapsack,
osqueryClient: mocks.NewQuerier(t),
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand Down Expand Up @@ -231,19 +223,16 @@ func Test_addAttributesFromOsquery(t *testing.T) {
expectedOsVersion := "1.2.3"
expectedHostname := "Test-Hostname"

osqueryClient := mocks.NewQuerier(t)
osqueryClient.On("Query", mock.Anything).Return([]map[string]string{
{
"osquery_version": expectedOsqueryVersion,
"os_name": expectedOsName,
"os_version": expectedOsVersion,
"hostname": expectedHostname,
},
}, nil)
mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("GetEnrollmentDetails").Return(types.EnrollmentDetails{
OsqueryVersion: expectedOsqueryVersion,
OSName: expectedOsName,
OSVersion: expectedOsVersion,
Hostname: expectedHostname,
})

traceExporter := &TraceExporter{
knapsack: typesmocks.NewKnapsack(t),
osqueryClient: osqueryClient,
knapsack: mockKnapsack,
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand Down Expand Up @@ -274,7 +263,6 @@ func Test_addAttributesFromOsquery(t *testing.T) {
}
}

osqueryClient.AssertExpectations(t)
}

func TestPing(t *testing.T) {
Expand Down Expand Up @@ -362,25 +350,20 @@ func TestFlagsChanged_ExportTraces(t *testing.T) { //nolint:paralleltest
mockKnapsack.On("TraceIngestServerURL").Return("https://example.com")
}

osqueryClient := mocks.NewQuerier(t)

if tt.shouldReplaceProvider {
mockKnapsack.On("ServerProvidedDataStore").Return(s)
osqueryClient.On("Query", mock.Anything).Return([]map[string]string{
{
"osquery_version": "5.9.0",
"os_name": "Windows",
"os_version": "11",
"hostname": "Test device",
},
}, nil)
mockKnapsack.On("GetEnrollmentDetails").Return(types.EnrollmentDetails{
OsqueryVersion: "5.8.0",
OSName: "Windows",
OSVersion: "11",
Hostname: "Test device",
})
}

ctx, cancel := context.WithCancel(context.Background())
traceExporter := &TraceExporter{
knapsack: mockKnapsack,
bufSpanProcessor: &bufspanprocessor.BufSpanProcessor{},
osqueryClient: osqueryClient,
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand All @@ -400,7 +383,6 @@ func TestFlagsChanged_ExportTraces(t *testing.T) { //nolint:paralleltest

if tt.shouldReplaceProvider {
mockKnapsack.AssertExpectations(t)
osqueryClient.AssertExpectations(t)
require.Greater(t, len(traceExporter.attrs), 0)
require.NotNil(t, traceExporter.provider)
}
Expand Down Expand Up @@ -448,13 +430,11 @@ func TestFlagsChanged_TraceSamplingRate(t *testing.T) { //nolint:paralleltest
if tt.shouldReplaceProvider {
mockKnapsack.On("TraceIngestServerURL").Return("https://example.com")
}
osqueryClient := mocks.NewQuerier(t)

ctx, cancel := context.WithCancel(context.Background())
traceExporter := &TraceExporter{
knapsack: mockKnapsack,
bufSpanProcessor: &bufspanprocessor.BufSpanProcessor{},
osqueryClient: osqueryClient,
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand Down Expand Up @@ -518,13 +498,11 @@ func TestFlagsChanged_TraceIngestServerURL(t *testing.T) { //nolint:paralleltest
t.Run(tt.testName, func(t *testing.T) {
mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("TraceIngestServerURL").Return(tt.newObservabilityIngestServerURL)
osqueryClient := mocks.NewQuerier(t)

ctx, cancel := context.WithCancel(context.Background())
traceExporter := &TraceExporter{
knapsack: mockKnapsack,
bufSpanProcessor: &bufspanprocessor.BufSpanProcessor{},
osqueryClient: osqueryClient,
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand Down Expand Up @@ -591,15 +569,13 @@ func TestFlagsChanged_DisableTraceIngestTLS(t *testing.T) { //nolint:paralleltes
if tt.shouldReplaceProvider {
mockKnapsack.On("TraceIngestServerURL").Return("https://example.com")
}
osqueryClient := mocks.NewQuerier(t)

clientAuthenticator := newClientAuthenticator("test token", tt.currentDisableTraceIngestTLS)

ctx, cancel := context.WithCancel(context.Background())
traceExporter := &TraceExporter{
knapsack: mockKnapsack,
bufSpanProcessor: &bufspanprocessor.BufSpanProcessor{},
osqueryClient: osqueryClient,
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand Down Expand Up @@ -668,13 +644,10 @@ func TestFlagsChanged_TraceBatchTimeout(t *testing.T) { //nolint:paralleltest
mockKnapsack.On("TraceIngestServerURL").Return("https://example.com")
}

osqueryClient := mocks.NewQuerier(t)

ctx, cancel := context.WithCancel(context.Background())
traceExporter := &TraceExporter{
knapsack: mockKnapsack,
bufSpanProcessor: &bufspanprocessor.BufSpanProcessor{},
osqueryClient: osqueryClient,
slogger: multislogger.NewNopLogger(),
attrs: make([]attribute.KeyValue, 0),
attrLock: sync.RWMutex{},
Expand Down
Loading