Skip to content

Commit

Permalink
Merge branch 'main' into becca/TestRestart-windows
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany authored Jan 7, 2025
2 parents 71d76ce + 557a918 commit cfea863
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 56 deletions.
20 changes: 16 additions & 4 deletions ee/agent/reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/kolide/launcher/ee/agent/storage"
"github.com/kolide/launcher/ee/agent/types"
"github.com/kolide/launcher/pkg/osquery/runsimple"
"github.com/kolide/launcher/pkg/traces"
Expand Down Expand Up @@ -287,10 +289,20 @@ func currentMunemo(k types.Knapsack) (string, error) {
// as a record of the current state of this database before reset. It appends this record
// to previous records if they exist, and returns the collection ready for storage.
func prepareDatabaseResetRecords(ctx context.Context, k types.Knapsack, resetReason string) ([]byte, error) { // nolint:unused
nodeKey, err := k.ConfigStore().Get([]byte("nodeKey"))
if err != nil {
k.Slogger().Log(ctx, slog.LevelWarn, "could not get node key from store", "err", err)
nodeKeys := make([]string, 0)
for _, registrationId := range k.RegistrationIDs() {
nodeKey, err := k.ConfigStore().Get(storage.KeyByIdentifier([]byte("nodeKey"), storage.IdentifierTypeRegistration, []byte(registrationId)))
if err != nil {
k.Slogger().Log(ctx, slog.LevelWarn,
"could not get node key from store",
"registration_id", registrationId,
"err", err,
)
continue
}
nodeKeys = append(nodeKeys, string(nodeKey))
}
nodeKey := strings.Join(nodeKeys, ",")

localPubKey, err := getLocalPubKey(k)
if err != nil {
Expand Down Expand Up @@ -328,7 +340,7 @@ func prepareDatabaseResetRecords(ctx context.Context, k types.Knapsack, resetRea
}

dataToStore := dbResetRecord{
NodeKey: string(nodeKey),
NodeKey: nodeKey,
PubKeys: [][]byte{localPubKey},
Serial: string(serial),
HardwareUUID: string(hardwareUuid),
Expand Down
48 changes: 28 additions & 20 deletions ee/agent/startupsettings/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log/slog"

"github.com/kolide/launcher/ee/agent/flags/keys"
"github.com/kolide/launcher/ee/agent/storage"
agentsqlite "github.com/kolide/launcher/ee/agent/storage/sqlite"
"github.com/kolide/launcher/ee/agent/types"
"github.com/kolide/launcher/pkg/traces"
Expand Down Expand Up @@ -70,23 +71,27 @@ func (s *startupSettingsWriter) WriteSettings() error {
}
updatedFlags["use_tuf_autoupdater"] = "enabled" // Hardcode for backwards compatibility circa v1.5.3

atcConfig, err := s.extractAutoTableConstructionConfig()
if err != nil {
s.knapsack.Slogger().Log(context.TODO(), slog.LevelDebug,
"extracting auto_table_construction config",
"err", err,
)
} else {
updatedFlags["auto_table_construction"] = atcConfig
}
for _, registrationId := range s.knapsack.RegistrationIDs() {
atcConfig, err := s.extractAutoTableConstructionConfig(registrationId)
if err != nil {
s.knapsack.Slogger().Log(context.TODO(), slog.LevelDebug,
"extracting auto_table_construction config",
"err", err,
)
} else {
atcConfigKey := storage.KeyByIdentifier([]byte("auto_table_construction"), storage.IdentifierTypeRegistration, []byte(registrationId))
updatedFlags[string(atcConfigKey)] = atcConfig
}

if katcConfig, err := s.extractKATCConstructionConfig(); err != nil {
s.knapsack.Slogger().Log(context.TODO(), slog.LevelDebug,
"extracting katc_config",
"err", err,
)
} else {
updatedFlags["katc_config"] = katcConfig
if katcConfig, err := s.extractKATCConstructionConfig(registrationId); err != nil {
s.knapsack.Slogger().Log(context.TODO(), slog.LevelDebug,
"extracting katc_config",
"err", err,
)
} else {
katcConfigKey := storage.KeyByIdentifier([]byte("katc_config"), storage.IdentifierTypeRegistration, []byte(registrationId))
updatedFlags[string(katcConfigKey)] = katcConfig
}
}

if _, err := s.kvStore.Update(updatedFlags); err != nil {
Expand All @@ -112,8 +117,8 @@ func (s *startupSettingsWriter) Close() error {
return s.kvStore.Close()
}

func (s *startupSettingsWriter) extractAutoTableConstructionConfig() (string, error) {
osqConfig, err := s.knapsack.ConfigStore().Get([]byte("config"))
func (s *startupSettingsWriter) extractAutoTableConstructionConfig(registrationId string) (string, error) {
osqConfig, err := s.knapsack.ConfigStore().Get(storage.KeyByIdentifier([]byte("config"), storage.IdentifierTypeRegistration, []byte(registrationId)))
if err != nil {
return "", fmt.Errorf("could not get osquery config from store: %w", err)
}
Expand All @@ -140,10 +145,13 @@ func (s *startupSettingsWriter) extractAutoTableConstructionConfig() (string, er
return string(atcJson), nil
}

func (s *startupSettingsWriter) extractKATCConstructionConfig() (string, error) {
func (s *startupSettingsWriter) extractKATCConstructionConfig(registrationId string) (string, error) {
kolideCfg := make(map[string]string)
if err := s.knapsack.KatcConfigStore().ForEach(func(k []byte, v []byte) error {
kolideCfg[string(k)] = string(v)
key, _, identifier := storage.SplitKey(k)
if string(identifier) == registrationId {
kolideCfg[string(key)] = string(v)
}
return nil
}); err != nil {
return "", fmt.Errorf("could not get Kolide ATC config from store: %w", err)
Expand Down
4 changes: 4 additions & 0 deletions ee/agent/startupsettings/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/kolide/launcher/ee/agent/flags/keys"
"github.com/kolide/launcher/ee/agent/storage/inmemory"
agentsqlite "github.com/kolide/launcher/ee/agent/storage/sqlite"
"github.com/kolide/launcher/ee/agent/types"
typesmocks "github.com/kolide/launcher/ee/agent/types/mocks"
"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/stretchr/testify/mock"
Expand All @@ -35,6 +36,7 @@ func TestOpenWriter_NewDatabase(t *testing.T) {
k.On("ConfigStore").Return(inmemory.NewStore())
k.On("Slogger").Return(multislogger.NewNopLogger())
k.On("KatcConfigStore").Return(inmemory.NewStore())
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID})

// Set up storage db, which should create the database and set all flags
s, err := OpenWriter(context.TODO(), k)
Expand Down Expand Up @@ -87,6 +89,7 @@ func TestOpenWriter_DatabaseAlreadyExists(t *testing.T) {
k.On("RegisterChangeObserver", mock.Anything, keys.UpdateChannel)
k.On("RegisterChangeObserver", mock.Anything, keys.PinnedLauncherVersion)
k.On("RegisterChangeObserver", mock.Anything, keys.PinnedOsquerydVersion)
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID})

// Set up flag
updateChannelVal := "alpha"
Expand Down Expand Up @@ -132,6 +135,7 @@ func TestFlagsChanged(t *testing.T) {
k.On("RegisterChangeObserver", mock.Anything, keys.UpdateChannel)
k.On("RegisterChangeObserver", mock.Anything, keys.PinnedLauncherVersion)
k.On("RegisterChangeObserver", mock.Anything, keys.PinnedOsquerydVersion)
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID})
updateChannelVal := "beta"
k.On("UpdateChannel").Return(updateChannelVal).Once()
pinnedLauncherVersion := "1.2.3"
Expand Down
47 changes: 47 additions & 0 deletions ee/agent/storage/keys.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,52 @@
package storage

import "bytes"

var (
// Well-known keys
ObservabilityIngestAuthTokenKey = []byte("observability_ingest_auth_token")

// Identifier types in complex keys
IdentifierTypeRegistration = []byte("registration")

defaultIdentifier = []byte("default")
)

const (
keyDelimiter byte = 58 // :
)

func KeyByIdentifier(key []byte, identifierType []byte, identifier []byte) []byte {
// The default value is stored under `key`, without any identifier
if len(identifier) == 0 || bytes.Equal(identifier, defaultIdentifier) {
return key
}

// Key will take the form `<key>:<identifierType>:<identifier>` -- allocate
// a new key with the appropriate capacity.
totalSize := len(key) + 1 + len(identifierType) + 1 + len(identifier)
newKey := make([]byte, 0, totalSize)

newKey = append(newKey, key...)
newKey = append(newKey, keyDelimiter)
newKey = append(newKey, identifierType...)
newKey = append(newKey, keyDelimiter)
newKey = append(newKey, identifier...)

return newKey
}

func SplitKey(key []byte) ([]byte, []byte, []byte) {
if !bytes.Contains(key, []byte{keyDelimiter}) {
return key, nil, defaultIdentifier
}

// Key takes the form `<key>:<identifierType>:<identifier>` -- split
// on the keyDelimiter.
parts := bytes.SplitN(key, []byte{keyDelimiter}, 3)
if len(parts) != 3 {
return key, nil, defaultIdentifier
}

return parts[0], parts[1], parts[2]
}
92 changes: 92 additions & 0 deletions ee/agent/storage/keys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package storage

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestKeyByIdentifier(t *testing.T) {
t.Parallel()

for _, tt := range []struct {
testCaseName string
key []byte
identifierType []byte
identifier []byte
expectedKey []byte
}{
{
testCaseName: "default identifier",
key: []byte("nodeKey"),
identifierType: IdentifierTypeRegistration,
identifier: []byte("default"),
expectedKey: []byte("nodeKey"),
},
{
testCaseName: "empty identifier",
key: []byte("config"),
identifierType: IdentifierTypeRegistration,
identifier: nil,
expectedKey: []byte("config"),
},
{
testCaseName: "registration identifier",
key: []byte("uuid"),
identifierType: IdentifierTypeRegistration,
identifier: []byte("some-test-registration-id"),
expectedKey: []byte("uuid:registration:some-test-registration-id"),
},
} {
tt := tt
t.Run(tt.testCaseName, func(t *testing.T) {
t.Parallel()

require.Equal(t, tt.expectedKey, KeyByIdentifier(tt.key, tt.identifierType, tt.identifier))
})
}
}

func TestSplitKey(t *testing.T) {
t.Parallel()

for _, tt := range []struct {
testCaseName string
key []byte
expectedKey []byte
expectedIdentifierType []byte
expectedIdentifier []byte
}{
{
testCaseName: "default node key",
key: []byte("nodeKey"),
expectedKey: []byte("nodeKey"),
expectedIdentifierType: nil,
expectedIdentifier: []byte("default"),
},
{
testCaseName: "uuid by registration",
key: []byte("uuid:registration:some-test-registration-id"),
expectedKey: []byte("uuid"),
expectedIdentifierType: IdentifierTypeRegistration,
expectedIdentifier: []byte("some-test-registration-id"),
},
{
testCaseName: "katc table by registration",
key: []byte("katc_some_test_table:registration:another-test-registration-id"),
expectedKey: []byte("katc_some_test_table"),
expectedIdentifierType: IdentifierTypeRegistration,
expectedIdentifier: []byte("another-test-registration-id"),
},
} {
tt := tt
t.Run(tt.testCaseName, func(t *testing.T) {
t.Parallel()

splitKey, identifierType, identifier := SplitKey(tt.key)
require.Equal(t, tt.expectedKey, splitKey)
require.Equal(t, tt.expectedIdentifierType, identifierType)
require.Equal(t, tt.expectedIdentifier, identifier)
})
}
}
1 change: 1 addition & 0 deletions ee/uninstall/uninstall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func TestUninstall(t *testing.T) {
k.On("EnrollSecretPath").Return(enrollSecretPath)
k.On("Slogger").Return(multislogger.NewNopLogger())
k.On("RootDirectory").Return(tempRootDir)
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID})
testConfigStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())
require.NoError(t, err, "could not create test config store")
k.On("ConfigStore").Return(testConfigStore)
Expand Down
Loading

0 comments on commit cfea863

Please sign in to comment.