diff --git a/core/core.go b/core/core.go index 20c72a5..d213b0b 100644 --- a/core/core.go +++ b/core/core.go @@ -21,17 +21,50 @@ import ( func processPendingTask(task *types.Task) error { logger.L.Debugf("Start processing task: %s", task.String()) - os.MkdirAll(config.Cfg.Temp.BasePath, os.ModePerm) - - logger.L.Debugf("Start downloading file: %s", task.String()) - task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ - Message: "开始下载文件...", + Message: "正在下载文件...", ID: task.ReplyMessageID, }) - readCloser, err := NewTelegramReader(task.Ctx, bot.Client, task.File.Location, 0, task.File.FileSize-1, task.File.FileSize) + barTotalCount := 5 + if task.File.FileSize > 1024*1024*200 { + barTotalCount = 10 + } else if task.File.FileSize > 1024*1024*500 { + barTotalCount = 20 + } else if task.File.FileSize > 1024*1024*1000 { + barTotalCount = 50 + } + + readCloser, err := NewTelegramReader(task.Ctx, bot.Client, task.File.Location, 0, task.File.FileSize-1, task.File.FileSize, func(bytesRead, contentLength int64) { + progress := float64(bytesRead) / float64(contentLength) * 100 + logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress) + if task.File.FileSize < 1024*1024*50 { + return + } + + barSize := 100 / barTotalCount + if int(progress)%barSize != 0 { + return + } + + text := fmt.Sprintf("正在下载文件\n[%s] %.2f%%", func() string { + bar := "" + for i := 0; i < barTotalCount; i++ { + if int(progress)/barSize > i { + bar += "█" + } else { + bar += "░" + } + } + return bar + }(), progress) + task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ + Message: text, + ID: task.ReplyMessageID, + }) + }, task.File.FileSize/100) + if err != nil { return fmt.Errorf("Failed to create reader: %w", err) } @@ -41,18 +74,18 @@ func processPendingTask(task *types.Task) error { if err != nil { return fmt.Errorf("Failed to create file: %w", err) } - logger.L.Debug("Created file: ", dest.Name()) defer dest.Close() if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil { return fmt.Errorf("Failed to download file: %w", err) } + destName := dest.Name() defer func() { if config.Cfg.Temp.CacheTTL > 0 { - common.RmFileAfter(dest.Name(), time.Duration(config.Cfg.Temp.CacheTTL)*time.Second) + common.RmFileAfter(destName, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second) } else { - if err := os.Remove(dest.Name()); err != nil { + if err := os.Remove(destName); err != nil { logger.L.Errorf("Failed to purge file: %s", err) } } diff --git a/core/reader.go b/core/reader.go index ad07c0a..ebef27a 100644 --- a/core/reader.go +++ b/core/reader.go @@ -10,17 +10,20 @@ import ( ) type telegramReader struct { - ctx context.Context - client *gotgproto.Client - location *tg.InputDocumentFileLocation - start int64 - end int64 - next func() ([]byte, error) - buffer []byte - bytesread int64 - chunkSize int64 - i int64 - contentLength int64 + client *gotgproto.Client + location *tg.InputDocumentFileLocation + bytesread int64 + chunkSize int64 + i int64 + contentLength int64 + start int64 + end int64 + next func() ([]byte, error) + progressCallback func(bytesRead, contentLength int64) + callbackInterval int64 + lastProgress int64 + buffer []byte + ctx context.Context } func (*telegramReader) Close() error { @@ -50,6 +53,12 @@ func (r *telegramReader) Read(p []byte) (n int, err error) { n = copy(p, r.buffer[r.i:]) r.i += int64(n) r.bytesread += int64(n) + + if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) { + r.progressCallback(r.bytesread, r.contentLength) + r.lastProgress = r.bytesread + } + return n, nil } @@ -60,16 +69,20 @@ func NewTelegramReader( start int64, end int64, contentLength int64, + progressCallback func(bytesRead, contentLength int64), + callbackInterval int64, ) (io.ReadCloser, error) { r := &telegramReader{ - ctx: ctx, - location: location, - client: client, - start: start, - end: end, - chunkSize: int64(1024 * 1024), - contentLength: contentLength, + ctx: ctx, + location: location, + client: client, + start: start, + end: end, + chunkSize: int64(1024 * 1024), + contentLength: contentLength, + progressCallback: progressCallback, + callbackInterval: callbackInterval, } r.next = r.partStream()