From 9795a85dd617f2941e7d6a7c26d433e1035b3f13 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Mon, 20 Jan 2025 11:09:38 +0800 Subject: [PATCH] refactor: file download process --- core/core.go | 157 ++++++++++++++------------------------------------ core/utils.go | 62 ++++++++++++++++++++ 2 files changed, 104 insertions(+), 115 deletions(-) create mode 100644 core/utils.go diff --git a/core/core.go b/core/core.go index d77b794..3f6fdbc 100644 --- a/core/core.go +++ b/core/core.go @@ -7,27 +7,32 @@ import ( "io" "os" "path/filepath" - "time" "github.com/celestix/gotgproto/ext" "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/bot" - "github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/logger" "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/types" ) func processPendingTask(task *types.Task) error { logger.L.Debugf("Start processing task: %s", task.String()) os.MkdirAll(config.Cfg.Temp.BasePath, os.ModePerm) - task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ + + ctx := task.Ctx.(*ext.Context) + ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ Message: "正在下载: " + task.String(), ID: task.ReplyMessageID, }) + destPath := filepath.Join(config.Cfg.Temp.BasePath, task.File.FileName) + if task.StoragePath == "" { + task.StoragePath = task.File.FileName + } + + // process photo if task.File.FileSize == 0 { res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{ Location: task.File.Location, @@ -37,105 +42,55 @@ func processPendingTask(task *types.Task) error { if err != nil { return fmt.Errorf("Failed to get file: %w", err) } - switch result := res.(type) { - case *tg.UploadFile: - dest, err := os.Create(filepath.Join(config.Cfg.Temp.BasePath, task.File.FileName)) - if err != nil { - return fmt.Errorf("Failed to create file: %w", err) - } - defer dest.Close() - destName := dest.Name() - if err := os.WriteFile(destName, result.Bytes, os.ModePerm); err != nil { - return fmt.Errorf("Failed to write file: %w", err) - } + result, ok := res.(*tg.UploadFile) + if !ok { + return fmt.Errorf("unexpected type %T", res) + } - defer func() { - if config.Cfg.Temp.CacheTTL > 0 { - common.RmFileAfter(destName, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second) - } else { - if err := os.Remove(destName); err != nil { - logger.L.Errorf("Failed to purge file: %s", err) - } - } - }() + if err := os.WriteFile(destPath, result.Bytes, os.ModePerm); err != nil { + return fmt.Errorf("Failed to write file: %w", err) + } - if task.StoragePath == "" { - task.StoragePath = task.File.FileName - } + defer cleanCacheFile(destPath) - logger.L.Infof("Downloaded file: %s", dest.Name()) - task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), - ID: task.ReplyMessageID, - }) - if config.Cfg.Retry <= 0 { - if err := storage.Save(task.Storage, task.Ctx, dest.Name(), task.StoragePath); err != nil { - return fmt.Errorf("Failed to save file: %w", err) - } - } else { - for i := 0; i < config.Cfg.Retry; i++ { - if err := storage.Save(task.Storage, task.Ctx, dest.Name(), task.StoragePath); err != nil { - logger.L.Errorf("Failed to save file: %s, retrying...", err) - if i == config.Cfg.Retry-1 { - return fmt.Errorf("Failed to save file: %w", err) - } - } else { - break - } - } - } - return nil + logger.L.Infof("Downloaded file: %s", destPath) + ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ + Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), + ID: task.ReplyMessageID, + }) - default: - return fmt.Errorf("unexpected type %T", res) - } + return saveFileWithRetry(task, destPath) } - barTotalCount := 5 - if task.File.FileSize > 1024*1024*200 { - barTotalCount = 10 - } else if task.File.FileSize > 1024*1024*500 { - barTotalCount = 20 - } else if task.File.FileSize > 1024*1024*1000 { - barTotalCount = 50 - } + barTotalCount := calculateBarTotalCount(task.File.FileSize) - readCloser, err := NewTelegramReader(task.Ctx, bot.Client, &task.File.Location, 0, task.File.FileSize-1, task.File.FileSize, func(bytesRead, contentLength int64) { + progressCallback := func(bytesRead, contentLength int64) { progress := float64(bytesRead) / float64(contentLength) * 100 logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress) - if task.File.FileSize < 1024*1024*50 { - return - } - - barSize := 100 / barTotalCount - if int(progress)%barSize != 0 { + if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 { return } - - text := fmt.Sprintf("正在下载: %s\n[%s] %.2f%%", task.String(), func() string { - bar := "" - for i := 0; i < barTotalCount; i++ { - if int(progress)/barSize > i { - bar += "█" - } else { - bar += "░" - } - } - return bar - }(), progress) - task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ + text := fmt.Sprintf("正在下载: %s\n[%s] %.2f%%", + task.String(), + getProgressBar(progress, barTotalCount), + progress, + ) + ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ Message: text, ID: task.ReplyMessageID, }) - }, task.File.FileSize/100) + } + readCloser, err := NewTelegramReader(task.Ctx, bot.Client, &task.File.Location, + 0, task.File.FileSize-1, task.File.FileSize, + progressCallback, task.File.FileSize/100) if err != nil { return fmt.Errorf("Failed to create reader: %w", err) } defer readCloser.Close() - dest, err := os.Create(filepath.Join(config.Cfg.Temp.BasePath, task.File.FileName)) + dest, err := os.Create(destPath) if err != nil { return fmt.Errorf("Failed to create file: %w", err) } @@ -144,44 +99,16 @@ func processPendingTask(task *types.Task) error { if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil { return fmt.Errorf("Failed to download file: %w", err) } - destName := dest.Name() - - defer func() { - if config.Cfg.Temp.CacheTTL > 0 { - common.RmFileAfter(destName, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second) - } else { - if err := os.Remove(destName); err != nil { - logger.L.Errorf("Failed to purge file: %s", err) - } - } - }() - if task.StoragePath == "" { - task.StoragePath = task.File.FileName - } + defer cleanCacheFile(destPath) - logger.L.Infof("Downloaded file: %s", dest.Name()) - task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ + logger.L.Infof("Downloaded file: %s", destPath) + ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), ID: task.ReplyMessageID, }) - if config.Cfg.Retry <= 0 { - if err := storage.Save(task.Storage, task.Ctx, dest.Name(), task.StoragePath); err != nil { - return fmt.Errorf("Failed to save file: %w", err) - } - } else { - for i := 0; i < config.Cfg.Retry; i++ { - if err := storage.Save(task.Storage, task.Ctx, dest.Name(), task.StoragePath); err != nil { - logger.L.Errorf("Failed to save file: %s, retrying...", err) - if i == config.Cfg.Retry-1 { - return fmt.Errorf("Failed to save file: %w", err) - } - } else { - break - } - } - } - return nil + + return saveFileWithRetry(task, destPath) } func worker(queue *queue.TaskQueue, semaphore chan struct{}) { diff --git a/core/utils.go b/core/utils.go new file mode 100644 index 0000000..47497ba --- /dev/null +++ b/core/utils.go @@ -0,0 +1,62 @@ +package core + +import ( + "fmt" + "os" + "time" + + "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/storage" + "github.com/krau/SaveAny-Bot/types" +) + +func saveFileWithRetry(task *types.Task, destPath string) error { + for i := 0; i <= config.Cfg.Retry; i++ { + if err := storage.Save(task.Storage, task.Ctx, destPath, task.StoragePath); err != nil { + if i == config.Cfg.Retry { + return fmt.Errorf("Failed to save file: %w", err) + } + logger.L.Errorf("Failed to save file: %s, retrying...", err) + continue + } + return nil + } + return nil +} + +func getProgressBar(progress float64, totalCount int) string { + bar := "" + barSize := 100 / totalCount + for i := 0; i < totalCount; i++ { + if int(progress)/barSize > i { + bar += "█" + } else { + bar += "░" + } + } + return bar +} + +func cleanCacheFile(destPath string) { + if config.Cfg.Temp.CacheTTL > 0 { + common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second) + } else { + if err := os.Remove(destPath); err != nil { + logger.L.Errorf("Failed to purge file: %s", err) + } + } +} + +func calculateBarTotalCount(fileSize int64) int { + barTotalCount := 5 + if fileSize > 1024*1024*1000 { + barTotalCount = 50 + } else if fileSize > 1024*1024*500 { + barTotalCount = 20 + } else if fileSize > 1024*1024*200 { + barTotalCount = 10 + } + return barTotalCount +}