-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_caption.py
64 lines (48 loc) · 2.14 KB
/
generate_caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader
from PIL import Image
import cv2
from dataset import RoadSignDataset
from tqdm import tqdm
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import argparse
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate Caption by BLIP")
parser.add_argument("--cuda", default="0", type=str, help="GPU to use")
parser.add_argument("--data_path", default="metadata_test", type=str, help="Path to testing dataset")
return parser.parse_args()
def generate_caption(image, BLIP_process, BLIP_model):
"""
Generate caption from image
:param image: image to generate caption
:param BLIP_process: BLIP pre-process the input
:param BLIP_model: BLIP model
:return: generated caption
"""
inputs = BLIP_process(images=[image], return_tensors="pt").to(device, torch.float16)
generated_ids = BLIP_model.generate(**inputs)
generated_text = BLIP_process.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
def inference_image(BLIP, processor, data_path):
test_dataset = RoadSignDataset(f"dataset/{data_path}.csv", return_raw_data=True)
results = []
for i in tqdm(range(len(test_dataset))):
(image, label) = test_dataset[i]
image_id = test_dataset.get_image_id(i)
pred = generate_caption(image, processor, BLIP)
results.append([image_id, label, pred])
# Save the generate caption for each image into csv file
df = pd.DataFrame(results, columns=["image_id", "label", "caption"])
df.to_csv(f"result_Bilp_caption_{data_path}.csv", index=False)
if __name__ == "__main__":
# Do not need to pre-processing the image
args = parse_arguments()
device = "cpu"
if torch.cuda.is_available():
device = f"cuda:{args.cuda}"
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
BLIP = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map={"": device}).to(
device)
inference_image(BLIP, processor, args.data_path)