From e3cd659eb3769e51645466d368bac1d99d1744ae Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Sat, 9 Nov 2024 11:34:28 +0800 Subject: [PATCH] feat: save file by cmd --- bot/handlers.go | 107 ++++++++++++++++++++++++++++++++++++++++++++++-- bot/utils.go | 3 +- 2 files changed, 105 insertions(+), 5 deletions(-) diff --git a/bot/handlers.go b/bot/handlers.go index ab741df..2d0489e 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -130,13 +130,114 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error { } func saveCmd(ctx *ext.Context, update *ext.Update) error { - // TODO: Implement save command + res, ok := update.EffectiveMessage.GetReplyTo() + if !ok || res == nil { + ctx.Reply(update, "请回复要保存的文件", nil) + return dispatcher.EndGroups + } + replyHeader, ok := res.(*tg.MessageReplyHeader) + if !ok { + ctx.Reply(update, "请回复要保存的文件", nil) + return dispatcher.EndGroups + } + replyToMsgID, ok := replyHeader.GetReplyToMsgID() + if !ok { + ctx.Reply(update, "请回复要保存的文件", nil) + return dispatcher.EndGroups + } + msg, err := GetTGMessage(ctx, Client, replyToMsgID) + + supported, _ := supportedMediaFilter(msg) + if !supported { + ctx.Reply(update, "不支持的消息类型或消息中没有文件", nil) + return dispatcher.EndGroups + } + + user, err := dao.GetUserByUserID(update.GetUserChat().GetID()) + if err != nil { + logger.L.Errorf("Failed to get user: %s", err) + return dispatcher.EndGroups + } + + replied, err := ctx.Reply(update, "正在获取文件信息...", nil) + if err != nil { + logger.L.Errorf("Failed to reply: %s", err) + return dispatcher.EndGroups + } + file, err := FileFromMessage(ctx, Client, update.EffectiveChat().GetID(), msg.ID) + if err != nil { + logger.L.Errorf("Failed to get file from message: %s", err) + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: "无法获取文件", + ID: replied.ID, + }) + return dispatcher.EndGroups + } + + if file.FileName == "" { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: "无法获取文件名", + ID: replied.ID, + }) + return dispatcher.EndGroups + } + + if err := dao.AddReceivedFile(&types.ReceivedFile{ + Processing: false, + FileName: file.FileName, + ChatID: update.EffectiveChat().GetID(), + MessageID: replyToMsgID, + ReplyMessageID: replied.ID, + }); err != nil { + logger.L.Errorf("Failed to add received file: %s", err) + if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: "无法保存文件", + ID: replied.ID, + }); err != nil { + logger.L.Errorf("Failed to edit message: %s", err) + } + return dispatcher.EndGroups + } + + if !user.Silent { + text := "请选择存储位置" + _, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: text, + ReplyMarkup: getAddTaskMarkup(msg.ID), + ID: replied.ID, + }) + if err != nil { + logger.L.Errorf("Failed to reply: %s", err) + } + return dispatcher.EndGroups + } + + if user.DefaultStorage == "" { + ctx.Reply(update, "请先使用 /storage 设置默认存储位置", nil) + return dispatcher.EndGroups + } + queue.AddTask(types.Task{ + Ctx: ctx, + Status: types.Pending, + File: file, + Storage: types.StorageType(user.DefaultStorage), + ChatID: update.EffectiveChat().GetID(), + ReplyMessageID: replied.ID, + MessageID: msg.ID, + }) + _, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", file.FileName, queue.Len()), + ID: replied.ID, + }) + if err != nil { + logger.L.Errorf("Failed to edit message: %s", err) + } return dispatcher.EndGroups } func handleFileMessage(ctx *ext.Context, update *ext.Update) error { logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName()) - supported, err := supportedMediaFilter(update.EffectiveMessage) + supported, err := supportedMediaFilter(update.EffectiveMessage.Message) if err != nil { return err } @@ -226,7 +327,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error { func AddToQueue(ctx *ext.Context, update *ext.Update) error { args := strings.Split(string(update.CallbackQuery.Data), " ") messageID, _ := strconv.Atoi(args[1]) - logger.L.Trace("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2]) + logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2]) record, err := dao.GetReceivedFileByChatAndMessageID(update.EffectiveChat().GetID(), messageID) if err != nil { logger.L.Errorf("Failed to get received file: %s", err) diff --git a/bot/utils.go b/bot/utils.go index c928287..31a8778 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -6,7 +6,6 @@ import ( "github.com/celestix/gotgproto" "github.com/celestix/gotgproto/dispatcher" - tgTypes "github.com/celestix/gotgproto/types" "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/logger" @@ -14,7 +13,7 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func supportedMediaFilter(m *tgTypes.Message) (bool, error) { +func supportedMediaFilter(m *tg.Message) (bool, error) { if not := m.Media == nil; not { return false, dispatcher.EndGroups }