diff --git a/internal/dev_server/api/api.yaml b/internal/dev_server/api/api.yaml index 71bcd413..01b1f717 100644 --- a/internal/dev_server/api/api.yaml +++ b/internal/dev_server/api/api.yaml @@ -9,6 +9,25 @@ info: servers: - url: "http" paths: + /dev/backup: + get: + summary: get the backup + operationId: getBackup + responses: + 200: + $ref: "#/components/responses/DbBackup" + post: + summary: post backup + operationId: restoreBackup + requestBody: + content: + application/vnd.sqlite3: + schema: + type: string + format: binary + responses: + 200: + description: 'Backup restored' /dev/projects: get: summary: lists all projects that have been configured for the dev server @@ -142,7 +161,7 @@ paths: description: limit the number of environments returned required: false schema: - type: integer + type: integer responses: 200: description: OK. List of environments @@ -286,6 +305,13 @@ components: application/json: schema: $ref: "#/components/schemas/Project" + DbBackup: + description: A backup of the local sqlite database + content: + application/vnd.sqlite3: + schema: + type: string + format: binary ErrorResponse: description: Error response object content: diff --git a/internal/dev_server/api/get_backup.go b/internal/dev_server/api/get_backup.go new file mode 100644 index 00000000..b1babc14 --- /dev/null +++ b/internal/dev_server/api/get_backup.go @@ -0,0 +1,21 @@ +package api + +import ( + "context" + + "github.com/launchdarkly/ldcli/internal/dev_server/model" +) + +func (s server) GetBackup(ctx context.Context, request GetBackupRequestObject) (GetBackupResponseObject, error) { + store := model.StoreFromContext(ctx) + backup, size, err := store.CreateBackup(ctx) + if err != nil { + return nil, err + } + + return GetBackup200ApplicationvndSqlite3Response{DbBackupApplicationvndSqlite3Response{ + Body: backup, + ContentLength: size, + }}, nil + +} diff --git a/internal/dev_server/api/restore_backup.go b/internal/dev_server/api/restore_backup.go new file mode 100644 index 00000000..f1fe6e88 --- /dev/null +++ b/internal/dev_server/api/restore_backup.go @@ -0,0 +1,14 @@ +package api + +import ( + "context" + "github.com/launchdarkly/ldcli/internal/dev_server/model" +) + +func (s server) RestoreBackup(ctx context.Context, request RestoreBackupRequestObject) (RestoreBackupResponseObject, error) { + err := model.RestoreDb(ctx, request.Body) + if err != nil { + return nil, err + } + return RestoreBackup200Response{}, nil +} diff --git a/internal/dev_server/api/server.gen.go b/internal/dev_server/api/server.gen.go index fd7f3b14..9d359b1a 100644 --- a/internal/dev_server/api/server.gen.go +++ b/internal/dev_server/api/server.gen.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "github.com/gorilla/mux" @@ -170,6 +171,12 @@ type PutOverrideFlagJSONRequestBody = FlagValue // ServerInterface represents all server handlers. type ServerInterface interface { + // get the backup + // (GET /dev/backup) + GetBackup(w http.ResponseWriter, r *http.Request) + // post backup + // (POST /dev/backup) + RestoreBackup(w http.ResponseWriter, r *http.Request) // lists all projects that have been configured for the dev server // (GET /dev/projects) GetProjects(w http.ResponseWriter, r *http.Request) @@ -205,6 +212,36 @@ type ServerInterfaceWrapper struct { type MiddlewareFunc func(http.Handler) http.Handler +// GetBackup operation middleware +func (siw *ServerInterfaceWrapper) GetBackup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.GetBackup(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + +// RestoreBackup operation middleware +func (siw *ServerInterfaceWrapper) RestoreBackup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.RestoreBackup(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // GetProjects operation middleware func (siw *ServerInterfaceWrapper) GetProjects(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -585,6 +622,10 @@ func HandlerWithOptions(si ServerInterface, options GorillaServerOptions) http.H ErrorHandlerFunc: options.ErrorHandlerFunc, } + r.HandleFunc(options.BaseURL+"/dev/backup", wrapper.GetBackup).Methods("GET") + + r.HandleFunc(options.BaseURL+"/dev/backup", wrapper.RestoreBackup).Methods("POST") + r.HandleFunc(options.BaseURL+"/dev/projects", wrapper.GetProjects).Methods("GET") r.HandleFunc(options.BaseURL+"/dev/projects/{projectKey}", wrapper.DeleteProject).Methods("DELETE") @@ -604,6 +645,12 @@ func HandlerWithOptions(si ServerInterface, options GorillaServerOptions) http.H return r } +type DbBackupApplicationvndSqlite3Response struct { + Body io.Reader + + ContentLength int64 +} + type ErrorResponseJSONResponse struct { // Code specific error code encountered Code string `json:"code"` @@ -622,6 +669,47 @@ type FlagOverrideJSONResponse struct { type ProjectJSONResponse Project +type GetBackupRequestObject struct { +} + +type GetBackupResponseObject interface { + VisitGetBackupResponse(w http.ResponseWriter) error +} + +type GetBackup200ApplicationvndSqlite3Response struct { + DbBackupApplicationvndSqlite3Response +} + +func (response GetBackup200ApplicationvndSqlite3Response) VisitGetBackupResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/vnd.sqlite3") + if response.ContentLength != 0 { + w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength)) + } + w.WriteHeader(200) + + if closer, ok := response.Body.(io.ReadCloser); ok { + defer closer.Close() + } + _, err := io.Copy(w, response.Body) + return err +} + +type RestoreBackupRequestObject struct { + Body io.Reader +} + +type RestoreBackupResponseObject interface { + VisitRestoreBackupResponse(w http.ResponseWriter) error +} + +type RestoreBackup200Response struct { +} + +func (response RestoreBackup200Response) VisitRestoreBackupResponse(w http.ResponseWriter) error { + w.WriteHeader(200) + return nil +} + type GetProjectsRequestObject struct { } @@ -856,6 +944,12 @@ func (response PutOverrideFlag400JSONResponse) VisitPutOverrideFlagResponse(w ht // StrictServerInterface represents all server handlers. type StrictServerInterface interface { + // get the backup + // (GET /dev/backup) + GetBackup(ctx context.Context, request GetBackupRequestObject) (GetBackupResponseObject, error) + // post backup + // (POST /dev/backup) + RestoreBackup(ctx context.Context, request RestoreBackupRequestObject) (RestoreBackupResponseObject, error) // lists all projects that have been configured for the dev server // (GET /dev/projects) GetProjects(ctx context.Context, request GetProjectsRequestObject) (GetProjectsResponseObject, error) @@ -911,6 +1005,56 @@ type strictHandler struct { options StrictHTTPServerOptions } +// GetBackup operation middleware +func (sh *strictHandler) GetBackup(w http.ResponseWriter, r *http.Request) { + var request GetBackupRequestObject + + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) { + return sh.ssi.GetBackup(ctx, request.(GetBackupRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "GetBackup") + } + + response, err := handler(r.Context(), w, r, request) + + if err != nil { + sh.options.ResponseErrorHandlerFunc(w, r, err) + } else if validResponse, ok := response.(GetBackupResponseObject); ok { + if err := validResponse.VisitGetBackupResponse(w); err != nil { + sh.options.ResponseErrorHandlerFunc(w, r, err) + } + } else if response != nil { + sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response)) + } +} + +// RestoreBackup operation middleware +func (sh *strictHandler) RestoreBackup(w http.ResponseWriter, r *http.Request) { + var request RestoreBackupRequestObject + + request.Body = r.Body + + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) { + return sh.ssi.RestoreBackup(ctx, request.(RestoreBackupRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "RestoreBackup") + } + + response, err := handler(r.Context(), w, r, request) + + if err != nil { + sh.options.ResponseErrorHandlerFunc(w, r, err) + } else if validResponse, ok := response.(RestoreBackupResponseObject); ok { + if err := validResponse.VisitRestoreBackupResponse(w); err != nil { + sh.options.ResponseErrorHandlerFunc(w, r, err) + } + } else if response != nil { + sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response)) + } +} + // GetProjects operation middleware func (sh *strictHandler) GetProjects(w http.ResponseWriter, r *http.Request) { var request GetProjectsRequestObject diff --git a/internal/dev_server/db/backup/sqlite_backup_test.go b/internal/dev_server/db/backup/sqlite_backup_test.go new file mode 100644 index 00000000..7f769b4b --- /dev/null +++ b/internal/dev_server/db/backup/sqlite_backup_test.go @@ -0,0 +1,136 @@ +package backup_test + +import ( + "context" + "database/sql" + "github.com/google/uuid" + "os" + "strconv" + + "github.com/launchdarkly/ldcli/internal/dev_server/db/backup" + "github.com/stretchr/testify/require" + "testing" +) + +type myTable struct { + key string + someValue string +} + +func TestDbBackup(t *testing.T) { + ctx := context.Background() + dbPath := "test_source_for_backup.db" + + original, err := sql.Open("sqlite3", dbPath) + require.NoError(t, err) + + defer func() { + require.NoError(t, os.Remove(dbPath)) + }() + + originalResults := createAndSeedTable(ctx, original, t) + dataSize := len(originalResults) + + manager := backup.NewManager(dbPath, "main", "ld_cli_backup_test*.bak", "ld_cli_restore_test*.bak") + + for i := 0; i < 5; i++ { + t.Run("Backup_"+strconv.Itoa(i), func(t *testing.T) { + result, err := manager.MakeBackupFile(ctx) + require.NoError(t, err) + require.NotEqual(t, "", result) + + backupDb, err := sql.Open("sqlite3", result) + require.NoError(t, err) + + resultsFromBackup, err := getResults(backupDb) + require.NoError(t, err) + require.Len(t, resultsFromBackup, dataSize) + + require.Equal(t, originalResults, resultsFromBackup) + + require.NoError(t, backupDb.Close()) + }) + } +} + +func TestDbRestore(t *testing.T) { + ctx := context.Background() + dbToRestorePath := "test_source_for_restore.db" + originalDbPathOrigin := "test_source_for_restore_origin.db" + + original, err := sql.Open("sqlite3", dbToRestorePath) + require.NoError(t, err) + + defer func() { + require.NoError(t, os.Remove(dbToRestorePath)) + }() + + originals := createAndSeedTable(ctx, original, t) + require.NoError(t, original.Close()) + require.NotEmpty(t, originals) + + t.Run("Valid restore completed", func(t *testing.T) { + manager := backup.NewManager(originalDbPathOrigin, "main", "ld_cli_backup_test*.bak", "ld_cli_restore_test*.bak") + manager.AddValidationQueries("select count(1) from my_table") + + f, err := os.Open(dbToRestorePath) + require.NoError(t, err) + + restoredLocation, err := manager.RestoreToFile(ctx, f) + require.NoError(t, err) + require.NotEqual(t, "", restoredLocation) + }) + + t.Run("Query validation fails should error", func(t *testing.T) { + manager := backup.NewManager(originalDbPathOrigin, "main", "ld_cli_backup_test*.bak", "ld_cli_restore_test*.bak") + manager.AddValidationQueries("select count(1) from a_non_existent_table") + + f, err := os.Open(dbToRestorePath) + require.NoError(t, err) + + restoredLocation, err := manager.RestoreToFile(ctx, f) + require.Error(t, err) + require.Equal(t, "", restoredLocation) + }) + +} + +const createTable = `CREATE TABLE IF NOT EXISTS my_table ( +key text PRIMARY KEY, +some_value text NOT NULL +)` +const insertStatement = `INSERT INTO my_table (key, some_value) VALUES (?, ?)` + +func createAndSeedTable(ctx context.Context, db *sql.DB, t *testing.T) []myTable { + _, err := db.ExecContext(ctx, createTable) + require.NoError(t, err) + + dataSize := 50 + for i := 0; i < dataSize; i++ { + _, err = db.ExecContext(ctx, insertStatement, uuid.New(), uuid.New()) + require.NoError(t, err) + } + + originalResults, err := getResults(db) + require.NoError(t, err) + require.Len(t, originalResults, dataSize) + return originalResults +} + +func getResults(db *sql.DB) ([]myTable, error) { + rows, err := db.Query("select key, some_value from my_table order by key desc") + if err != nil { + return nil, err + } + defer rows.Close() + var res []myTable + for rows.Next() { + r := myTable{} + err = rows.Scan(&r.key, &r.someValue) + if err != nil { + return nil, err + } + res = append(res, r) + } + return res, nil +} diff --git a/internal/dev_server/db/backup/sqllite_backup.go b/internal/dev_server/db/backup/sqllite_backup.go new file mode 100644 index 00000000..d37ef710 --- /dev/null +++ b/internal/dev_server/db/backup/sqllite_backup.go @@ -0,0 +1,206 @@ +package backup + +import ( + "context" + "database/sql" + "fmt" + sqllite "github.com/mattn/go-sqlite3" + "github.com/pkg/errors" + "io" + "log" + "os" + "sync" + "sync/atomic" +) + +var c atomic.Int32 + +type Manager struct { + dbPath string + dbName string + backupFilePattern string + restoreFilePattern string + driverName string + validationQueries []string + mutex sync.Mutex + conns []*sqllite.SQLiteConn +} + +// NewManager creates a new backup manager +// Each instance of a Manager can run 1 backup or restore at a time (internally uses a mutex) +// It is safe to create multiple instances of Manager which could run Backups/Restores concurrently +func NewManager(dbPath string, dbName string, backupFilePattern string, restoreFilePattern string) *Manager { + count := c.Add(1) + m := &Manager{ + dbPath: dbPath, + dbName: dbName, + backupFilePattern: backupFilePattern, + restoreFilePattern: restoreFilePattern, + driverName: fmt.Sprintf("sqlite3-backups-%d", count), + conns: make([]*sqllite.SQLiteConn, 0), + validationQueries: make([]string, 0), + mutex: sync.Mutex{}, + } + sql.Register(m.driverName, &sqllite.SQLiteDriver{ + ConnectHook: func(conn *sqllite.SQLiteConn) error { + m.conns = append(m.conns, conn) + return nil + }, + }) + return m +} + +// AddValidationQueries Adds queries to run on a restored database to ensure meets some criteria +// These queries should cause db.Exec to return error if the database imported is invalid. +// For example, if the database does not have a vital table +func (m *Manager) AddValidationQueries(queries ...string) { + m.validationQueries = append(m.validationQueries, queries...) +} + +// assumes is that the caller has the Manager's mutex. +func (m *Manager) resetConnections() { + m.conns = make([]*sqllite.SQLiteConn, 0) +} + +// connectToDb opens a sqlite connection and pings the database to populate the underlying sqlite connection +// assumes is that the caller has the Manager's mutex. +func (m *Manager) connectToDb(ctx context.Context, path string) (*sql.DB, error) { + db, err := sql.Open(m.driverName, path) + if err != nil { + return nil, errors.Wrap(err, "open database") + } + + connCountBefore := len(m.conns) + + err = db.PingContext(ctx) + if err != nil { + return nil, errors.Wrap(err, "connecting to database database") + } + + // We expect there to only ever be 1 or 2 connections + expectedDbConnectionCount := connCountBefore + 1 + if len(m.conns) != expectedDbConnectionCount { + return nil, errors.New("error setting up backup connection: database connection count mismatch") + } + + return db, nil +} + +// RestoreToFile returns a string path of the sqlite database restored from the stream +func (m *Manager) RestoreToFile(ctx context.Context, stream io.Reader) (string, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + // Make a temp file to copy into + tempFile, err := os.CreateTemp("", m.restoreFilePattern) + if err != nil { + return "", errors.Wrapf(err, "unable to create temp file") + } + _, err = io.Copy(tempFile, stream) + if err != nil { + return "", errors.Wrapf(err, "unable to write to temp file") + } + + // connect to db + copiedDb, err := m.connectToDb(ctx, tempFile.Name()) + if err != nil { + return "", errors.Wrapf(err, "unable to connect to database") + } + defer func(copiedDb *sql.DB) { + err := copiedDb.Close() + if err != nil { + log.Println(err) + } + }(copiedDb) + + for _, query := range m.validationQueries { + _, err := copiedDb.ExecContext(ctx, query) + if err != nil { + return "", errors.Wrapf(err, "restored db failed validation query: %s", query) + } + } + + return tempFile.Name(), nil +} + +// MakeBackupFile returns a string path of the sqlite database backup +func (m *Manager) MakeBackupFile(ctx context.Context) (string, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + // clear out any connections from previous backups + m.resetConnections() + + // Make a temp file to back-up into + tempFile, err := os.CreateTemp("", m.backupFilePattern) + if err != nil { + return "", errors.Wrapf(err, "unable to create temp file") + } + + backupPath := tempFile.Name() + if err := tempFile.Close(); err != nil { + return "", errors.Wrapf(err, "unable to close temp file") + } + + // connect to source to populate sqlite connection + sourceDb, err := m.connectToDb(ctx, m.dbPath) + if err != nil { + return "", errors.Wrap(err, "open source database") + } + + defer func() { + err := sourceDb.Close() + if err != nil { + log.Printf("unable to close source connection: %s", err) + } + }() + + // connect to backup to populate sqlite connection + backupDb, err := m.connectToDb(ctx, backupPath) + if err != nil { + return "", errors.Wrap(err, "open backup database") + } + + defer func() { + err := backupDb.Close() + if err != nil { + log.Printf("unable to close source connection: %s", err) + } + }() + + // validate connection length + if len(m.conns) != 2 { + return "", errors.Wrapf(err, "no connection found to backup") + } + var srcDbConn = m.conns[0] + var backupDbConn = m.conns[1] + + err = runBackup(backupDbConn, srcDbConn, m.dbName) + if err != nil { + return "", errors.Wrapf(err, "unable to start backup db at %s", backupPath) + } + return backupPath, nil +} + +func runBackup(backupDbConn *sqllite.SQLiteConn, srcDbConn *sqllite.SQLiteConn, dbName string) error { + backup, err := backupDbConn.Backup(dbName, srcDbConn, dbName) + if err != nil { + return errors.Wrap(err, "unable to start backup db") + } + defer func(backup *sqllite.SQLiteBackup) { + err := backup.Close() + if err != nil { + log.Printf("unable to close backup connection: %s", err) + } + }(backup) + + var isDone = false + var stepError error = nil + for !isDone { + isDone, stepError = backup.Step(1) + if stepError != nil { + return errors.Wrap(stepError, "unable to backup db at %s") + } + } + return nil +} diff --git a/internal/dev_server/db/sqlite.go b/internal/dev_server/db/sqlite.go index 722dc727..a6289ac8 100644 --- a/internal/dev_server/db/sqlite.go +++ b/internal/dev_server/db/sqlite.go @@ -4,21 +4,26 @@ import ( "context" "database/sql" "encoding/json" - _ "github.com/mattn/go-sqlite3" "github.com/pkg/errors" + "io" + "os" "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/ldcli/internal/dev_server/db/backup" "github.com/launchdarkly/ldcli/internal/dev_server/model" ) type Sqlite struct { database *sql.DB + dbPath string + + backupManager *backup.Manager } -var _ model.Store = Sqlite{} +var _ model.Store = &Sqlite{} -func (s Sqlite) GetDevProjectKeys(ctx context.Context) ([]string, error) { +func (s *Sqlite) GetDevProjectKeys(ctx context.Context) ([]string, error) { rows, err := s.database.Query("select key from projects") if err != nil { return nil, err @@ -35,7 +40,7 @@ func (s Sqlite) GetDevProjectKeys(ctx context.Context) ([]string, error) { return keys, nil } -func (s Sqlite) GetDevProject(ctx context.Context, key string) (*model.Project, error) { +func (s *Sqlite) GetDevProject(ctx context.Context, key string) (*model.Project, error) { var project model.Project var contextData string var flagStateData string @@ -66,7 +71,7 @@ func (s Sqlite) GetDevProject(ctx context.Context, key string) (*model.Project, return &project, nil } -func (s Sqlite) UpdateProject(ctx context.Context, project model.Project) (bool, error) { +func (s *Sqlite) UpdateProject(ctx context.Context, project model.Project) (bool, error) { flagsStateJson, err := json.Marshal(project.AllFlagsState) if err != nil { return false, errors.Wrap(err, "unable to marshal flags state when updating project") @@ -119,7 +124,7 @@ func (s Sqlite) UpdateProject(ctx context.Context, project model.Project) (bool, return true, nil } -func (s Sqlite) DeleteDevProject(ctx context.Context, key string) (bool, error) { +func (s *Sqlite) DeleteDevProject(ctx context.Context, key string) (bool, error) { result, err := s.database.Exec("DELETE FROM projects where key=?", key) if err != nil { return false, err @@ -153,7 +158,7 @@ func InsertAvailableVariations(ctx context.Context, tx *sql.Tx, project model.Pr return nil } -func (s Sqlite) InsertProject(ctx context.Context, project model.Project) (err error) { +func (s *Sqlite) InsertProject(ctx context.Context, project model.Project) (err error) { flagsStateJson, err := json.Marshal(project.AllFlagsState) if err != nil { return errors.Wrap(err, "unable to marshal flags state when writing project") @@ -203,7 +208,7 @@ VALUES (?, ?, ?, ?, ?) return tx.Commit() } -func (s Sqlite) GetAvailableVariationsForProject(ctx context.Context, projectKey string) (map[string][]model.Variation, error) { +func (s *Sqlite) GetAvailableVariationsForProject(ctx context.Context, projectKey string) (map[string][]model.Variation, error) { rows, err := s.database.QueryContext(ctx, ` SELECT flag_key, id, name, description, value FROM available_variations @@ -250,7 +255,7 @@ func (s Sqlite) GetAvailableVariationsForProject(ctx context.Context, projectKey return availableVariations, nil } -func (s Sqlite) GetOverridesForProject(ctx context.Context, projectKey string) (model.Overrides, error) { +func (s *Sqlite) GetOverridesForProject(ctx context.Context, projectKey string) (model.Overrides, error) { rows, err := s.database.QueryContext(ctx, ` SELECT flag_key, active, value, version FROM overrides @@ -295,7 +300,7 @@ func (s Sqlite) GetOverridesForProject(ctx context.Context, projectKey string) ( return overrides, nil } -func (s Sqlite) UpsertOverride(ctx context.Context, override model.Override) (model.Override, error) { +func (s *Sqlite) UpsertOverride(ctx context.Context, override model.Override) (model.Override, error) { valueJson, err := override.Value.MarshalJSON() if err != nil { return model.Override{}, errors.Wrap(err, "unable to marshal override value when writing override") @@ -324,7 +329,7 @@ func (s Sqlite) UpsertOverride(ctx context.Context, override model.Override) (mo return override, nil } -func (s Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey string) (int, error) { +func (s *Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey string) (int, error) { row := s.database.QueryRowContext(ctx, ` UPDATE overrides set active = false, version = version+1 @@ -345,21 +350,71 @@ func (s Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey stri return version, nil } -func NewSqlite(ctx context.Context, dbPath string) (Sqlite, error) { +func (s *Sqlite) RestoreBackup(ctx context.Context, stream io.Reader) (string, error) { + filepath, err := s.backupManager.RestoreToFile(ctx, stream) + if err != nil { + return "", errors.Wrap(err, "unable to restore backup db") + } + err = s.database.Close() + if err != nil { + return "", errors.Wrap(err, "unable to close database before restoring backup") + } + err = os.Rename(filepath, s.dbPath) + if err != nil { + //panic because this would really leave the app in an invalid state + panic(err) + } + s.database, err = sql.Open("sqlite3", s.dbPath) + if err != nil { + //panic because this would really leave the app in an invalid state + panic(err) + } + + err = s.runMigrations(ctx) + if err != nil { + return "", errors.Wrap(err, "unable to run migrations after restoring backup") + } + + return filepath, err +} + +func (s *Sqlite) CreateBackup(ctx context.Context) (io.ReadCloser, int64, error) { + backupPath, err := s.backupManager.MakeBackupFile(ctx) + fi, err := os.Open(backupPath) + if err != nil { + return nil, 0, errors.Wrapf(err, "unable to open backup db at %s", backupPath) + } + stat, err := fi.Stat() + if err != nil { + return nil, 0, errors.Wrapf(err, "unable to stat backup db at %s", backupPath) + } + return fi, stat.Size(), nil +} + +func NewSqlite(ctx context.Context, dbPath string) (*Sqlite, error) { store := new(Sqlite) + store.dbPath = dbPath + store.backupManager = backup.NewManager(dbPath, "main", "ld_cli_*.bak", "ld_cli_restore_*.db") + store.backupManager.AddValidationQueries(validationQueries...) db, err := sql.Open("sqlite3", dbPath) if err != nil { - return Sqlite{}, err + return &Sqlite{}, err } store.database = db err = store.runMigrations(ctx) if err != nil { - return Sqlite{}, err + return &Sqlite{}, err } - return *store, nil + return store, nil +} + +var validationQueries = []string{ + "SELECT COUNT(1) from projects", + "SELECT COUNT(1) from overrides", + "SELECT COUNT(1) from available_variations", } -func (s Sqlite) runMigrations(ctx context.Context) error { +func (s *Sqlite) runMigrations(ctx context.Context) error { tx, err := s.database.BeginTx(ctx, nil) if err != nil { return err diff --git a/internal/dev_server/model/mocks/store.go b/internal/dev_server/model/mocks/store.go index 37669725..a5e2603a 100644 --- a/internal/dev_server/model/mocks/store.go +++ b/internal/dev_server/model/mocks/store.go @@ -11,6 +11,7 @@ package mocks import ( context "context" + io "io" reflect "reflect" model "github.com/launchdarkly/ldcli/internal/dev_server/model" @@ -40,6 +41,22 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder { return m.recorder } +// CreateBackup mocks base method. +func (m *MockStore) CreateBackup(arg0 context.Context) (io.ReadCloser, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateBackup", arg0) + ret0, _ := ret[0].(io.ReadCloser) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// CreateBackup indicates an expected call of CreateBackup. +func (mr *MockStoreMockRecorder) CreateBackup(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateBackup", reflect.TypeOf((*MockStore)(nil).CreateBackup), arg0) +} + // DeactivateOverride mocks base method. func (m *MockStore) DeactivateOverride(arg0 context.Context, arg1, arg2 string) (int, error) { m.ctrl.T.Helper() @@ -144,6 +161,21 @@ func (mr *MockStoreMockRecorder) InsertProject(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertProject", reflect.TypeOf((*MockStore)(nil).InsertProject), arg0, arg1) } +// RestoreBackup mocks base method. +func (m *MockStore) RestoreBackup(arg0 context.Context, arg1 io.Reader) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestoreBackup", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RestoreBackup indicates an expected call of RestoreBackup. +func (mr *MockStoreMockRecorder) RestoreBackup(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoreBackup", reflect.TypeOf((*MockStore)(nil).RestoreBackup), arg0, arg1) +} + // UpdateProject mocks base method. func (m *MockStore) UpdateProject(arg0 context.Context, arg1 model.Project) (bool, error) { m.ctrl.T.Helper() diff --git a/internal/dev_server/model/restore.go b/internal/dev_server/model/restore.go new file mode 100644 index 00000000..1a59554e --- /dev/null +++ b/internal/dev_server/model/restore.go @@ -0,0 +1,38 @@ +package model + +import ( + "context" + "io" +) + +func RestoreDb(ctx context.Context, stream io.Reader) error { + store := StoreFromContext(ctx) + _, err := store.RestoreBackup(ctx, stream) + if err != nil { + return err + } + + projects, err := store.GetDevProjectKeys(ctx) + if err != nil { + return err + } + + observers := GetObserversFromContext(ctx) + + for _, projectKey := range projects { + project, err := store.GetDevProject(ctx, projectKey) + if err != nil { + return err + } + allFlagsWithOverrides, err := project.GetFlagStateWithOverridesForProject(ctx) + if err != nil { + return err + } + observers.Notify(SyncEvent{ + ProjectKey: project.Key, + AllFlagsState: allFlagsWithOverrides, + }) + } + + return nil +} diff --git a/internal/dev_server/model/restore_test.go b/internal/dev_server/model/restore_test.go new file mode 100644 index 00000000..d522f410 --- /dev/null +++ b/internal/dev_server/model/restore_test.go @@ -0,0 +1,63 @@ +package model_test + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + adapters_mocks "github.com/launchdarkly/ldcli/internal/dev_server/adapters/mocks" + "github.com/launchdarkly/ldcli/internal/dev_server/model" + "github.com/launchdarkly/ldcli/internal/dev_server/model/mocks" +) + +func TestRestoreDb(t *testing.T) { + ctx := context.Background() + mockController := gomock.NewController(t) + observers := model.NewObservers() + ctx = model.SetObserversOnContext(ctx, observers) + ctx, _, _ = adapters_mocks.WithMockApiAndSdk(ctx, mockController) + store := mocks.NewMockStore(mockController) + ctx = model.ContextWithStore(ctx, store) + projKey := "proj" + sourceEnvKey := "env" + + proj := model.Project{ + Key: projKey, + SourceEnvironmentKey: sourceEnvKey, + Context: ldcontext.New(t.Name()), + AllFlagsState: map[string]model.FlagState{ + "boolFlag": { + Version: 0, + Value: ldvalue.Bool(false), + }, + }, + } + + t.Run("Returns error if restore fails", func(t *testing.T) { + store.EXPECT().RestoreBackup(gomock.Any(), gomock.Any()).Return("", errors.New("restore failed")) + + err := model.RestoreDb(ctx, strings.NewReader("")) + assert.NotNil(t, err) + }) + + t.Run("Notifies Projects if restore completes", func(t *testing.T) { + store.EXPECT().RestoreBackup(gomock.Any(), gomock.Any()).Return("restore.db", nil) + store.EXPECT().GetDevProjectKeys(gomock.Any()).Return([]string{projKey}, nil) + store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(&proj, nil) + store.EXPECT().GetOverridesForProject(gomock.Any(), projKey).Return(model.Overrides{}, nil) + observer := mocks.NewMockObserver(mockController) + observer.EXPECT().Handle(model.SyncEvent{ProjectKey: projKey, AllFlagsState: proj.AllFlagsState}) + + observers.RegisterObserver(observer) + + err := model.RestoreDb(ctx, strings.NewReader("")) + require.NoError(t, err) + }) +} diff --git a/internal/dev_server/model/store.go b/internal/dev_server/model/store.go index 74326a46..6d36bddb 100644 --- a/internal/dev_server/model/store.go +++ b/internal/dev_server/model/store.go @@ -3,6 +3,7 @@ package model import ( "context" "errors" + "io" "net/http" "github.com/gorilla/mux" @@ -28,6 +29,9 @@ type Store interface { UpsertOverride(ctx context.Context, override Override) (Override, error) GetOverridesForProject(ctx context.Context, projectKey string) (Overrides, error) GetAvailableVariationsForProject(ctx context.Context, projectKey string) (map[string][]Variation, error) + + CreateBackup(ctx context.Context) (io.ReadCloser, int64, error) + RestoreBackup(ctx context.Context, stream io.Reader) (string, error) } func ContextWithStore(ctx context.Context, store Store) context.Context {