From 0288928c0c5834580348ef164139f9671538e3bd Mon Sep 17 00:00:00 2001 From: RebeccaMahany Date: Fri, 7 Feb 2025 09:37:49 -0500 Subject: [PATCH] Add new store for ZTA info --- cmd/launcher/launcher.go | 11 ++- ee/agent/knapsack/knapsack.go | 4 ++ ee/agent/storage/bbolt/stores_bbolt.go | 1 + ee/agent/storage/ci/stores_ci.go | 1 + ee/agent/storage/stores.go | 1 + ee/agent/types/mocks/flags.go | 72 ++++++++++---------- ee/agent/types/mocks/knapsack.go | 92 ++++++++++++++++---------- ee/agent/types/stores.go | 1 + 8 files changed, 109 insertions(+), 74 deletions(-) diff --git a/cmd/launcher/launcher.go b/cmd/launcher/launcher.go index 190603efd..6497ef424 100644 --- a/cmd/launcher/launcher.go +++ b/cmd/launcher/launcher.go @@ -74,6 +74,7 @@ const ( desktopMenuSubsystemName = "kolide_desktop_menu" authTokensSubsystemName = "auth_tokens" katcSubsystemName = "katc_config" // Kolide ATC + ztaInfoSubsystemName = "zta_info" ) // runLauncher is the entry point into running launcher. It creates a @@ -482,8 +483,8 @@ func runLauncher(ctx context.Context, cancel func(), multiSlogger, systemMultiSl return fmt.Errorf("failed to register auth token consumer: %w", err) } - // begin log shipping and subsribe to token updates - // nil check incase it failed to create for some reason + // begin log shipping and subscribe to token updates + // nil check in case it failed to create for some reason if logShipper != nil { controlService.RegisterSubscriber(authTokensSubsystemName, logShipper) } @@ -504,6 +505,12 @@ func runLauncher(ctx context.Context, cancel func(), multiSlogger, systemMultiSl // information is made available from server_data (e.g. on a fresh install) metadataWriter.Ping() } + + // Set up consumer to receive ZTA info from the control server + ztaInfoConsumer := keyvalueconsumer.New(k.ZtaInfoStore()) + if err := controlService.RegisterConsumer(ztaInfoSubsystemName, ztaInfoConsumer); err != nil { + return fmt.Errorf("failed to register ZTA info consumer: %w", err) + } } runEECode := k.ControlServerURL() != "" || k.IAmBreakingEELicense() diff --git a/ee/agent/knapsack/knapsack.go b/ee/agent/knapsack/knapsack.go index f52c63ce1..4f3c1f77e 100644 --- a/ee/agent/knapsack/knapsack.go +++ b/ee/agent/knapsack/knapsack.go @@ -176,6 +176,10 @@ func (k *knapsack) LauncherHistoryStore() types.KVStore { return k.getKVStore(storage.LauncherHistoryStore) } +func (k *knapsack) ZtaInfoStore() types.KVStore { + return k.getKVStore(storage.ZtaInfoStore) +} + func (k *knapsack) SetLauncherWatchdogEnabled(enabled bool) error { return k.flags.SetLauncherWatchdogEnabled(enabled) } diff --git a/ee/agent/storage/bbolt/stores_bbolt.go b/ee/agent/storage/bbolt/stores_bbolt.go index 917f0cfa4..dc8c041d6 100644 --- a/ee/agent/storage/bbolt/stores_bbolt.go +++ b/ee/agent/storage/bbolt/stores_bbolt.go @@ -34,6 +34,7 @@ func MakeStores(ctx context.Context, slogger *slog.Logger, db *bbolt.DB) (map[st storage.TokenStore, storage.ControlServerActionsStore, storage.LauncherHistoryStore, + storage.ZtaInfoStore, } for _, storeName := range storeNames { diff --git a/ee/agent/storage/ci/stores_ci.go b/ee/agent/storage/ci/stores_ci.go index e16d72e4a..351a74345 100644 --- a/ee/agent/storage/ci/stores_ci.go +++ b/ee/agent/storage/ci/stores_ci.go @@ -32,6 +32,7 @@ func MakeStores(t *testing.T, slogger *slog.Logger, db *bbolt.DB) (map[storage.S storage.ServerProvidedDataStore, storage.TokenStore, storage.LauncherHistoryStore, + storage.ZtaInfoStore, } if os.Getenv("CI") == "true" { diff --git a/ee/agent/storage/stores.go b/ee/agent/storage/stores.go index 1bc37b5c2..50204501c 100644 --- a/ee/agent/storage/stores.go +++ b/ee/agent/storage/stores.go @@ -19,6 +19,7 @@ const ( TokenStore Store = "token_store" // The store used for holding bearer auth tokens, e.g. the ones used to authenticate with the observability ingest server. ControlServerActionsStore Store = "action_store" // The store used for storing actions sent by control server. LauncherHistoryStore Store = "launcher_history" // The store used for storing launcher start time history currently. + ZtaInfoStore Store = "zta_info" // The store used for storing ZTA info about this device ) func (storeType Store) String() string { diff --git a/ee/agent/types/mocks/flags.go b/ee/agent/types/mocks/flags.go index b8bc68610..a5066e6c3 100644 --- a/ee/agent/types/mocks/flags.go +++ b/ee/agent/types/mocks/flags.go @@ -378,24 +378,6 @@ func (_m *Flags) ForceControlSubsystems() bool { return r0 } -// TableGenerateTimeout provides a mock function with given fields: -func (_m *Flags) TableGenerateTimeout() time.Duration { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for TableGenerateTimeout") - } - - var r0 time.Duration - if rf, ok := ret.Get(0).(func() time.Duration); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(time.Duration) - } - - return r0 -} - // IAmBreakingEELicense provides a mock function with given fields: func (_m *Flags) IAmBreakingEELicense() bool { ret := _m.Called() @@ -1104,24 +1086,6 @@ func (_m *Flags) SetForceControlSubsystems(force bool) error { return r0 } -// SetTableGenerateTimeout provides a mock function with given fields: interval -func (_m *Flags) SetTableGenerateTimeout(interval time.Duration) error { - ret := _m.Called(interval) - - if len(ret) == 0 { - panic("no return value specified for SetTableGenerateTimeout") - } - - var r0 error - if rf, ok := ret.Get(0).(func(time.Duration) error); ok { - r0 = rf(interval) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // SetInModernStandby provides a mock function with given fields: enabled func (_m *Flags) SetInModernStandby(enabled bool) error { ret := _m.Called(enabled) @@ -1397,6 +1361,24 @@ func (_m *Flags) SetSystrayRestartEnabled(enabled bool) error { return r0 } +// SetTableGenerateTimeout provides a mock function with given fields: interval +func (_m *Flags) SetTableGenerateTimeout(interval time.Duration) error { + ret := _m.Called(interval) + + if len(ret) == 0 { + panic("no return value specified for SetTableGenerateTimeout") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Duration) error); ok { + r0 = rf(interval) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SetTraceBatchTimeout provides a mock function with given fields: duration func (_m *Flags) SetTraceBatchTimeout(duration time.Duration) error { ret := _m.Called(duration) @@ -1600,6 +1582,24 @@ func (_m *Flags) SystrayRestartEnabled() bool { return r0 } +// TableGenerateTimeout provides a mock function with given fields: +func (_m *Flags) TableGenerateTimeout() time.Duration { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TableGenerateTimeout") + } + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + // TraceBatchTimeout provides a mock function with given fields: func (_m *Flags) TraceBatchTimeout() time.Duration { ret := _m.Called() diff --git a/ee/agent/types/mocks/knapsack.go b/ee/agent/types/mocks/knapsack.go index 0aa152a79..028a69308 100644 --- a/ee/agent/types/mocks/knapsack.go +++ b/ee/agent/types/mocks/knapsack.go @@ -526,24 +526,6 @@ func (_m *Knapsack) ForceControlSubsystems() bool { return r0 } -// TableGenerateTimeout provides a mock function with given fields: -func (_m *Knapsack) TableGenerateTimeout() time.Duration { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for TableGenerateTimeout") - } - - var r0 time.Duration - if rf, ok := ret.Get(0).(func() time.Duration); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(time.Duration) - } - - return r0 -} - // GetRunID provides a mock function with given fields: func (_m *Knapsack) GetRunID() string { ret := _m.Called() @@ -1516,24 +1498,6 @@ func (_m *Knapsack) SetForceControlSubsystems(force bool) error { return r0 } -// SetTableGenerateTimeout provides a mock function with given fields: interval -func (_m *Knapsack) SetTableGenerateTimeout(interval time.Duration) error { - ret := _m.Called(interval) - - if len(ret) == 0 { - panic("no return value specified for SetTableGenerateTimeout") - } - - var r0 error - if rf, ok := ret.Get(0).(func(time.Duration) error); ok { - r0 = rf(interval) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // SetInModernStandby provides a mock function with given fields: enabled func (_m *Knapsack) SetInModernStandby(enabled bool) error { ret := _m.Called(enabled) @@ -1814,6 +1778,24 @@ func (_m *Knapsack) SetSystrayRestartEnabled(enabled bool) error { return r0 } +// SetTableGenerateTimeout provides a mock function with given fields: interval +func (_m *Knapsack) SetTableGenerateTimeout(interval time.Duration) error { + ret := _m.Called(interval) + + if len(ret) == 0 { + panic("no return value specified for SetTableGenerateTimeout") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Duration) error); ok { + r0 = rf(interval) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SetTraceBatchTimeout provides a mock function with given fields: duration func (_m *Knapsack) SetTraceBatchTimeout(duration time.Duration) error { ret := _m.Called(duration) @@ -2097,6 +2079,24 @@ func (_m *Knapsack) SystrayRestartEnabled() bool { return r0 } +// TableGenerateTimeout provides a mock function with given fields: +func (_m *Knapsack) TableGenerateTimeout() time.Duration { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TableGenerateTimeout") + } + + var r0 time.Duration + if rf, ok := ret.Get(0).(func() time.Duration); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Duration) + } + + return r0 +} + // TokenStore provides a mock function with given fields: func (_m *Knapsack) TokenStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() @@ -2315,6 +2315,26 @@ func (_m *Knapsack) WatchdogUtilizationLimitPercent() int { return r0 } +// ZtaInfoStore provides a mock function with given fields: +func (_m *Knapsack) ZtaInfoStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ZtaInfoStore") + } + + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) + } + } + + return r0 +} + // NewKnapsack creates a new instance of Knapsack. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewKnapsack(t interface { diff --git a/ee/agent/types/stores.go b/ee/agent/types/stores.go index afaeaaaa9..793caf629 100644 --- a/ee/agent/types/stores.go +++ b/ee/agent/types/stores.go @@ -18,4 +18,5 @@ type Stores interface { ServerProvidedDataStore() KVStore TokenStore() KVStore LauncherHistoryStore() KVStore + ZtaInfoStore() KVStore }