From 641b6bfc5d8d2421511ab85f3622f35b17955bfd Mon Sep 17 00:00:00 2001 From: Dan Lapid Date: Thu, 1 Sep 2022 11:11:12 +0300 Subject: [PATCH] Develop * Added watcher tests * Added test reporting and test summary to ci * Replaced Uber rate limiter with Golang rate limiter * Added buffer fill support to windows and linux --- .github/workflows/coverage.yml | 39 ++++ .github/workflows/test.yml | 46 ++-- cmd/receiver/main.go | 12 +- cmd/sender/main.go | 12 +- cmd/sendfiles/main.go | 5 - cmd/watcher/main.go | 23 +- config.toml | 1 + go.mod | 7 +- go.sum | 16 +- pkg/bandwidthlimiter/bandwidthlimiter.go | 16 +- pkg/bandwidthlimiter/bandwidthlimiter_test.go | 28 ++- pkg/config/config.go | 1 + pkg/config/config_test.go | 72 +++++++ pkg/database/database.go | 20 +- pkg/fecdecoder/fecdecoder.go | 3 +- pkg/fecencoder/fecencoder.go | 3 +- pkg/filecloser/filecloser.go | 119 +++++------ pkg/sender/sender.go | 2 +- pkg/udpreceiver/udpreceiver.go | 9 +- pkg/udpsender/udpsender.go | 3 +- pkg/utils/darwin_ioctl.go | 26 +++ pkg/utils/linux_ioctl.go | 51 +++++ pkg/utils/linux_procudp.go | 201 ++++++++++++++++++ pkg/utils/unix_ioctl.go | 39 ++-- pkg/utils/utils.go | 17 +- pkg/utils/utils_test.go | 168 +++++++++++++++ pkg/utils/windows_ioctl.go | 65 +++++- pkg/watcher/watcher.go | 32 ++- tests/system_test.go | 104 +++++++-- 29 files changed, 896 insertions(+), 244 deletions(-) create mode 100644 .github/workflows/coverage.yml create mode 100644 pkg/config/config_test.go create mode 100644 pkg/utils/darwin_ioctl.go create mode 100644 pkg/utils/linux_ioctl.go create mode 100644 pkg/utils/linux_procudp.go create mode 100644 pkg/utils/utils_test.go diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 0000000..6140dac --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,39 @@ +name: Coverage +on: + push: + branches: + - main + pull_request: + branches: + - main +jobs: + coverage: + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + go: + - "1.19" + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + + - name: Calc coverage + run: | + go test -v -covermode=count -coverprofile=coverage.out -coverpkg ./pkg/... ./... + - name: Convert coverage.out to coverage.lcov + uses: jandelgado/gcov2lcov-action@master + - name: Coveralls + uses: coverallsapp/github-action@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + path-to-lcov: coverage.lcov diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f54905f..82118c2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,36 +30,28 @@ jobs: go-version: ${{ matrix.go }} - name: test + if: runner.os != 'Windows' run: | - go test -v -race ./... + go install github.com/jstemmer/go-junit-report/v2@latest + go test -v -race ./... 2>&1 | go-junit-report -set-exit-code -iocopy -out report.xml - coverage: - strategy: - fail-fast: false - matrix: - os: - - ubuntu-latest - go: - - "1.19" - runs-on: ${{ matrix.os }} - steps: - - name: Checkout code - uses: actions/checkout@v3 - with: - fetch-depth: 1 + - name: test-windows + if: runner.os == 'Windows' + run: | + go install github.com/jstemmer/go-junit-report/v2@latest + go test -v ./... 2>&1 | go-junit-report -set-exit-code -iocopy -out report.xml - - name: Install Go - uses: actions/setup-go@v3 + - name: Test Report + uses: dorny/test-reporter@v1 with: - go-version: ${{ matrix.go }} + name: ${{ matrix.os }} Tests + path: report.xml + reporter: java-junit + fail-on-error: "false" + if: always() - - name: Calc coverage - run: | - go test -v -covermode=count -coverprofile=coverage.out -coverpkg ./pkg/... ./... - - name: Convert coverage.out to coverage.lcov - uses: jandelgado/gcov2lcov-action@master - - name: Coveralls - uses: coverallsapp/github-action@master + - name: Test Summary + uses: test-summary/action@v1 with: - github-token: ${{ secrets.GITHUB_TOKEN }} - path-to-lcov: coverage.lcov + paths: report.xml + if: always() diff --git a/cmd/receiver/main.go b/cmd/receiver/main.go index 0856643..b0a5588 100644 --- a/cmd/receiver/main.go +++ b/cmd/receiver/main.go @@ -6,9 +6,6 @@ import ( "oneway-filesync/pkg/database" "oneway-filesync/pkg/receiver" "oneway-filesync/pkg/utils" - "os" - "os/signal" - "syscall" "github.com/sirupsen/logrus" ) @@ -28,16 +25,9 @@ func main() { return } - if err = database.ConfigureDatabase(db); err != nil { - logrus.Errorf("Failed setting up db with err %v", err) - return - } - ctx, cancel := context.WithCancel(context.Background()) // Create a cancelable context and pass it to all goroutines, allows us to gracefully shut down the program receiver.Receiver(ctx, db, conf) - done := make(chan os.Signal, 1) - signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) - <-done + <-utils.CtrlC() cancel() // Gracefully shutdown and stop all goroutines } diff --git a/cmd/sender/main.go b/cmd/sender/main.go index 76eabf7..a8bec6e 100644 --- a/cmd/sender/main.go +++ b/cmd/sender/main.go @@ -6,9 +6,6 @@ import ( "oneway-filesync/pkg/database" "oneway-filesync/pkg/sender" "oneway-filesync/pkg/utils" - "os" - "os/signal" - "syscall" "github.com/sirupsen/logrus" ) @@ -27,16 +24,9 @@ func main() { return } - if err = database.ConfigureDatabase(db); err != nil { - logrus.Errorf("Failed setting up db with err %v", err) - return - } - ctx, cancel := context.WithCancel(context.Background()) // Create a cancelable context and pass it to all goroutines, allows us to gracefully shut down the program sender.Sender(ctx, db, conf) - done := make(chan os.Signal, 1) - signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) - <-done + <-utils.CtrlC() cancel() // Gracefully shutdown and stop all goroutines } diff --git a/cmd/sendfiles/main.go b/cmd/sendfiles/main.go index 142109a..acf8b2e 100644 --- a/cmd/sendfiles/main.go +++ b/cmd/sendfiles/main.go @@ -19,11 +19,6 @@ func main() { return } - if err = database.ConfigureDatabase(db); err != nil { - fmt.Printf("Failed setting up db with err %v\n", err) - return - } - path := os.Args[1] err = filepath.Walk(path, func(filepath string, info os.FileInfo, e error) error { if !info.IsDir() { diff --git a/cmd/watcher/main.go b/cmd/watcher/main.go index a0a557a..5943ad1 100644 --- a/cmd/watcher/main.go +++ b/cmd/watcher/main.go @@ -2,24 +2,21 @@ package main import ( "context" + "oneway-filesync/pkg/config" "oneway-filesync/pkg/database" "oneway-filesync/pkg/utils" "oneway-filesync/pkg/watcher" - "os" - "os/signal" - "syscall" - "github.com/rjeczalik/notify" "github.com/sirupsen/logrus" ) func main() { utils.InitializeLogging("watcher.log") - if len(os.Args) < 2 { - logrus.Errorf("Usage: %s ", os.Args[0]) + conf, err := config.GetConfig("config.toml") + if err != nil { + logrus.Errorf("Failed reading config with err %v", err) return } - path := os.Args[1] db, err := database.OpenDatabase("s_") if err != nil { @@ -27,17 +24,9 @@ func main() { return } - if err = database.ConfigureDatabase(db); err != nil { - logrus.Errorf("Failed setting up db with err %v", err) - return - } - - events := make(chan notify.EventInfo, 20) ctx, cancel := context.WithCancel(context.Background()) // Create a cancelable context and pass it to all goroutines, allows us to gracefully shut down the program - watcher.CreateWatcher(ctx, db, path, events) + watcher.Watcher(ctx, db, conf) - done := make(chan os.Signal, 1) - signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) - <-done + <-utils.CtrlC() cancel() // Gracefully shutdown and stop all goroutines } diff --git a/config.toml b/config.toml index 2efa3f4..5a82243 100644 --- a/config.toml +++ b/config.toml @@ -5,3 +5,4 @@ ChunkSize = 8192 ChunkFecRequired = 5 ChunkFecTotal = 10 OutDir = "./out" +WatchDir = "./tmp" diff --git a/go.mod b/go.mod index ca29ed0..8a2a3b8 100644 --- a/go.mod +++ b/go.mod @@ -5,17 +5,16 @@ go 1.19 require ( github.com/BurntSushi/toml v1.2.0 github.com/klauspost/reedsolomon v1.10.0 - github.com/rjeczalik/notify v0.9.2 + github.com/rjeczalik/notify v0.9.3-0.20210809113154-3472d85e95cd github.com/sirupsen/logrus v1.9.0 github.com/zhuangsirui/binpacker v2.0.0+incompatible - go.uber.org/ratelimit v0.2.0 - golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 + golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 + golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 gorm.io/driver/sqlite v1.3.6 gorm.io/gorm v1.23.8 ) require ( - github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/cpuid/v2 v2.0.14 // indirect diff --git a/go.sum b/go.sum index 459c504..5c91d89 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= github.com/BurntSushi/toml v1.2.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 h1:MzBOUgng9orim59UnfUTLRjMpd09C5uEVQ6RPGeCaVI= -github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129/go.mod h1:rFgpPQZYZ8vdbc+48xibu8ALc3yeyd64IhHS+PU6Yyg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -20,23 +18,21 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rjeczalik/notify v0.9.2 h1:MiTWrPj55mNDHEiIX5YUSKefw/+lCQVoAFmD6oQm5w8= github.com/rjeczalik/notify v0.9.2/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM= +github.com/rjeczalik/notify v0.9.3-0.20210809113154-3472d85e95cd h1:LHLg0gdpRUCvujg2Zol6e2Uknq5vHycLxqEzYwxt1vY= +github.com/rjeczalik/notify v0.9.3-0.20210809113154-3472d85e95cd/go.mod h1:gF3zSOrafR9DQEWSE8TjfI9NkooDxbyT4UgRGKZA0lc= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/zhuangsirui/binpacker v2.0.0+incompatible h1:s2wDYWXT4IznT7NUFzn5gJHqjtWz/zIwUxdiFGNomdk= github.com/zhuangsirui/binpacker v2.0.0+incompatible/go.mod h1:TdE7uEZ8Q7sMzbCpk2Y+ksFB8yA5AErPz0meDB612rU= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/ratelimit v0.2.0 h1:UQE2Bgi7p2B85uP5dC2bbRtig0C+OeNRnNEafLjsLPA= -go.uber.org/ratelimit v0.2.0/go.mod h1:YYBV4e4naJvhpitQrWJu1vCpgB7CboMe0qhltKt6mUg= golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 h1:UiNENfZ8gDvpiWw7IpOMQ27spWmThO1RwwdQVbJahJM= -golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ= +golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/bandwidthlimiter/bandwidthlimiter.go b/pkg/bandwidthlimiter/bandwidthlimiter.go index 29f06f8..ac68e4c 100644 --- a/pkg/bandwidthlimiter/bandwidthlimiter.go +++ b/pkg/bandwidthlimiter/bandwidthlimiter.go @@ -4,11 +4,12 @@ import ( "context" "oneway-filesync/pkg/structs" - "go.uber.org/ratelimit" + "github.com/sirupsen/logrus" + "golang.org/x/time/rate" ) type bandwidthLimiterConfig struct { - rl ratelimit.Limiter + rl *rate.Limiter input chan *structs.Chunk output chan *structs.Chunk } @@ -19,15 +20,18 @@ func worker(ctx context.Context, conf *bandwidthLimiterConfig) { case <-ctx.Done(): return case buf := <-conf.input: - conf.rl.Take() - conf.output <- buf + if err := conf.rl.WaitN(ctx, len(buf.Data)); err != nil { + logrus.Error(err) + } else { + conf.output <- buf + } } } } -func CreateBandwidthLimiter(ctx context.Context, chunks_per_sec int, input chan *structs.Chunk, output chan *structs.Chunk) { +func CreateBandwidthLimiter(ctx context.Context, bandwidth int, chunksize int, input chan *structs.Chunk, output chan *structs.Chunk) { conf := bandwidthLimiterConfig{ - rl: ratelimit.New(chunks_per_sec), + rl: rate.NewLimiter(rate.Limit(bandwidth), chunksize), input: input, output: output, } diff --git a/pkg/bandwidthlimiter/bandwidthlimiter_test.go b/pkg/bandwidthlimiter/bandwidthlimiter_test.go index 87f6e53..43c8662 100644 --- a/pkg/bandwidthlimiter/bandwidthlimiter_test.go +++ b/pkg/bandwidthlimiter/bandwidthlimiter_test.go @@ -10,31 +10,37 @@ import ( func TestCreateBandwidthLimiter(t *testing.T) { type args struct { - chunks int - chunks_per_sec int + chunk_count int + chunk_size int + bytes_per_sec int } tests := []struct { name string args args }{ - {name: "test1", args: args{chunks: 100, chunks_per_sec: 10}}, + {name: "test1", args: args{chunk_count: 100, chunk_size: 8000, bytes_per_sec: 240000}}, + {name: "test2", args: args{chunk_count: 300, chunk_size: 8000, bytes_per_sec: 240000}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - expected := float64(tt.args.chunks) / float64(tt.args.chunks_per_sec) - ch_in := make(chan *structs.Chunk, tt.args.chunks) - ch_out := make(chan *structs.Chunk, tt.args.chunks) - for i := 0; i < tt.args.chunks; i++ { - ch_in <- &structs.Chunk{} + expected := (float64(tt.args.chunk_size) / float64(tt.args.bytes_per_sec)) * float64(tt.args.chunk_count) + ch_in := make(chan *structs.Chunk, tt.args.chunk_count) + ch_out := make(chan *structs.Chunk, tt.args.chunk_count) + + chunk := structs.Chunk{Data: make([]byte, tt.args.chunk_size)} + for i := 0; i < tt.args.chunk_count; i++ { + ch_in <- &chunk } + ctx, cancel := context.WithCancel(context.Background()) start := time.Now() - bandwidthlimiter.CreateBandwidthLimiter(ctx, tt.args.chunks_per_sec, ch_in, ch_out) - for i := 0; i < tt.args.chunks; i++ { + bandwidthlimiter.CreateBandwidthLimiter(ctx, tt.args.bytes_per_sec, tt.args.chunk_size, ch_in, ch_out) + for i := 0; i < tt.args.chunk_count; i++ { <-ch_out } timepast := time.Since(start) - if timepast > time.Duration(expected+1)*time.Second || timepast < time.Duration(expected-1)*time.Second { + + if timepast > time.Duration(expected*1.2)*time.Second || timepast < time.Duration(expected*0.8)*time.Second { t.Fatalf("Bandwidthlimiter took %f seconds instead of %f", timepast.Seconds(), expected) } cancel() diff --git a/pkg/config/config.go b/pkg/config/config.go index 78c53e5..9b7bbb0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -12,6 +12,7 @@ type Config struct { ChunkFecRequired int ChunkFecTotal int OutDir string + WatchDir string } func GetConfig(file string) (Config, error) { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..4215fc6 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,72 @@ +package config_test + +import ( + "oneway-filesync/pkg/config" + "os" + "reflect" + "testing" +) + +func TestGetConfig(t *testing.T) { + type args struct { + configtext string + } + tests := []struct { + name string + args args + want config.Config + wantErr bool + }{ + { + name: "test1", + args: args{configtext: ` + ReceiverIP = "127.0.0.1" + ReceiverPort = 5000 + BandwidthLimit = 10000000 + ChunkSize = 8192 + ChunkFecRequired = 5 + ChunkFecTotal = 10 + OutDir = "./out" + WatchDir = "./tmp"`}, + want: config.Config{ + ReceiverIP: "127.0.0.1", + ReceiverPort: 5000, + BandwidthLimit: 10000000, + ChunkSize: 8192, + ChunkFecRequired: 5, + ChunkFecTotal: 10, + OutDir: "./out", + WatchDir: "./tmp", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var filename string + func() { + f, err := os.CreateTemp("", "") + if (err != nil) != tt.wantErr { + t.Errorf("CreateTemp() error = %v, wantErr %v", err, tt.wantErr) + return + } + defer f.Close() + filename = f.Name() + _, err = f.WriteString(tt.args.configtext) + if (err != nil) != tt.wantErr { + t.Errorf("WriteString() error = %v, wantErr %v", err, tt.wantErr) + return + } + }() + defer os.Remove(filename) + got, err := config.GetConfig(filename) + if (err != nil) != tt.wantErr { + t.Errorf("GetConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetConfig() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 2ecf5b2..1abface 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -26,20 +26,27 @@ type ReceivedFile struct { const DBFILE = "gorm.db" +func configureDatabase(db *gorm.DB) error { + return db.AutoMigrate(&File{}) +} + // Opens a connection to the database, // eventually we can choose to receive the user, password, host, database name // from the the configuration file, because we expect this database to be run locally // we leave it as defaults for now. func OpenDatabase(tableprefix string) (*gorm.DB, error) { - return gorm.Open(sqlite.Open(DBFILE), + db, err := gorm.Open(sqlite.Open(DBFILE), &gorm.Config{ NamingStrategy: schema.NamingStrategy{TablePrefix: tableprefix}, Logger: gormlogger.Discard, }) -} - -func ConfigureDatabase(db *gorm.DB) error { - return db.AutoMigrate(&File{}) + if err != nil { + return nil, err + } + if err = configureDatabase(db); err != nil { + return nil, err + } + return db, nil } func ClearDatabase(db *gorm.DB) error { @@ -79,6 +86,5 @@ func QueueFileForSending(db *gorm.DB, path string) error { Success: false, } - result := db.Create(&file) - return result.Error + return db.Create(&file).Error } diff --git a/pkg/fecdecoder/fecdecoder.go b/pkg/fecdecoder/fecdecoder.go index edaa18d..96105a9 100644 --- a/pkg/fecdecoder/fecdecoder.go +++ b/pkg/fecdecoder/fecdecoder.go @@ -19,8 +19,7 @@ type fecDecoderConfig struct { func worker(ctx context.Context, conf *fecDecoderConfig) { fec, err := reedsolomon.New(conf.required, conf.total-conf.required) if err != nil { - logrus.Errorf("Error creating fec object: %v", err) - return + logrus.Fatalf("Error creating fec object: %v", err) } for { select { diff --git a/pkg/fecencoder/fecencoder.go b/pkg/fecencoder/fecencoder.go index bc1ef15..e5abdbb 100644 --- a/pkg/fecencoder/fecencoder.go +++ b/pkg/fecencoder/fecencoder.go @@ -25,8 +25,7 @@ type fecEncoderConfig struct { func worker(ctx context.Context, conf *fecEncoderConfig) { fec, err := reedsolomon.New(conf.required, conf.total-conf.required) if err != nil { - logrus.Errorf("Error creating fec object: %v", err) - return + logrus.Fatalf("Error creating fec object: %v", err) } for { select { diff --git a/pkg/filecloser/filecloser.go b/pkg/filecloser/filecloser.go index ee50fb6..d97cbe1 100644 --- a/pkg/filecloser/filecloser.go +++ b/pkg/filecloser/filecloser.go @@ -2,6 +2,7 @@ package filecloser import ( "context" + "errors" "fmt" "oneway-filesync/pkg/database" "oneway-filesync/pkg/structs" @@ -22,6 +23,50 @@ func normalizePath(path string) string { return filepath.Join(strings.Split(newpath, "/")...) } } +func closeFile(file *structs.OpenTempFile, outdir string) error { + l := logrus.WithFields(logrus.Fields{ + "TempFile": file.TempFile, + "Path": file.Path, + "Hash": fmt.Sprintf("%x", file.Hash), + }) + + f, err := os.Open(file.TempFile) + if err != nil { + l.Errorf("Error opening tempfile: %v", err) + return err + } + + hash, err := structs.HashFile(f) + err2 := f.Close() + if err != nil { + l.Errorf("Error hashing tempfile: %v", err) + return err + } + if err2 != nil { + l.Errorf("Error closing tempfile: %v", err2) + // Not returning error on purpose + } + if hash != file.Hash { + l.WithField("TempFileHash", fmt.Sprintf("%x", hash)).Errorf("Hash mismatch") + return errors.New("hash mismatch") + } + + newpath := filepath.Join(outdir, normalizePath(file.Path)) + err = os.MkdirAll(filepath.Dir(newpath), os.ModePerm) + if err != nil { + l.Errorf("Failed creating directory path: %v", err) + return err + } + + err = os.Rename(file.TempFile, newpath) + if err != nil { + l.Errorf("Failed moving tempfile to new location: %v", err) + return err + } + + l.WithField("NewPath", newpath).Infof("Successfully finished writing file") + return nil +} type fileCloserConfig struct { db *gorm.DB @@ -42,83 +87,19 @@ func worker(ctx context.Context, conf *fileCloserConfig) { Finished: true, } - f, err := os.Open(file.TempFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "TempFile": file.TempFile, - "Path": file.Path, - "Hash": fmt.Sprintf("%x", file.Hash), - }).Errorf("Error opening tempfile: %v", err) - dbentry.Success = false - conf.db.Save(&dbentry) - continue - } - - hash, err := structs.HashFile(f) - err2 := f.Close() + err := closeFile(file, conf.outdir) if err != nil { - logrus.WithFields(logrus.Fields{ - "TempFile": file.TempFile, - "Path": file.Path, - "Hash": fmt.Sprintf("%x", file.Hash), - }).Errorf("Error hashing tempfile: %v", err) dbentry.Success = false - conf.db.Save(&dbentry) - continue + } else { + dbentry.Success = true } - if hash != file.Hash { - logrus.WithFields(logrus.Fields{ - "TempFile": file.TempFile, - "Path": file.Path, - "Hash": fmt.Sprintf("%x", file.Hash), - "TempFileHash": fmt.Sprintf("%x", hash), - }).Errorf("Hash mismatch") - dbentry.Success = false - conf.db.Save(&dbentry) - continue - } - if err2 != nil { + if err := conf.db.Save(&dbentry).Error; err != nil { logrus.WithFields(logrus.Fields{ "TempFile": file.TempFile, "Path": file.Path, "Hash": fmt.Sprintf("%x", file.Hash), - }).Errorf("Error ckisubg tempfile: %v", err) + }).Errorf("Failed committing to db: %v", err) } - - newpath := filepath.Join(conf.outdir, normalizePath(file.Path)) - err = os.MkdirAll(filepath.Dir(newpath), os.ModePerm) - if err != nil { - logrus.WithFields(logrus.Fields{ - "TempFile": file.TempFile, - "Path": file.Path, - "Hash": fmt.Sprintf("%x", file.Hash), - "TempFileHash": fmt.Sprintf("%x", hash), - }).Errorf("Failed creating directory path: %v", err) - dbentry.Success = false - conf.db.Save(&dbentry) - continue - } - - err = os.Rename(file.TempFile, newpath) - if err != nil { - logrus.WithFields(logrus.Fields{ - "TempFile": file.TempFile, - "Path": file.Path, - "Hash": fmt.Sprintf("%x", file.Hash), - "TempFileHash": fmt.Sprintf("%x", hash), - }).Errorf("Failed moving tempfile to new location: %v", err) - dbentry.Success = false - conf.db.Save(&dbentry) - continue - } - - logrus.WithFields(logrus.Fields{ - "Path": file.Path, - "Hash": fmt.Sprintf("%x", file.Hash), - "NewPath": newpath, - }).Infof("Successfully finished writing file") - dbentry.Success = true - conf.db.Save(&dbentry) } } } diff --git a/pkg/sender/sender.go b/pkg/sender/sender.go index beb5b5b..81ebf01 100644 --- a/pkg/sender/sender.go +++ b/pkg/sender/sender.go @@ -25,6 +25,6 @@ func Sender(ctx context.Context, db *gorm.DB, conf config.Config) { queuereader.CreateQueueReader(ctx, db, queue_chan) filereader.CreateFileReader(ctx, db, conf.ChunkSize, conf.ChunkFecRequired, queue_chan, chunks_chan, maxprocs) fecencoder.CreateFecEncoder(ctx, conf.ChunkSize, conf.ChunkFecRequired, conf.ChunkFecTotal, chunks_chan, shares_chan, maxprocs) - bandwidthlimiter.CreateBandwidthLimiter(ctx, conf.BandwidthLimit/conf.ChunkSize, shares_chan, bw_limited_chunks) + bandwidthlimiter.CreateBandwidthLimiter(ctx, conf.BandwidthLimit, conf.ChunkSize, shares_chan, bw_limited_chunks) udpsender.CreateUdpSender(ctx, conf.ReceiverIP, conf.ReceiverPort, bw_limited_chunks, maxprocs) } diff --git a/pkg/udpreceiver/udpreceiver.go b/pkg/udpreceiver/udpreceiver.go index e7af94f..4d71d2f 100644 --- a/pkg/udpreceiver/udpreceiver.go +++ b/pkg/udpreceiver/udpreceiver.go @@ -6,7 +6,6 @@ import ( "net" "oneway-filesync/pkg/structs" "oneway-filesync/pkg/utils" - "runtime" "time" "github.com/sirupsen/logrus" @@ -19,11 +18,6 @@ type udpReceiverConfig struct { } func manager(ctx context.Context, conf *udpReceiverConfig) { - if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { - logrus.Infof("Buffers fill detection not supported on the current OS") - return - } - ticker := time.NewTicker(200 * time.Millisecond) rawconn, err := conf.conn.SyscallConn() if err != nil { @@ -88,8 +82,7 @@ func CreateUdpReceiver(ctx context.Context, ip string, port int, chunksize int, conn, err := net.ListenUDP("udp", &addr) if err != nil { - logrus.Errorf("Error creating udp socket: %v", err) - return + logrus.Fatalf("Error creating udp socket: %v", err) } go func() { <-ctx.Done() diff --git a/pkg/udpsender/udpsender.go b/pkg/udpsender/udpsender.go index c5be83c..846c250 100644 --- a/pkg/udpsender/udpsender.go +++ b/pkg/udpsender/udpsender.go @@ -18,8 +18,7 @@ type udpSenderConfig struct { func worker(ctx context.Context, conf *udpSenderConfig) { conn, err := net.Dial("udp", fmt.Sprintf("%s:%d", conf.ip, conf.port)) if err != nil { - logrus.Errorf("Error creating udp socket: %v", err) - return + logrus.Fatalf("Error creating udp socket: %v", err) } defer conn.Close() for { diff --git a/pkg/utils/darwin_ioctl.go b/pkg/utils/darwin_ioctl.go new file mode 100644 index 0000000..3d6b993 --- /dev/null +++ b/pkg/utils/darwin_ioctl.go @@ -0,0 +1,26 @@ +//go:build darwin + +package utils + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +const FIONREAD uint = 0x4004667f + +func GetAvailableBytes(rawconn syscall.RawConn) (int, error) { + var err error + var avail int + err2 := rawconn.Control(func(fd uintptr) { + avail, err = unix.IoctlGetInt(int(fd), FIONREAD) + }) + if err2 != nil { + return 0, err2 + } + if err != nil { + return 0, err + } + return avail, nil +} diff --git a/pkg/utils/linux_ioctl.go b/pkg/utils/linux_ioctl.go new file mode 100644 index 0000000..6822973 --- /dev/null +++ b/pkg/utils/linux_ioctl.go @@ -0,0 +1,51 @@ +//go:build linux + +package utils + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + "syscall" +) + +// Under linux FIONREAD returns the size of the waiting datagram if one exists and not the total available bytes +// See: https://manpages.debian.org/bullseye/manpages/udp.7.en.html#FIONREAD +// Sadly the only way to get the available bytes under linux is through proc/udp +func GetAvailableBytes(rawconn syscall.RawConn) (int, error) { + var err error + var link string + err2 := rawconn.Control(func(fd uintptr) { + link, err = os.Readlink(fmt.Sprintf("/proc/%d/fd/%d", os.Getpid(), int(fd))) + }) + if err2 != nil { + return 0, err2 + } + if err != nil { + return 0, err + } + + parts := strings.Split(link, ":[") + if parts[0] != "socket" { + return 0, errors.New("failed parsing /proc//fd/ link") + } + + inode, err := strconv.ParseUint(parts[1][:len(parts[1])-1], 0, 64) + if err != nil { + return 0, err + } + + netudp, err := GetNetUDP() + if err != nil { + return 0, err + } + for _, l := range netudp { + if l.Inode == inode { + // The division by 2 is due to the same overehead mentioned in SO_RCVBUF + return int(l.RxQueue / 2), nil + } + } + return 0, errors.New("socket inode was not found in proc/net/udp") +} diff --git a/pkg/utils/linux_procudp.go b/pkg/utils/linux_procudp.go new file mode 100644 index 0000000..f8d081f --- /dev/null +++ b/pkg/utils/linux_procudp.go @@ -0,0 +1,201 @@ +// Copyright 2020 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// FROM: https://github.com/prometheus/procfs/blob/master/net_ip_socket.go + +package utils + +import ( + "bufio" + "encoding/hex" + "fmt" + "io" + "net" + "os" + "strconv" + "strings" +) + +const ( + // readLimit is used by io.LimitReader while reading the content of the + // /proc/net/udp{,6} files. The number of lines inside such a file is dynamic + // as each line represents a single used socket. + // In theory, the number of available sockets is 65535 (2^16 - 1) per IP. + // With e.g. 150 Byte per line and the maximum number of 65535, + // the reader needs to handle 150 Byte * 65535 =~ 10 MB for a single IP. + readLimit = 4294967296 // Byte -> 4 GiB +) + +// This contains generic data structures for both udp and tcp sockets. +type ( + // NetIPSocket represents the contents of /proc/net/{t,u}dp{,6} file without the header. + NetIPSocket []*netIPSocketLine + + // netIPSocketLine represents the fields parsed from a single line + // in /proc/net/{t,u}dp{,6}. Fields which are not used by IPSocket are skipped. + // For the proc file format details, see https://linux.die.net/man/5/proc. + netIPSocketLine struct { + Sl uint64 + LocalAddr net.IP + LocalPort uint64 + RemAddr net.IP + RemPort uint64 + St uint64 + TxQueue uint64 + RxQueue uint64 + UID uint64 + Inode uint64 + } + + // NetUDP represents the contents of /proc/net/udp{,6} file without the header. + NetUDP []*netIPSocketLine +) + +func newNetIPSocket(file string) (NetIPSocket, error) { + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + var netIPSocket NetIPSocket + + lr := io.LimitReader(f, readLimit) + s := bufio.NewScanner(lr) + s.Scan() // skip first line with headers + for s.Scan() { + fields := strings.Fields(s.Text()) + line, err := parseNetIPSocketLine(fields) + if err != nil { + return nil, err + } + netIPSocket = append(netIPSocket, line) + } + if err := s.Err(); err != nil { + return nil, err + } + return netIPSocket, nil +} + +// the /proc/net/{t,u}dp{,6} files are network byte order for ipv4 and for ipv6 the address is four words consisting of four bytes each. In each of those four words the four bytes are written in reverse order. + +func parseIP(hexIP string) (net.IP, error) { + var byteIP []byte + byteIP, err := hex.DecodeString(hexIP) + if err != nil { + return nil, fmt.Errorf("cannot parse address field in socket line %q", hexIP) + } + switch len(byteIP) { + case 4: + return net.IP{byteIP[3], byteIP[2], byteIP[1], byteIP[0]}, nil + case 16: + i := net.IP{ + byteIP[3], byteIP[2], byteIP[1], byteIP[0], + byteIP[7], byteIP[6], byteIP[5], byteIP[4], + byteIP[11], byteIP[10], byteIP[9], byteIP[8], + byteIP[15], byteIP[14], byteIP[13], byteIP[12], + } + return i, nil + default: + return nil, fmt.Errorf("unable to parse IP %s", hexIP) + } +} + +// parseNetIPSocketLine parses a single line, represented by a list of fields. +func parseNetIPSocketLine(fields []string) (*netIPSocketLine, error) { + line := &netIPSocketLine{} + if len(fields) < 10 { + return nil, fmt.Errorf( + "cannot parse net socket line as it has less then 10 columns %q", + strings.Join(fields, " "), + ) + } + var err error // parse error + + // sl + s := strings.Split(fields[0], ":") + if len(s) != 2 { + return nil, fmt.Errorf("cannot parse sl field in socket line %q", fields[0]) + } + + if line.Sl, err = strconv.ParseUint(s[0], 0, 64); err != nil { + return nil, fmt.Errorf("cannot parse sl value in socket line: %w", err) + } + // local_address + l := strings.Split(fields[1], ":") + if len(l) != 2 { + return nil, fmt.Errorf("cannot parse local_address field in socket line %q", fields[1]) + } + if line.LocalAddr, err = parseIP(l[0]); err != nil { + return nil, err + } + if line.LocalPort, err = strconv.ParseUint(l[1], 16, 64); err != nil { + return nil, fmt.Errorf("cannot parse local_address port value in socket line: %w", err) + } + + // remote_address + r := strings.Split(fields[2], ":") + if len(r) != 2 { + return nil, fmt.Errorf("cannot parse rem_address field in socket line %q", fields[1]) + } + if line.RemAddr, err = parseIP(r[0]); err != nil { + return nil, err + } + if line.RemPort, err = strconv.ParseUint(r[1], 16, 64); err != nil { + return nil, fmt.Errorf("cannot parse rem_address port value in socket line: %w", err) + } + + // st + if line.St, err = strconv.ParseUint(fields[3], 16, 64); err != nil { + return nil, fmt.Errorf("cannot parse st value in socket line: %w", err) + } + + // tx_queue and rx_queue + q := strings.Split(fields[4], ":") + if len(q) != 2 { + return nil, fmt.Errorf( + "cannot parse tx/rx queues in socket line as it has a missing colon %q", + fields[4], + ) + } + if line.TxQueue, err = strconv.ParseUint(q[0], 16, 64); err != nil { + return nil, fmt.Errorf("cannot parse tx_queue value in socket line: %w", err) + } + if line.RxQueue, err = strconv.ParseUint(q[1], 16, 64); err != nil { + return nil, fmt.Errorf("cannot parse rx_queue value in socket line: %w", err) + } + + // uid + if line.UID, err = strconv.ParseUint(fields[7], 0, 64); err != nil { + return nil, fmt.Errorf("cannot parse uid value in socket line: %w", err) + } + + // inode + if line.Inode, err = strconv.ParseUint(fields[9], 0, 64); err != nil { + return nil, fmt.Errorf("cannot parse inode value in socket line: %w", err) + } + + return line, nil +} + +// newNetUDP creates a new NetUDP{,6} from the contents of the given file. +func newNetUDP(file string) (NetUDP, error) { + n, err := newNetIPSocket(file) + n1 := NetUDP(n) + return n1, err +} + +// NetUDP returns the IPv4 kernel/networking statistics for UDP datagrams +// read from /proc/net/udp. +func GetNetUDP() (NetUDP, error) { + return newNetUDP(("/proc/net/udp")) +} diff --git a/pkg/utils/unix_ioctl.go b/pkg/utils/unix_ioctl.go index 54d5e8d..3df03c8 100644 --- a/pkg/utils/unix_ioctl.go +++ b/pkg/utils/unix_ioctl.go @@ -3,13 +3,25 @@ package utils import ( - "errors" "runtime" "syscall" "golang.org/x/sys/unix" ) +// Removed CtrlC test due to: https://github.com/golang/go/issues/46354 +// func sendCtrlC(pid int) error { +// p, err := os.FindProcess(pid) +// if err != nil { +// return err +// } +// err = p.Signal(os.Interrupt) +// if err != nil { +// return err +// } +// return nil +// } + func GetReadBuffer(rawconn syscall.RawConn) (int, error) { var err error var bufsize int @@ -22,29 +34,10 @@ func GetReadBuffer(rawconn syscall.RawConn) (int, error) { if err != nil { return 0, err } - return bufsize, nil -} - -func GetAvailableBytes(rawconn syscall.RawConn) (int, error) { - var FIONREAD uint = 0 if runtime.GOOS == "linux" { - FIONREAD = 0x541B - } else if runtime.GOOS == "darwin" { - FIONREAD = 0x4004667f + // See https://man7.org/linux/man-pages/man7/socket.7.html SO_RCVBUF + return bufsize / 2, nil } else { - return 0, errors.New("unsupported OS") + return bufsize, nil } - var err error - var avail int - err2 := rawconn.Control(func(fd uintptr) { - avail, err = unix.IoctlGetInt(int(fd), FIONREAD) - }) - if err2 != nil { - return 0, err2 - } - if err != nil { - return 0, err - } - return avail, nil - } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index a89edb1..1ff70de 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -4,16 +4,29 @@ import ( "fmt" "io" "os" + "os/signal" "runtime" "strings" "sync" + "syscall" "github.com/sirupsen/logrus" ) func formatFilePath(path string) string { - arr := strings.Split(path, "/") - return arr[len(arr)-1] + if strings.Contains(path, "\\") { + arr := strings.Split(path, "\\") + return arr[len(arr)-1] + } else { + arr := strings.Split(path, "/") + return arr[len(arr)-1] + } +} + +func CtrlC() chan os.Signal { + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + return done } func InitializeLogging(logFile string) { diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 0000000..d6b1e93 --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,168 @@ +package utils + +import ( + "fmt" + "math/rand" + "net" + "os" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +func Test_formatFilePath(t *testing.T) { + type args struct { + path string + } + tests := []struct { + name string + args args + want string + }{ + {"test1", args{"/a/b/c/d.tmp"}, "d.tmp"}, + {"test2", args{"C:\\a\\b\\c\\d.tmp"}, "d.tmp"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatFilePath(tt.args.path); got != tt.want { + t.Errorf("formatFilePath() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestInitializeLogging(t *testing.T) { + type args struct { + logFile string + } + tests := []struct { + name string + args args + }{ + {"test1", args{"logfile.txt"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + InitializeLogging(tt.args.logFile) + if _, err := os.Stat(tt.args.logFile); os.IsExist(err) { + t.Fatal("logfile did not create") + } + os.Remove(tt.args.logFile) + }) + } +} + +// Removed CtrlC test due to: https://github.com/golang/go/issues/46354 +// func TestCtrlC(t *testing.T) { +// ch := CtrlC() +// err := sendCtrlC(os.Getpid()) +// if err != nil { +// t.Fatal(err) +// } +// _, ok := <-ch +// if !ok { +// t.Fatal("Ctrl c not caught") +// } +// } + +func TestGetReadBuffer(t *testing.T) { + ip := "127.0.0.1" + port := rand.Intn(30000) + 30000 + addr := net.UDPAddr{ + IP: net.ParseIP(ip), + Port: port, + } + + conn, err := net.ListenUDP("udp", &addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + rawconn, err := conn.SyscallConn() + if err != nil { + t.Fatal(err) + } + + type args struct { + bufsize int + } + tests := []struct { + name string + args args + want int + wantErr bool + }{ + {"test1", args{8 * 1024}, 8 * 1024, false}, + {"test2", args{100 * 1024}, 100 * 1024, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := conn.SetReadBuffer(tt.args.bufsize) + if err != nil { + t.Error(err) + return + } + time.Sleep(300 * time.Millisecond) + + got, err := GetReadBuffer(rawconn) + if (err != nil) != tt.wantErr { + t.Errorf("GetReadBuffer() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetReadBuffer() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetAvailableBytes(t *testing.T) { + ip := "127.0.0.1" + port := rand.Intn(30000) + 30000 + addr := net.UDPAddr{ + IP: net.ParseIP(ip), + Port: port, + } + + receiving_conn, err := net.ListenUDP("udp", &addr) + if err != nil { + t.Fatal(err) + } + defer receiving_conn.Close() + + sending_conn, err := net.Dial("udp", fmt.Sprintf("%s:%d", ip, port)) + if err != nil { + logrus.Errorf("Error creating udp socket: %v", err) + return + } + defer sending_conn.Close() + + rawconn, err := receiving_conn.SyscallConn() + if err != nil { + t.Fatal(err) + } + + chunksize := 8192 + chunk := make([]byte, chunksize) + for i := 0; i < 5; i++ { + expected := (i + 1) * chunksize + _, err := sending_conn.Write(chunk) + if err != nil { + t.Error(err) + return + } + time.Sleep(300 * time.Millisecond) + + avail, err := GetAvailableBytes(rawconn) + if err != nil { + t.Errorf("GetAvailableBytes() error = %v", err) + return + } + if avail < expected { + t.Errorf("GetAvailableBytes() = %v, want %v", avail, expected) + return + } + } +} diff --git a/pkg/utils/windows_ioctl.go b/pkg/utils/windows_ioctl.go index e643e4e..175ce02 100644 --- a/pkg/utils/windows_ioctl.go +++ b/pkg/utils/windows_ioctl.go @@ -3,14 +3,73 @@ package utils import ( - "errors" "syscall" + "unsafe" ) +var ( + ws2_32 = syscall.NewLazyDLL("ws2_32.dll") + ioctlsocket = ws2_32.NewProc("ioctlsocket") + // kernel32 = syscall.NewLazyDLL("kernel32.dll") + // generateConsoleCtrlEvent = kernel32.NewProc("GenerateConsoleCtrlEvent") +) + +const FIONREAD int32 = 0x4004667f + +func ioctlSocket(s syscall.Handle, cmd int32) (int, error) { + v := uint32(0) + rc, _, err := ioctlsocket.Call(uintptr(s), uintptr(cmd), uintptr(unsafe.Pointer(&v))) + if rc == 0 { + return int(v), nil + } else { + return 0, err + } +} + +func getsockoptInt(fd syscall.Handle, level, opt int) (int, error) { + v := int32(0) + l := int32(unsafe.Sizeof(v)) + err := syscall.Getsockopt(fd, int32(level), int32(opt), (*byte)(unsafe.Pointer(&v)), &l) + return int(v), err +} + +// Removed CtrlC test due to: https://github.com/golang/go/issues/46354 +// func sendCtrlC(pid int) error { +// r, _, e := generateConsoleCtrlEvent.Call(syscall.CTRL_C_EVENT, uintptr(pid)) +// if r == 0 { +// return e +// } else { +// return nil +// } +// } + func GetReadBuffer(rawconn syscall.RawConn) (int, error) { - return 0, errors.New("unsupported OS") + var err error + var bufsize int + err2 := rawconn.Control(func(fd uintptr) { + bufsize, err = getsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) + }) + if err2 != nil { + return 0, err2 + } + if err != nil { + return 0, err + } + return bufsize, nil } func GetAvailableBytes(rawconn syscall.RawConn) (int, error) { - return 0, errors.New("unsupported OS") + var err error + var avail int + err2 := rawconn.Control(func(fd uintptr) { + avail, err = ioctlSocket(syscall.Handle(fd), FIONREAD) + }) + if err2 != nil { + return 0, err2 + } + if err != nil { + return 0, err + } + return avail, nil + } diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index ef23174..2d583c2 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -2,7 +2,9 @@ package watcher import ( "context" + "oneway-filesync/pkg/config" "oneway-filesync/pkg/database" + "os" "path/filepath" "time" @@ -11,26 +13,39 @@ import ( "gorm.io/gorm" ) +func isDirectory(path string) (bool, error) { + fileInfo, err := os.Stat(path) + if err != nil { + return false, err + } + + return fileInfo.IsDir(), err +} + type watcherConfig struct { db *gorm.DB input chan notify.EventInfo cache map[string]time.Time } -// To save up on resources we only send files that haven't changed for the past 60 seconds +// To save up on resources we only send files that haven't changed for the past 30 seconds // otherwise many consecutive small changes will cause a large overhead on the sender/receiver func worker(ctx context.Context, conf *watcherConfig) { - ticker := time.NewTicker(15 * time.Second) + ticker := time.NewTicker(10 * time.Second) for { select { case <-ctx.Done(): notify.Stop(conf.input) return case ei := <-conf.input: - conf.cache[ei.Path()] = time.Now() + isdir, err := isDirectory(ei.Path()) + if err == nil && !isdir { + conf.cache[ei.Path()] = time.Now() + logrus.Infof("Noticed change in file '%s'", ei.Path()) + } case <-ticker.C: for path, lastupdated := range conf.cache { - if time.Since(lastupdated).Seconds() > 60 { + if time.Since(lastupdated).Seconds() > 30 { delete(conf.cache, path) err := database.QueueFileForSending(conf.db, path) if err != nil { @@ -46,8 +61,7 @@ func worker(ctx context.Context, conf *watcherConfig) { func CreateWatcher(ctx context.Context, db *gorm.DB, watchdir string, input chan notify.EventInfo) { if err := notify.Watch(filepath.Join(watchdir, "..."), input, notify.Write, notify.Create); err != nil { - logrus.Errorf("%v", err) - return + logrus.Fatalf("%v", err) } conf := watcherConfig{ db: db, @@ -56,3 +70,9 @@ func CreateWatcher(ctx context.Context, db *gorm.DB, watchdir string, input chan } go worker(ctx, &conf) } + +func Watcher(ctx context.Context, db *gorm.DB, conf config.Config) { + events := make(chan notify.EventInfo, 500) + + CreateWatcher(ctx, db, conf.WatchDir, events) +} diff --git a/tests/system_test.go b/tests/system_test.go index 4a8d009..4419322 100644 --- a/tests/system_test.go +++ b/tests/system_test.go @@ -10,6 +10,7 @@ import ( "oneway-filesync/pkg/database" "oneway-filesync/pkg/receiver" "oneway-filesync/pkg/sender" + "oneway-filesync/pkg/watcher" "os" "path/filepath" "strings" @@ -68,12 +69,11 @@ func pathReplace(path string) string { return newpath } -func waitForFinishedFile(t *testing.T, db *gorm.DB, path string, timeout time.Duration, outdir string) { - start := time.Now() +func waitForFinishedFile(t *testing.T, db *gorm.DB, path string, endtime time.Time, outdir string) { ticker := time.NewTicker(1 * time.Second) for { <-ticker.C - if time.Since(start) > timeout { + if time.Now().After(endtime) { t.Fatalf("File '%s' did not transfer in time", path) } var file database.File @@ -86,13 +86,14 @@ func waitForFinishedFile(t *testing.T, db *gorm.DB, path string, timeout time.Du diff := getDiff(t, path, tmpfilepath) t.Fatalf("File '%s' transferred but not successfully %d different bytes", path, diff) } else { + t.Logf("File '%s' transferred successfully", path) return } } } -func tempFile(t *testing.T, size int) string { - file, err := os.CreateTemp("", "") +func tempFile(t *testing.T, size int, tmpdir string) string { + file, err := os.CreateTemp(tmpdir, "") if err != nil { log.Fatal(err) } @@ -101,7 +102,11 @@ func tempFile(t *testing.T, size int) string { if err != nil { log.Fatal(err) } - return file.Name() + tempfilepath, err := filepath.Abs(file.Name()) + if err != nil { + log.Fatal(err) + } + return tempfilepath } func setupTest(t *testing.T, conf config.Config) (*gorm.DB, *gorm.DB, func()) { @@ -109,28 +114,31 @@ func setupTest(t *testing.T, conf config.Config) (*gorm.DB, *gorm.DB, func()) { if err != nil { t.Fatalf("Failed setting up db with err: %v\n", err) } - if err := database.ConfigureDatabase(senderdb); err != nil { - t.Fatalf("Failed setting up db with err: %v\n", err) - } receiverdb, err := database.OpenDatabase("t_r_") if err != nil { t.Fatalf("Failed setting up db with err: %v\n", err) } - if err := database.ConfigureDatabase(receiverdb); err != nil { - t.Fatalf("Failed setting up db with err: %v\n", err) - } if err := os.MkdirAll(conf.OutDir, os.ModePerm); err != nil { t.Fatalf("Failed creating outdir with err: %v\n", err) } + if err := os.MkdirAll(conf.WatchDir, os.ModePerm); err != nil { + t.Fatalf("Failed creating watchdir with err: %v\n", err) + } + ctx, cancel := context.WithCancel(context.Background()) // Create a cancelable context and pass it to all goroutines, allows us to gracefully shut down the program receiver.Receiver(ctx, receiverdb, conf) sender.Sender(ctx, senderdb, conf) + watcher.Watcher(ctx, senderdb, conf) return senderdb, receiverdb, func() { cancel() + time.Sleep(2 * time.Second) + if err := os.RemoveAll(conf.WatchDir); err != nil { + t.Log(err) + } if err := os.RemoveAll(conf.OutDir); err != nil { t.Log(err) } @@ -140,6 +148,16 @@ func setupTest(t *testing.T, conf config.Config) (*gorm.DB, *gorm.DB, func()) { if err := database.ClearDatabase(senderdb); err != nil { t.Log(err) } + if indb, err := receiverdb.DB(); err == nil { + if err := indb.Close(); err != nil { + t.Log(err) + } + } + if indb, err := senderdb.DB(); err == nil { + if err := indb.Close(); err != nil { + t.Log(err) + } + } if err := os.Remove(database.DBFILE); err != nil { t.Log(err) } @@ -155,6 +173,7 @@ func TestSetup(t *testing.T) { ChunkFecRequired: 5, ChunkFecTotal: 10, OutDir: "tests_out", + WatchDir: "tests_watch", }) defer teardowntest() } @@ -168,39 +187,90 @@ func TestSmallFile(t *testing.T) { ChunkFecRequired: 5, ChunkFecTotal: 10, OutDir: "tests_out", + WatchDir: "tests_watch", } senderdb, receiverdb, teardowntest := setupTest(t, conf) defer teardowntest() - testfile := tempFile(t, 500) + testfile := tempFile(t, 500, "") defer os.Remove(testfile) err := database.QueueFileForSending(senderdb, testfile) if err != nil { t.Fatal(err) } - waitForFinishedFile(t, receiverdb, testfile, time.Minute, conf.OutDir) + waitForFinishedFile(t, receiverdb, testfile, time.Now().Add(time.Minute), conf.OutDir) } func TestLargeFile(t *testing.T) { conf := config.Config{ ReceiverIP: "127.0.0.1", ReceiverPort: 5000, - BandwidthLimit: 4 * 1024 * 1024, + BandwidthLimit: 1024 * 1024, ChunkSize: 8192, ChunkFecRequired: 5, ChunkFecTotal: 10, OutDir: "tests_out", + WatchDir: "tests_watch", } senderdb, receiverdb, teardowntest := setupTest(t, conf) defer teardowntest() - testfile := tempFile(t, 50*1024*1024) + testfile := tempFile(t, 20*1024*1024, "") defer os.Remove(testfile) err := database.QueueFileForSending(senderdb, testfile) if err != nil { t.Fatal(err) } - waitForFinishedFile(t, receiverdb, testfile, time.Minute*2, conf.OutDir) + waitForFinishedFile(t, receiverdb, testfile, time.Now().Add(time.Minute*2), conf.OutDir) +} + +func TestWatcherFiles(t *testing.T) { + conf := config.Config{ + ReceiverIP: "127.0.0.1", + ReceiverPort: 5000, + BandwidthLimit: 1024 * 1024, + ChunkSize: 8192, + ChunkFecRequired: 5, + ChunkFecTotal: 10, + OutDir: "tests_out", + WatchDir: "tests_watch", + } + _, receiverdb, teardowntest := setupTest(t, conf) + defer teardowntest() + + for i := 0; i < 30; i++ { + tempfile := tempFile(t, 30000, conf.WatchDir) + defer os.Remove(tempfile) + defer waitForFinishedFile(t, receiverdb, tempfile, time.Now().Add(time.Minute*5), conf.OutDir) + } + tmpdir1 := filepath.Join(conf.WatchDir, "tmp1") + err := os.Mkdir(tmpdir1, os.ModePerm) + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpdir1) + time.Sleep(time.Second) + + for i := 0; i < 10; i++ { + tempfile := tempFile(t, 30000, tmpdir1) + defer os.Remove(tempfile) + defer waitForFinishedFile(t, receiverdb, tempfile, time.Now().Add(time.Minute*5), conf.OutDir) + } + + tmpdir2 := filepath.Join(tmpdir1, "tmp2") + err = os.Mkdir(tmpdir2, os.ModePerm) + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpdir2) + time.Sleep(time.Second) + + for i := 0; i < 10; i++ { + tempfile := tempFile(t, 30000, tmpdir2) + defer os.Remove(tempfile) + defer waitForFinishedFile(t, receiverdb, tempfile, time.Now().Add(time.Minute*5), conf.OutDir) + } + }