Skip to content

Commit

Permalink
feat: save file by cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
krau committed Nov 9, 2024
1 parent 454d69c commit e3cd659
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 5 deletions.
107 changes: 104 additions & 3 deletions bot/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,114 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
}

func saveCmd(ctx *ext.Context, update *ext.Update) error {
// TODO: Implement save command
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
ctx.Reply(update, "请回复要保存的文件", nil)
return dispatcher.EndGroups
}
replyHeader, ok := res.(*tg.MessageReplyHeader)
if !ok {
ctx.Reply(update, "请回复要保存的文件", nil)
return dispatcher.EndGroups
}
replyToMsgID, ok := replyHeader.GetReplyToMsgID()
if !ok {
ctx.Reply(update, "请回复要保存的文件", nil)
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, Client, replyToMsgID)

supported, _ := supportedMediaFilter(msg)
if !supported {
ctx.Reply(update, "不支持的消息类型或消息中没有文件", nil)
return dispatcher.EndGroups
}

user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}

replied, err := ctx.Reply(update, "正在获取文件信息...", nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, Client, update.EffectiveChat().GetID(), msg.ID)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件",
ID: replied.ID,
})
return dispatcher.EndGroups
}

if file.FileName == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件名",
ID: replied.ID,
})
return dispatcher.EndGroups
}

if err := dao.AddReceivedFile(&types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: replyToMsgID,
ReplyMessageID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to add received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件",
ID: replied.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}

if !user.Silent {
text := "请选择存储位置"
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
ReplyMarkup: getAddTaskMarkup(msg.ID),
ID: replied.ID,
})
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
}
return dispatcher.EndGroups
}

if user.DefaultStorage == "" {
ctx.Reply(update, "请先使用 /storage 设置默认存储位置", nil)
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
ChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
MessageID: msg.ID,
})
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", file.FileName, queue.Len()),
ID: replied.ID,
})
if err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}

func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage)
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
if err != nil {
return err
}
Expand Down Expand Up @@ -226,7 +327,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(string(update.CallbackQuery.Data), " ")
messageID, _ := strconv.Atoi(args[1])
logger.L.Trace("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2])
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2])
record, err := dao.GetReceivedFileByChatAndMessageID(update.EffectiveChat().GetID(), messageID)
if err != nil {
logger.L.Errorf("Failed to get received file: %s", err)
Expand Down
3 changes: 1 addition & 2 deletions bot/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@ import (

"github.com/celestix/gotgproto"
"github.com/celestix/gotgproto/dispatcher"
tgTypes "github.com/celestix/gotgproto/types"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)

func supportedMediaFilter(m *tgTypes.Message) (bool, error) {
func supportedMediaFilter(m *tg.Message) (bool, error) {
if not := m.Media == nil; not {
return false, dispatcher.EndGroups
}
Expand Down

0 comments on commit e3cd659

Please sign in to comment.