diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6914b93c5..ab8618b08 100755 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,6 +7,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v4 with: diff --git a/.gitignore b/.gitignore index ea20fe526..a254579f2 100755 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ VATEX/ lmms_eval/tasks/vatex/__pycache__/utils.cpython-310.pyc lmms_eval/tasks/mlvu/__pycache__/utils.cpython-310.pyc -scripts/ \ No newline at end of file +scripts/ +.env \ No newline at end of file diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index edf926f2f..06affda2f 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -11,6 +11,7 @@ logger.add(sys.stdout, level="WARNING") AVAILABLE_MODELS = { + "aria": "Aria", "auroracap": "AuroraCap", "batch_gpt4": "BatchGPT4", "claude": "Claude", @@ -21,45 +22,48 @@ "gpt4v": "GPT4V", "idefics2": "Idefics2", "instructblip": "InstructBLIP", + "internvideo2": "InternVideo2", "internvl": "InternVLChat", "internvl2": "InternVL2", "llama_vid": "LLaMAVid", + "llama_vision": "LlamaVision", "llava": "Llava", "llava_hf": "LlavaHf", "llava_onevision": "Llava_OneVision", "llava_onevision_moviechat": "Llava_OneVision_MovieChat", "llava_sglang": "LlavaSglang", "llava_vid": "LlavaVid", - "slime": "Slime", "longva": "LongVA", "mantis": "Mantis", "minicpm_v": "MiniCPM_V", "minimonkey": "MiniMonkey", "moviechat": "MovieChat", "mplug_owl_video": "mplug_Owl", + "ola": "Ola", + "openai_compatible": "OpenAICompatible", + "oryx": "Oryx", "phi3v": "Phi3v", - "qwen_vl": "Qwen_VL", - "qwen2_vl": "Qwen2_VL", "qwen2_5_vl": "Qwen2_5_VL", "qwen2_5_vl_interleave": "Qwen2_5_VL_Interleave", "qwen2_audio": "Qwen2_Audio", + "qwen2_vl": "Qwen2_VL", + "qwen_vl": "Qwen_VL", "qwen_vl_api": "Qwen_VL_API", "reka": "Reka", + "ross": "Ross", + "slime": "Slime", "srt_api": "SRT_API", "tinyllava": "TinyLlava", "videoChatGPT": "VideoChatGPT", + "videochat2": "VideoChat2", "video_llava": "VideoLLaVA", "vila": "VILA", - "xcomposer2_4KHD": "XComposer2_4KHD", - "internvideo2": "InternVideo2", - "xcomposer2d5": "XComposer2D5", - "oryx": "Oryx", - "videochat2": "VideoChat2", - "llama_vision": "LlamaVision", - "aria": "Aria", - "ross": "Ross", "vita": "VITA", "docvision": "DocVision", + "vllm": "VLLM", + "xcomposer2_4KHD": "XComposer2_4KHD", + "xcomposer2d5": "XComposer2D5", + "egogpt": "EgoGPT", } diff --git a/lmms_eval/models/egogpt.py b/lmms_eval/models/egogpt.py new file mode 100644 index 000000000..f75bd5df5 --- /dev/null +++ b/lmms_eval/models/egogpt.py @@ -0,0 +1,472 @@ +import copy +import json +import logging +import math +import re +import warnings +from datetime import timedelta +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import transformers +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from decord import VideoReader, cpu +from packaging import version +from tqdm import tqdm +from transformers import AutoConfig + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +# Suppress warnings +warnings.filterwarnings("ignore") + +# Configure logging +eval_logger = logging.getLogger("lmms-eval") + +# Enable TF32 for CUDA +torch.backends.cuda.matmul.allow_tf32 = True + +# Import LLaVA modules +try: + import copy + import os + import re + import sys + import warnings + + import numpy as np + import requests + import soundfile as sf + import torch + import whisper + from decord import VideoReader, cpu + from egogpt.constants import ( + DEFAULT_IMAGE_TOKEN, + DEFAULT_SPEECH_TOKEN, + IGNORE_INDEX, + IMAGE_TOKEN_INDEX, + SPEECH_TOKEN_INDEX, + ) + from egogpt.conversation import SeparatorStyle, conv_templates + from egogpt.mm_utils import get_model_name_from_path, process_images + from egogpt.model.builder import load_pretrained_model + from PIL import Image + from scipy.signal import resample +except ImportError as e: + eval_logger.debug(f"egogpt is not installed. Please install egogpt to use this model.\nError: {e}") + + +# Determine best attention implementation +if version.parse(torch.__version__) >= version.parse("2.1.2"): + best_fit_attn_implementation = "sdpa" +else: + best_fit_attn_implementation = "eager" + + +@register_model("egogpt") +class EgoGPT(lmms): + """ + EgoGPT Model + """ + + def __init__( + self, + pretrained: str = "checkpoints/egogpt_IT_12k_1126_zero3", + truncation: Optional[bool] = True, + device: Optional[str] = "cuda:0", + batch_size: Optional[Union[int, str]] = 1, + model_name: Optional[str] = None, + attn_implementation: Optional[str] = best_fit_attn_implementation, + device_map: Optional[str] = "cuda:0", + conv_template: Optional[str] = "qwen_1_5", + use_cache: Optional[bool] = True, + truncate_context: Optional[bool] = False, # whether to truncate the context in generation, set it False for LLaVA-1.6 + customized_config: Optional[str] = None, # ends in json + max_frames_num: Optional[int] = 32, + mm_spatial_pool_stride: Optional[int] = 2, + mm_spatial_pool_mode: Optional[str] = "bilinear", + token_strategy: Optional[str] = "single", # could be "single" or "multiple", "multiple" denotes adding multiple tokens for each frame + video_decode_backend: str = "decord", + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + egogpt_model_args = {} + if attn_implementation is not None: + egogpt_model_args["attn_implementation"] = attn_implementation + + self.pretrained = pretrained + self.token_strategy = token_strategy + self.max_frames_num = max_frames_num + self.mm_spatial_pool_stride = mm_spatial_pool_stride + self.mm_spatial_pool_mode = mm_spatial_pool_mode + self.video_decode_backend = video_decode_backend + # Try to load the model with the multimodal argument + self._tokenizer, self._model, self._max_length = load_pretrained_model(pretrained, device_map=self.device_map, **egogpt_model_args) + self._image_processor = self._model.get_vision_tower().image_processor + self._config = self._model.config + self.model.eval() + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + self.use_cache = use_cache + self.truncate_context = truncate_context + assert self.batch_size_per_gpu == 1 + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._world_size = 1 + + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + try: + return self.tokenizer.decode(tokens) + except: + return self.tokenizer.decode([tokens]) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + raise NotImplementedError("Loglikelihood is not implemented for EgoGPT") + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def split_text(self, text, keywords): + pattern = "(" + "|".join(map(re.escape, keywords)) + ")" + parts = re.split(pattern, text) + parts = [part for part in parts if part] + return parts + + def load_video(self, video_path=None, audio_path=None, max_frames_num=16, fps=1, task_name=None): + if audio_path is not None: + speech, sample_rate = sf.read(audio_path) + if sample_rate != 16000: + target_length = int(len(speech) * 16000 / sample_rate) + speech = resample(speech, target_length) + if speech.ndim > 1: + speech = np.mean(speech, axis=1) + # max_length = 480000 + speech = whisper.pad_or_trim(speech.astype(np.float32)) + speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0) + speech_lengths = torch.LongTensor([speech.shape[0]]) + else: + speech = torch.zeros(3000, 128) + speech_lengths = torch.LongTensor([3000]) + + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + total_frame_num = len(vr) + avg_fps = round(vr.get_avg_fps() / fps) + frame_idx = [i for i in range(0, total_frame_num, avg_fps)] + frame_time = [i / avg_fps for i in frame_idx] + + if max_frames_num > 0: + if len(frame_idx) > max_frames_num: + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + if task_name == "egoplan": + # add current ovservation frame + frame_idx.append(total_frame_num - 1) + video = vr.get_batch(frame_idx).asnumpy() + return video, speech, speech_lengths + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + metadata = requests[0].metadata + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + + origin_image_aspect_ratio = getattr(self._config, "image_aspect_ratio", None) + + for chunk in chunks: + batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_id, batched_task, batched_split = zip(*chunk) + task = batched_task[0] + split = batched_split[0] + batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] # [B, N] + assert len(batched_visuals) == 1 + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + if "until" in gen_kwargs: + gen_kwargs.pop("until") + + question_input = [] + # import ipdb; ipdb.set_trace() + for visual, context in zip(batched_visuals, batched_contexts): + if origin_image_aspect_ratio is not None and self._config.image_aspect_ratio != origin_image_aspect_ratio: + self._config.image_aspect_ratio = origin_image_aspect_ratio + eval_logger.info(f"Resetting image aspect ratio to {origin_image_aspect_ratio}") + + if visual is None or visual == []: # for text-only tasks. + visual = None + task_type = "text" + placeholder_count = 0 + image_tensor = None + else: + if len(visual) > 1 or "image_aspect_ratio" not in self._config.__dict__: # for multi image case, we treat per image aspect ratio as "pad" by default. + self._config.image_aspect_ratio = getattr(gen_kwargs, "image_aspect_ratio", "pad") + eval_logger.info(f"In Multi-Image setting, image aspect ratio: {self._config.image_aspect_ratio}") + + if "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: # overwrite logic for video task with multiple static image frames + assert type(visual) == list, "sample_frames must be specified for video task" + sample_indices = np.linspace(0, len(visual) - 1, metadata["sample_frames"], dtype=int) + visual = [visual[i] for i in sample_indices] + assert len(visual) == metadata["sample_frames"] + + image_tensor = process_images(visual, self._image_processor, self._config) + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + image_tensor = [image_tensor] + task_type = "video" + placeholder_count = 1 + + elif type(visual[0]) == PIL.Image.Image: # For image, multi-image tasks + image_tensor = process_images(visual, self._image_processor, self._config) + speech = torch.zeros(3000, 128) + speech_lengths = torch.LongTensor([3000]) + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + + task_type = "image" + placeholder_count = len(visual) if isinstance(visual, list) else 1 + + elif type(visual[0]) == str: # For video task + image_tensor = [] + try: + if self.video_decode_backend == "decord": + if "egoplan" in visual[0]: + task_name = "egoplan" + else: + task_name = None + frames, speech, speech_lengths = self.load_video(video_path=visual[0], max_frames_num=self.max_frames_num, task_name=task_name) + else: + raise NotImplementedError("Only decord backend is supported for video task") + processed_frames = self._image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda() + processed_frames = processed_frames.half() + image_tensor.append(processed_frames) + image_sizes = [frames[0].size] + except Exception as e: + eval_logger.error(f"Error {e} in loading video") + image_tensor = None + + task_type = "video" + placeholder_count = len(frames) if self.token_strategy == "multiple" else 1 + if DEFAULT_IMAGE_TOKEN not in context: + question = DEFAULT_IMAGE_TOKEN + "\n" + context + else: + question = context + speech = torch.stack([speech]).to(self.device).half() + # This is much safer for llama3, as we now have some object type in it + if "llama_3" in self.conv_template: + conv = copy.deepcopy(conv_templates[self.conv_template]) + else: + conv = conv_templates[self.conv_template].copy() + + if utils.is_json(question): # conversational question input + question = json.loads(question) + for idx, item in enumerate(question): + role = conv.roles[idx % 2] + message = item["value"] + conv.append_message(role, message) + + assert len(conv.messages) % 2 == 1 + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + else: # only simple string for question + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + + # preconfigure gen_kwargs with defaults + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "do_sample" not in gen_kwargs: + gen_kwargs["do_sample"] = False + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + parts = self.split_text(prompt_question, ["", ""]) + input_ids = [] + for part in parts: + if "" == part: + input_ids += [IMAGE_TOKEN_INDEX] + elif "" == part: + input_ids += [SPEECH_TOKEN_INDEX] + else: + input_ids += self.tokenizer(part).input_ids + + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(self.device) + input_ids_list = [input_ids] + pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) + attention_masks = input_ids.ne(pad_token_ids).to(self.device) + input_ids = torch.tensor(input_ids, dtype=torch.long).squeeze(0).to(self.device) + if task_type == "image": + gen_kwargs["image_sizes"] = [batched_visuals[0][idx].size for idx in range(len(batched_visuals[0]))] + elif task_type == "video": + gen_kwargs["modalities"] = ["video"] + self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride + self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode + gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id + + # These steps are not in LLaVA's original code, but are necessary for generation to work + # TODO: attention to this major generation step... + if "image_aspect_ratio" in gen_kwargs.keys(): + gen_kwargs.pop("image_aspect_ratio") + try: + with torch.inference_mode(): + cont = self.model.generate(input_ids, images=image_tensor, speech=speech, speech_lengths=speech_lengths, **gen_kwargs) + + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) + except Exception as e: + raise e + + text_outputs = [response.strip() for response in text_outputs] + res.extend(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res + + def generate_until_multi_round(self, requests: List[Instance]) -> List[str]: + raise NotImplementedError("generate_until_multi_round is not implemented for EgoGPT") diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index af313a573..662b73882 100755 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -4,11 +4,12 @@ import time from copy import deepcopy from io import BytesIO -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import requests as url_requests from accelerate import Accelerator, DistributedType +from openai import AzureOpenAI, OpenAI from tqdm import tqdm from lmms_eval.api.instance import Instance @@ -20,26 +21,19 @@ except ImportError: pass +from loguru import logger as eval_logger from PIL import Image API_TYPE = os.getenv("API_TYPE", "openai") -NUM_SECONDS_TO_SLEEP = 30 -from loguru import logger as eval_logger - +NUM_SECONDS_TO_SLEEP = 10 if API_TYPE == "openai": API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") - headers = { - "Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json", - } + elif API_TYPE == "azure": API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") - headers = { - "api-key": API_KEY, - "Content-Type": "application/json", - } + API_VERSION = os.getenv("AZURE_API_VERSION", "2023-07-01-preview") @register_model("gpt4v") @@ -52,6 +46,7 @@ def __init__( timeout: int = 120, continual_mode: bool = False, response_persistent_folder: str = None, + max_size_in_mb: int = 20, **kwargs, ) -> None: super().__init__() @@ -80,6 +75,11 @@ def __init__( self.response_cache = {} self.cache_mode = "start" + if API_TYPE == "openai": + self.client = OpenAI(api_key=API_KEY) + elif API_TYPE == "azure": + self.client = AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL, api_version=API_VERSION) + accelerator = Accelerator() # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." if accelerator.num_processes > 1: @@ -94,13 +94,30 @@ def __init__( self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes + self.max_size_in_mb = max_size_in_mb self.device = self.accelerator.device # Function to encode the image - def encode_image(self, image: Image): + def encode_image(self, image: Union[Image.Image, str]): + max_size = self.max_size_in_mb * 1024 * 1024 # 20MB in bytes + if isinstance(image, str): + img = Image.open(image).convert("RGB") + else: + img = image.copy() + output_buffer = BytesIO() - image.save(output_buffer, format="PNG") + img.save(output_buffer, format="PNG") byte_data = output_buffer.getvalue() + + # If image is too large, resize it while maintaining aspect ratio + while len(byte_data) > max_size and img.size[0] > 100 and img.size[1] > 100: + new_size = (int(img.size[0] * 0.75), int(img.size[1] * 0.75)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode("utf-8") return base64_str @@ -150,39 +167,30 @@ def generate_until(self, requests) -> List[str]: continue visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] - visuals = self.flatten(visuals) - imgs = [] # multiple images or frames for video - for visual in visuals: - if self.modality == "image": - img = self.encode_image(visual) - imgs.append(img) - elif self.modality == "video": - frames = self.encode_video(visual, self.max_frames_num) - imgs.extend(frames) + if None in visuals: + visuals = [] + imgs = [] + else: + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + for visual in visuals: + if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual): + frames = self.encode_video(visual, self.max_frames_num) + imgs.extend(frames) + elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual): + img = self.encode_image(visual) + imgs.append(img) + elif isinstance(visual, Image.Image): + img = self.encode_image(visual) + imgs.append(img) payload = {"messages": []} - if API_TYPE == "openai": - payload["model"] = self.model_version - - response_json = {"role": "user", "content": []} - # When there is no image token in the context, append the image to the text - if self.image_token not in contexts: - payload["messages"].append(deepcopy(response_json)) - payload["messages"][0]["content"].append({"type": "text", "text": contexts}) - for img in imgs: - payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - else: - contexts = contexts.split(self.image_token) - for idx, img in enumerate(imgs): - payload["messages"].append(deepcopy(response_json)) - payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) - payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - - # If n image tokens are in the contexts - # contexts will be splitted into n+1 chunks - # Manually add it into the payload - payload["messages"].append(deepcopy(response_json)) - payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]}) + payload["model"] = self.model_version + + payload["messages"].append({"role": "user", "content": []}) + payload["messages"][0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 @@ -198,26 +206,24 @@ def generate_until(self, requests) -> List[str]: payload["max_tokens"] = gen_kwargs["max_new_tokens"] payload["temperature"] = gen_kwargs["temperature"] - for attempt in range(5): + MAX_RETRIES = 5 + for attempt in range(MAX_RETRIES): try: - response = url_requests.post(API_URL, headers=headers, json=payload, timeout=self.timeout) - response_data = response.json() - - response_text = response_data["choices"][0]["message"]["content"].strip() + response = self.client.chat.completions.create(**payload) + response_text = response.choices[0].message.content break # If successful, break out of the loop except Exception as e: - try: - error_msg = response.json() - except: - error_msg = "" + error_msg = str(e) + eval_logger.info(f"Attempt {attempt + 1}/{MAX_RETRIES} failed with error: {error_msg}") - eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.\nReponse: {error_msg}") - if attempt <= 5: - time.sleep(NUM_SECONDS_TO_SLEEP) - else: # If this was the last attempt, log and return empty string - eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.\nResponse: {response.json()}") + # On last attempt, log error and set empty response + if attempt == MAX_RETRIES - 1: + eval_logger.error(f"All {MAX_RETRIES} attempts failed. Last error: {error_msg}") response_text = "" + else: + time.sleep(NUM_SECONDS_TO_SLEEP) + res.append(response_text) pbar.update(1) diff --git a/lmms_eval/models/llava_vid.py b/lmms_eval/models/llava_vid.py index 2e2028ea7..5eddc5a06 100755 --- a/lmms_eval/models/llava_vid.py +++ b/lmms_eval/models/llava_vid.py @@ -1,3 +1,4 @@ +import glob import math import os from datetime import timedelta @@ -416,6 +417,8 @@ def generate_until(self, requests) -> List[str]: visuals = doc_to_visual(self.task_dict[task][split][doc_id]) # visuals = [visuals] # visuals = self.flatten(visuals) + if os.path.isdir(visuals[0]): + visuals = glob.glob(visuals[0] + "/*") videos = [] try: # for visual in visuals: @@ -440,7 +443,8 @@ def generate_until(self, requests) -> List[str]: frame_idx = sampled_indices.tolist() frame_time = [i / fps for i in frame_idx] frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) - video = [visuals[i] for i in frame_idx] + # video = [visuals[i] for i in frame_idx] + video = np.stack([np.array(Image.open(visuals[i])) for i in frame_idx], axis=0) video = self._image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda() if self.torch_dtype == "bfloat16": diff --git a/lmms_eval/models/ola.py b/lmms_eval/models/ola.py new file mode 100644 index 000000000..2ef488be7 --- /dev/null +++ b/lmms_eval/models/ola.py @@ -0,0 +1,707 @@ +import os + +os.environ["LOWRES_RESIZE"] = "384x32" +os.environ["HIGHRES_BASE"] = "0x32" +os.environ["VIDEO_RESIZE"] = "0x64" +os.environ["VIDEO_MAXRES"] = "480" +os.environ["VIDEO_MINRES"] = "288" +os.environ["MAXRES"] = "1536" +os.environ["MINRES"] = "0" +os.environ["FORCE_NO_DOWNSAMPLE"] = "1" +os.environ["LOAD_VISION_EARLY"] = "1" +os.environ["PAD2STRIDE"] = "1" +os.environ["USE_SPEECH"] = "1" +import copy +import logging +from datetime import timedelta +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import librosa +import numpy as np +import PIL +import soundfile as sf +import torch +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from decord import VideoReader, cpu +from tqdm import tqdm +from transformers import AutoConfig + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from lmms_eval.models.model_utils.audio_processing import downsample_audio +from lmms_eval.models.model_utils.load_video import read_video_pyav + +eval_logger = logging.getLogger("lmms-eval") + +import sys + +wd = Path(__file__).parent.parent.parent.parent.resolve() +sys.path.append(os.path.join(str(wd), "Ola")) + +import whisper +from ola.constants import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_TOKEN, + DEFAULT_SPEECH_TOKEN, + IMAGE_TOKEN_INDEX, + SPEECH_TOKEN_INDEX, +) +from ola.conversation import SeparatorStyle, conv_templates +from ola.datasets.preprocess import ( + tokenizer_image_token, + tokenizer_speech_image_token, + tokenizer_speech_token, +) +from ola.mm_utils import ( + KeywordsStoppingCriteria, + get_model_name_from_path, + process_anyres_highres_image, + process_anyres_video, +) +from ola.model.builder import load_pretrained_model + +try: + from ola.model.language_model.ola_qwen import OlaConfigQwen + + AutoConfig.register("ola_qwen", OlaConfigQwen) +except: + eval_logger.debug("") +from moviepy.video.io.VideoFileClip import VideoFileClip + +if "USE_SPEECH" in os.environ: + USE_SPEECH = os.environ["USE_SPEECH"] + print("USE_SPEECH is set") +else: + USE_SPEECH = None + + +@register_model("ola") +class Ola(lmms): + """ + How to run lmms-eval with Ola model: + + 1. Install Ola: + https://github.com/Ola-Omni/Ola?tab=readme-ov-file#installation + + 2. Download the pretrained weight from https://huggingface.co/THUdyh/Ola-7b + or skip this step to use the online weights directly + + 3.Download audio encoder from https://huggingface.co/THUdyh/Ola_speech_encoders/tree/main + and put the weights large-v3.pt and BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt + under llms-eval repository (ensure your current directory can see large-v3.pt and BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt) + + The path for the project should be like this: + Folder/contains/lmms-eval/and/Ola + ├── lmms-eval (current directory) + │ ├── lmms_eval/ + │ ├── ... + │ ├── large-v3.pt + │ ├── BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt + ├── Ola + │ ├── ... + + 4. Run the the command to start evaluate the modeL. For example: + ```bash + python3 -m accelerate.commands.launch \ + --num_processes=8 \ + -m lmms_eval \ + --model ola\ + --tasks mme \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix mme_ola \ + --output_path ./logs/ + ``` + """ + + def __init__( + self, + pretrained: str = "THUdyh/Ola-7b", + truncation: Optional[bool] = True, + device: Optional[str] = "cuda:0", + batch_size: Optional[Union[int, str]] = 1, + attn_implementation=( + "sdpa" if torch.__version__ >= "2.1.2" else "eager" + ), # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2". Seems FA2 is not effective during inference: https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5 + device_map="", + conv_template="qwen_1_5", + use_cache=True, + truncate_context=False, + max_frames_num: int = 64, + mm_resampler_type: str = "spatial_pool", + overwrite: bool = True, + video_decode_backend: str = "decord", + **kwargs, + ) -> None: + super().__init__() + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + self.pretrained = pretrained + self.model_name = get_model_name_from_path(pretrained) + self.video_decode_backend = video_decode_backend + # self._config = AutoConfig.from_pretrained(self.pretrained) + self.overwrite = overwrite + self.mm_resampler_type = mm_resampler_type + self.max_frames_num = int(max_frames_num) + if self.overwrite == True: + overwrite_config = {} + overwrite_config["patchify_video_feature"] = False + overwrite_config["attn_implementation"] = attn_implementation + + cfg_pretrained = AutoConfig.from_pretrained(self.pretrained) + + self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, device=self.device_map) + else: + self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model( + pretrained, + None, + device_map=self.device_map, + ) + + self._config = self._model.config + self.model.to("cuda").eval().bfloat16() + self.model.tie_weights() + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + self.use_cache = use_cache + self.truncate_context = truncate_context + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._word_size = 1 + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + self.accelerator = accelerator + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def load_video(self, video_path, max_frames_num): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + fps = round(vr.get_avg_fps()) + + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + + spare_frames = vr.get_batch(frame_idx).asnumpy() + video = [PIL.Image.fromarray(frame) for frame in spare_frames] + return video, frame_idx + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + # encode, pad, and truncate contexts for this batch + if type(doc_to_target) == str: + continuation = doc_to_target + else: + continuation = doc_to_target(self.task_dict[task][split][doc_id]) + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + videos = [] + for visual in visuals: + video = self.load_video(visual, self.max_frames_num) + video = self._image_processor.preprocess(video, return_tensors="pt")["pixel_values"].bfloat16().to(self.device) + videos.append(video) + + qs = contexts + if self.model.config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) + + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], continuation) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) + + labels = input_ids.clone() + # Context part no need to calculate for loss + labels[0, : contxt_id.shape[1]] = -100 + + with torch.inference_mode(): + outputs = self.model(input_ids=input_ids, labels=labels, images=videos, modalities="video") + + loss = outputs["loss"] + # loss = torch.exp(loss) + logits = outputs["logits"] + greedy_tokens = logits.argmax(dim=-1) + cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq] + greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]] # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + res.append((float(loss.item()), bool(max_equal))) + pbar.update(1) + pbar.close() + return res + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def extract_audio(self, videos_file_path): + my_clip = VideoFileClip(videos_file_path) + return my_clip.audio + + def load_audio(self, audio_file_name): + CHUNK_LIM = 480000 + import librosa + + audio, samplerate = librosa.load(audio_file_name, sr=16000) + audio = audio.astype(np.float32) + + if len(audio.shape) > 1: + audio = audio[:, 0] + + speechs = [] + speech_wavs = [] + if len(audio) <= CHUNK_LIM: + audio = whisper.pad_or_trim(audio) + speechs.append(audio) + speech_wavs.append(torch.from_numpy(audio).unsqueeze(0)) + else: + for i in range(0, len(audio), CHUNK_LIM): + chunk = audio[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 20: + mels = mels[:20] + speech_wavs = speech_wavs[:20] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + + return mels, speech_length, speech_chunks, speech_wavs + + def process_audio(self, audio_array, sampling_rate): + """ + Process audio array to format of Ola model + """ + audio = audio_array.astype(np.float32) + if len(audio.shape) > 1: + audio = audio[:, 0] + target_sr = 16000 + CHUNK_LIM = 480000 + import librosa + + if sampling_rate != target_sr: + speech_wav = librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=target_sr).astype(np.float32) + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + return mels, speech_length, speech_chunks, speech_wavs + + def generate_until(self, requests) -> List[str]: + MODALITY = None + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + context = contexts[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) # Len = 1. just an audio tho + + speechs, speech_lengths, speech_wavs, speech_chunks = [], [], [], [] + images, images_highres = [], [] # For dummy image passed in audio modality + image_sizes = [] + image_tensor, image_highres_tensor = [], [] # For image + video_processed = [] # For video only + for visual in visuals: + if isinstance(visual, str): # For Video + if MODALITY is None: + MODALITY = "VIDEO" + # Process audio of video + try: + video, frame_idx = self.load_video(visual, self.max_frames_num) + except Exception as e: + eval_logger.info(f"{e}") + eval_logger.info(f"Video {visuals} can not load, check the source") + continue + audio = self.extract_audio(visual) + audio.write_audiofile("./video_audio.wav") + video_audio_path = "./video_audio.wav" + speech, speech_length, speech_chunk, speech_wav = self.load_audio(video_audio_path) + speechs.append(speech.bfloat16().to("cuda")) + speech_lengths.append(speech_length.to("cuda")) + speech_chunks.append(speech_chunk.to("cuda")) + speech_wavs.append(speech_wav.to("cuda")) + os.remove(video_audio_path) + + # Process images of video + for idx, frame in enumerate(video): + self._image_processor.do_resize = False + self._image_processor.do_center_crop = False + frame = process_anyres_video(frame, self._image_processor) + + if frame_idx is not None and idx in frame_idx: + video_processed.append(frame.unsqueeze(0)) + elif frame_idx is None: + video_processed.append(frame.unsqueeze(0)) + + if frame_idx is None: + frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() + + video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda") + video_processed = (video_processed, video_processed) + + video_data = (video_processed, (384, 384), "video") + + elif isinstance(visual, PIL.Image.Image): # For Image + if MODALITY is None: + MODALITY = "IMAGE" + self._image_processor.do_resize = False + self._image_processor.do_center_crop = False + image_sizes.append(visual.size) + image_tensor_, image_highres_tensor_ = process_anyres_highres_image(visual, self._image_processor) + image_tensor.append(image_tensor_) + image_highres_tensor.append(image_highres_tensor_) + if all(x.shape == image_tensor[0].shape for x in image_tensor): + image_tensor = torch.stack(image_tensor, dim=0) + if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor): + image_highres_tensor = torch.stack(image_highres_tensor, dim=0) + if type(image_tensor) is list: + image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor] + else: + image_tensor = image_tensor.bfloat16().to("cuda") + if type(image_highres_tensor) is list: + image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor] + else: + image_highres_tensor = image_highres_tensor.bfloat16().to("cuda") + + # Processing dummy audio, as required by model + speechs.append(torch.zeros(1, 3000, 128).bfloat16().to("cuda")) + speech_lengths.append(torch.LongTensor([3000]).to("cuda")) + speech_wavs.append(torch.zeros([1, 480000]).to("cuda")) + speech_chunks.append(torch.LongTensor([1]).to("cuda")) + + elif isinstance(visual, dict) and "array" in visual: # For Audio + if MODALITY is None: + MODALITY = "AUDIO" + mels, speech_length, speech_chunk, speech_wav = self.process_audio(visual["array"], visual["sampling_rate"]) + speechs.append(mels.bfloat16().to("cuda")) + speech_lengths.append(speech_length.to("cuda")) + speech_chunks.append(speech_chunk.to("cuda")) + speech_wavs.append(speech_wav.to("cuda")) + + # Processing dummy images, as required by model + images.append(torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device="cuda", non_blocking=True)) + images_highres.append(torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device="cuda", non_blocking=True)) + image_sizes.append((224, 224)) + + if not video_processed and MODALITY == "VIDEO": + # If video is not processed, skip the iteration + pbar.update(1) + continue + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tokenizer.decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" + # Okay be I am assuming bs always == 1 + qs = list(contexts)[0] + if self.model.config.mm_use_im_start_end: + if MODALITY == "AUDIO": + qs = DEFAULT_IM_START_TOKEN + DEFAULT_SPEECH_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + elif MODALITY == "IMAGE": + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + elif MODALITY == "VIDEO": + qs = DEFAULT_IM_START_TOKEN + DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + else: + if MODALITY == "AUDIO": + qs = DEFAULT_SPEECH_TOKEN + "\n" + qs + elif MODALITY == "IMAGE": + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + elif MODALITY == "VIDEO": + qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{prompt}\n") + + if MODALITY == "AUDIO": + input_ids = tokenizer_speech_token(prompt, self.tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self._device) + elif MODALITY == "IMAGE": + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self._device) + elif MODALITY == "VIDEO": + input_ids = tokenizer_speech_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to("cuda") + pad_token_ids = 151643 + attention_masks = input_ids.ne(pad_token_ids).long().to(self.device) + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 256 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + try: + with torch.inference_mode(): + if MODALITY == "AUDIO": + output_ids = self.model.generate( + input_ids, + images=images, + images_highres=images_highres, + image_sizes=image_sizes, + modalities=["text"], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + elif MODALITY == "IMAGE": + output_ids = self.model.generate( + inputs=input_ids, + images=image_tensor, + images_highres=image_highres_tensor, + image_sizes=image_sizes, + modalities=["image"], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + elif MODALITY == "VIDEO": + output_ids = self.model.generate( + inputs=input_ids, + images=video_data[0][0], + images_highres=video_data[0][1], + modalities=video_data[2], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + except Exception as e: + eval_logger.error(f"Error {e} in generating") + outputs = "" + res.append(outputs) + pbar.update(1) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), outputs) + continue + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{outputs}\n") + + res.append(outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation") diff --git a/lmms_eval/models/openai_compatible.py b/lmms_eval/models/openai_compatible.py new file mode 100644 index 000000000..7be1cd91e --- /dev/null +++ b/lmms_eval/models/openai_compatible.py @@ -0,0 +1,225 @@ +import base64 +import json +import os +import time +from io import BytesIO +from typing import List, Tuple, Union + +import numpy as np +import requests as url_requests +from accelerate import Accelerator, DistributedType +from tqdm import tqdm + +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +try: + from decord import VideoReader, cpu +except ImportError: + pass + +from dotenv import find_dotenv, load_dotenv +from loguru import logger as eval_logger +from openai import OpenAI +from PIL import Image + +load_dotenv(verbose=True) + + +@register_model("openai_compatible") +class OpenAICompatible(lmms): + def __init__( + self, + model_version: str = "grok-2-latest", + timeout: int = 120, + max_retries: int = 5, + max_size_in_mb: int = 20, + continual_mode: bool = False, + response_persistent_folder: str = None, + **kwargs, + ) -> None: + super().__init__() + self.model_version = model_version + self.timeout = timeout + self.max_retries = max_retries + self.max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image + self.continual_mode = continual_mode + if self.continual_mode: + if response_persistent_folder is None: + raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") + + os.makedirs(response_persistent_folder, exist_ok=True) + self.response_persistent_folder = response_persistent_folder + self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") + + if os.path.exists(self.response_persistent_file): + with open(self.response_persistent_file, "r") as f: + self.response_cache = json.load(f) + self.cache_mode = "resume" + else: + self.response_cache = {} + self.cache_mode = "start" + + self.client = OpenAI(api_key=os.getenv("OPENAI_COMPATIBLE_API_KEY"), base_url=os.getenv("OPENAI_COMPATIBLE_API_URL")) + + accelerator = Accelerator() + # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.accelerator = accelerator + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + self.device = self.accelerator.device + + # Function to encode the image + def encode_image(self, image: Union[Image.Image, str]): + max_size = self.max_size_in_mb * 1024 * 1024 # 20MB in bytes + if isinstance(image, str): + img = Image.open(image).convert("RGB") + else: + img = image.copy() + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + + # If image is too large, resize it while maintaining aspect ratio + while len(byte_data) > max_size and img.size[0] > 100 and img.size[1] > 100: + new_size = (int(img.size[0] * 0.75), int(img.size[1] * 0.75)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + + base64_str = base64.b64encode(byte_data).decode("utf-8") + return base64_str + + # Function to encode the video + def encode_video(self, video_path, for_get_frames_num): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int) + + # Ensure the last frame is included + if total_frame_num - 1 not in uniform_sampled_frames: + uniform_sampled_frames = np.append(uniform_sampled_frames, total_frame_num - 1) + + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + img = Image.fromarray(frame) + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode("utf-8") + base64_frames.append(base64_str) + + return base64_frames + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests) -> List[str]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + if self.continual_mode is True and self.cache_mode == "resume": + doc_uuid = f"{task}___{split}___{doc_id}" + if doc_uuid in self.response_cache: + response_text = self.response_cache[doc_uuid] + if response_text: + res.append(response_text) + pbar.update(1) + continue + + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + if None in visuals: + visuals = [] + imgs = [] + else: + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + for visual in visuals: + if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual): + frames = self.encode_video(visual, self.max_frames_num) + imgs.extend(frames) + elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual): + img = self.encode_image(visual) + imgs.append(img) + elif isinstance(visual, Image.Image): + img = self.encode_image(visual) + imgs.append(img) + + payload = {"messages": []} + payload["model"] = self.model_version + + # When there is no image token in the context, append the image to the text + payload["messages"].append({"role": "user", "content": []}) + payload["messages"][0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if gen_kwargs["max_new_tokens"] > 4096: + gen_kwargs["max_new_tokens"] = 4096 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + payload["max_tokens"] = gen_kwargs["max_new_tokens"] + payload["temperature"] = gen_kwargs["temperature"] + + for attempt in range(self.max_retries): + try: + response = self.client.chat.completions.create(**payload) + response_text = response.choices[0].message.content + break # If successful, break out of the loop + + except Exception as e: + error_msg = str(e) + eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}") + + # On last attempt, log error and set empty response + if attempt == self.max_retries - 1: + eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}") + response_text = "" + else: + time.sleep(self.timeout) + + res.append(response_text) + pbar.update(1) + + if self.continual_mode is True: # Cache the response + doc_uuid = f"{task}___{split}___{doc_id}" + self.response_cache[doc_uuid] = response_text + with open(self.response_persistent_file, "w") as f: + json.dump(self.response_cache, f) + + pbar.close() + return res + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation for OpenAI compatible models") + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + raise NotImplementedError("TODO: Implement loglikelihood for OpenAI compatible models") diff --git a/lmms_eval/models/vllm.py b/lmms_eval/models/vllm.py new file mode 100644 index 000000000..3b125718f --- /dev/null +++ b/lmms_eval/models/vllm.py @@ -0,0 +1,194 @@ +import asyncio +import base64 +import json +import os +import time +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from io import BytesIO +from multiprocessing import cpu_count +from typing import List, Optional, Tuple, Union + +import numpy as np +from accelerate import Accelerator, DistributedType +from decord import VideoReader, cpu +from loguru import logger as eval_logger +from PIL import Image +from tqdm import tqdm + +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +NUM_SECONDS_TO_SLEEP = 5 + +try: + from vllm import LLM, SamplingParams +except ImportError: + vllm = None + + +@register_model("vllm") +class VLLM(lmms): + def __init__( + self, + model_version: str = "Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.8, + batch_size: int = 1, + timeout: int = 60, + max_images: int = 32, + max_videos: int = 8, + max_audios: int = 8, + max_frame_num: int = 32, + threads: int = 16, # Threads to use for decoding visuals + trust_remote_code: Optional[bool] = True, + **kwargs, + ) -> None: + super().__init__() + # Manually set a image token for GPT4V so that we can search for it + # and split the text and image + # Here we just use the same token as llava for convenient + self.model_version = model_version + self.max_images = max_images + self.max_frame_num = max_frame_num + self.threads = threads + + accelerator = Accelerator() + self.client = LLM( + model=self.model_version, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={"image": max_images, "video": max_videos, "audio": max_audios}, + trust_remote_code=trust_remote_code, + ) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.accelerator = accelerator + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + self.device = self.accelerator.device + self.batch_size_per_gpu = int(batch_size) + + # Function to encode the image + def encode_image(self, image: Union[Image.Image, str]): + if isinstance(image, str): + img = Image.open(image).convert("RGB") + else: + img = image.copy() + + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + + base64_str = base64.b64encode(byte_data).decode("utf-8") + return base64_str + + # Function to encode the video + def encode_video(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frame_num, dtype=int) + + # Ensure the last frame is included + if total_frame_num - 1 not in uniform_sampled_frames: + uniform_sampled_frames = np.append(uniform_sampled_frames, total_frame_num - 1) + + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + img = Image.fromarray(frame) + output_buffer = BytesIO() + img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode("utf-8") + base64_frames.append(base64_str) + + return base64_frames + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests) -> List[str]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + batch_size = self.batch_size_per_gpu + batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)] + for batch_requests in batched_requests: + batched_messages = [] + for idx in range(len(batch_requests)): + contexts, gen_kwargs, doc_to_visual, doc_id, task, split = batch_requests[idx].arguments + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if gen_kwargs["max_new_tokens"] > 4096: + gen_kwargs["max_new_tokens"] = 4096 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = 0.95 + + params = { + "temperature": gen_kwargs["temperature"], + "max_tokens": gen_kwargs["max_new_tokens"], + "top_p": gen_kwargs["top_p"], + } + sampling_params = SamplingParams(**params) + + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + if None in visuals: + visuals = [] + imgs = [] + else: + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + all_tasks = [] + with ThreadPoolExecutor(max_workers=self.threads) as executor: + for visual in visuals: + if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual): + all_tasks.append(executor.submit(self.encode_video, visual)) + elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual): + all_tasks.append(executor.submit(self.encode_image, visual)) + elif isinstance(visual, Image.Image): + all_tasks.append(executor.submit(self.encode_image, visual)) + + for task in all_tasks: + imgs.append(task.result()) + + messages = [{"role": "user", "content": []}] + # When there is no image token in the context, append the image to the text + messages[0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + + batched_messages.append(messages) + + response = self.client.chat(sampling_params=sampling_params, messages=batched_messages) + response_text = [o.outputs[0].text for o in response] + + assert len(response_text) == len(batch_requests) + res.extend(response_text) + pbar.update(len(batch_requests)) + + pbar.close() + return res + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "GPT4V not support" + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation") diff --git a/lmms_eval/tasks/charades_sta/charades.yaml b/lmms_eval/tasks/charades_sta/charades.yaml new file mode 100644 index 000000000..b9f140575 --- /dev/null +++ b/lmms_eval/tasks/charades_sta/charades.yaml @@ -0,0 +1,30 @@ +dataset_path: lmms-lab/charades_sta +dataset_kwargs: + token: True + cache_dir: charades_sta + video: True +task: temporal_grounding_charades +test_split: test + +generation_kwargs: + max_new_tokens: 50 + temperature: 0 + top_p: 1.0 + num_beams: 1 + do_sample: false + +output_type: generate_until +doc_to_visual: !function utils.temporal_grounding_doc_to_visual +doc_to_text: !function utils.temporal_grounding_doc_to_text +doc_to_target: !function utils.temporal_grounding_doc_to_answer +process_results: !function utils.temporal_grounding_process_results_generation + + +metric_list: + - metric: submission + aggregation: !function utils.temporal_grounding_aggregate_charades + higher_is_better: true +lmms_eval_specific_kwargs: + default: + pre_prompt: "Please find the visual event described by a sentence in the video, determining its starting and ending times. The format should be: 'The event happens in the start time - end time'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seonds. Now I will give you the textual sentence: " + post_prompt: "Please return its start time and end time." \ No newline at end of file diff --git a/lmms_eval/tasks/charades_sta/eval_tvg.py b/lmms_eval/tasks/charades_sta/eval_tvg.py new file mode 100644 index 000000000..6a43841fe --- /dev/null +++ b/lmms_eval/tasks/charades_sta/eval_tvg.py @@ -0,0 +1,135 @@ +import argparse +import json +import os +import pdb +import re +from copy import deepcopy +from pathlib import Path + +import numpy as np + + +# read json files +def read_json(path): + with open(path, "r") as fin: + datas = json.load(fin) + return datas + + +def write_json(path, data): + with open(path, "w") as fout: + json.dump(data, fout) + print("The format file has been saved at:{}".format(path)) + return + + +def extract_time(paragraph): + prompt = "A specific example is : 20.8 - 30.0 seconds".lower() + paragraph = paragraph.lower().replace(prompt, "").replace("to", "-") + # Split text into sentences based on common delimiters + sentences = re.split(r"[!?\n]", paragraph) + + # Keywords that might indicate the presence of time information + keywords = ["starts", "ends", "happens in", "start time", "end time", "start", "end", "happen"] + # filter sentences by keywords + candidates = [] + for sentence in sentences: + # If sentence contains one of the keywords + if any(keyword in sentence for keyword in keywords): + candidates.append(sentence) + + timestamps = [] + # Check for The given query happens in m - n (seconds) + patterns = [r"(\d+\.*\d*)\s*-\s*(\d+\.*\d*)"] + + for time_pattern in patterns: + time_matches = re.findall(time_pattern, paragraph) + if time_matches: + timestamps = [[float(start), float(end)] for start, end in time_matches] + + if len(sentences) == 0: + return [] + # check for other formats e.g.: + # 1 .Starting time: 0.8 seconds + # Ending time: 1.1 seconds + # 2. The start time for this event is 0 seconds, and the end time is 12 seconds. + if len(timestamps) == 0: + times = [] + time_regex = re.compile(r"\b(\d+\.\d+\b|\b\d+)\b") # time formats (e.g., 18, 18.5) + for sentence in candidates: + time = re.findall(time_regex, sentence) + if time: + time_in_sec = float(time[0]) + times.append(time_in_sec) + times = times[: len(times) // 2 * 2] + timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)] + # Check for examples like: + # 3. The event 'person flipped the light switch near the door' starts at 00:00:18 and ends at 00:00:23. + if len(timestamps) == 0: + times = [] + time_regex = re.compile(r"\b((\d{1,2}:\d{2}:\d{2}))\b") # time formats (e.g., 18:00, 00:18:05) + for sentence in candidates: + time = re.findall(time_regex, sentence) + if time: + t = time[0] + else: + continue + # If time is in HH:MM:SS format, convert to seconds + if t.count(":") == 2: + h, m, s = map(int, t.split(":")) + time_in_sec = h * 3600 + m * 60 + s + elif t.count(":") == 1: + m, s = map(int, t.split(":")) + time_in_sec = m * 60 + s + times.append(time_in_sec) + times = times[: len(times) // 2 * 2] + timestamps = [(times[i], times[i + 1]) for i in range(0, len(times), 2)] + results = [] + for start, end in timestamps: + if end > start: + results.append([start, end]) + else: + results.append([end, start]) + if len(results) > 1: + results = results[:1] + return results + + +def iou(A, B): + max0 = max((A[0]), (B[0])) + min0 = min((A[0]), (B[0])) + max1 = max((A[1]), (B[1])) + min1 = min((A[1]), (B[1])) + return max(min1 - max0, 0) / (max1 - min0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-f", default="your_result.json") + args = parser.parse_args() + + datas = read_json(args.f) + + num = len(datas) + + # miou + ious = [] + for k in datas.keys(): + vid, caption, gt = k.split(">>>") + pred = datas[k] + gt = eval(gt) + timestamps = extract_time(pred) + if len(timestamps) != 1: + print(f"pred={pred},timestamps={timestamps}") + timestamps = [[gt[1] + 10, gt[1] + 20]] + # print(f"GT: {gt}, Pred: {timestamps[0]}") + + ious.append(iou(gt, timestamps[0])) + + Result = {0.3: 0, 0.5: 0, 0.7: 0} + for c_iou in [0.3, 0.5, 0.7]: + for cur_iou in ious: + if cur_iou >= c_iou: + Result[c_iou] = Result[c_iou] + 1 + + print("IOU 0.3: {0}\nIOU 0.5: {1}\nIOU 0.7: {2}\nmIOU".format(Result[0.3] * 100 / num, Result[0.5] * 100 / num, Result[0.7] * 100 / num), sum(ious) * 100 / num) diff --git a/lmms_eval/tasks/charades_sta/utils.py b/lmms_eval/tasks/charades_sta/utils.py new file mode 100644 index 000000000..8a5b4509e --- /dev/null +++ b/lmms_eval/tasks/charades_sta/utils.py @@ -0,0 +1,102 @@ +import datetime +import json +import os +import random +import sys +from pathlib import Path + +import numpy as np +import yaml +from decord import VideoReader, cpu +from loguru import logger as eval_logger + +import lmms_eval.tasks._task_utils.file_utils as file_utils + +# with open(Path(__file__).parent / "_default_template.yaml", "r") as f: +# raw_data = f.readlines() +# safe_data = [] +# for i, line in enumerate(raw_data): +# # remove function definition since yaml load cannot handle it +# if "!function" not in line: +# safe_data.append(line) + +# config = yaml.safe_load("".join(safe_data)) + + +hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") +# cache_dir = os.path.join(hf_home, cache_dir) +# base_cache_dir = config["dataset_kwargs"]["cache_dir"] +base_cache_dir = os.path.expanduser(hf_home) +with open(Path(__file__).parent / "charades.yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + +cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"] + + +# DATA_LIST = { +# "charades": 'your_data_dir/Charades/', +# } +# Pass in video path here +# Can only work correctly with video llm +def temporal_grounding_doc_to_visual(doc, lmms_eval_specific_kwargs=None): + video_path = doc["video"] + cache_dir = os.path.join(base_cache_dir, cache_name) + video_path = os.path.join(cache_dir, "Charades_v1_480", video_path) + if os.path.exists(video_path): + video_path = video_path + elif "s3://" not in video_path: + sys.exit(f"video path:{video_path} does not exist, please check") + + return [video_path] + + +# This is the place where you format your question +def temporal_grounding_doc_to_text(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + question = doc["caption"] + + return f"{pre_prompt}{question}. {post_prompt}" + + +def temporal_grounding_doc_to_answer(doc): + return doc["timestamp"] + + +# Process result for mcq answer generation +def temporal_grounding_process_results_generation(doc, result): + pred = result[0] + return {"submission": {f'{doc["video"]}>>>{doc["caption"]}>>>{doc["timestamp"]}': pred}} + + +def temporal_grounding_aggregate_charades(results, args): + temporal_grounding_aggregate_submissions(results, args, "charades") + + +def temporal_grounding_aggregate_submissions(results, args, task): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"inference_results_temporal_grounding_{task}_{now_date_time}.json" + path = file_utils.generate_submission_file(submission_file_name, args) + + # results is a list of 5031 dict, + # need to convert results into a single dict with 5031 key-value pairs + combined_submission = {} + + for submission_dict in results: + combined_submission.update(submission_dict) + + with open(path, "w") as f: + json.dump(combined_submission, f, indent=4) + + eval_logger.info(f"Submission file saved to {path}") diff --git a/lmms_eval/tasks/egoplan/egoplan.yaml b/lmms_eval/tasks/egoplan/egoplan.yaml new file mode 100644 index 000000000..a4aec0785 --- /dev/null +++ b/lmms_eval/tasks/egoplan/egoplan.yaml @@ -0,0 +1,43 @@ +dataset_path: EgoLife-v1/EgoPlan +dataset_kwargs: + token: True + cache_dir: egoplan + video: True + # From_YouTube: True +task: egoplan +test_split: validation +output_type: generate_until +doc_to_visual: !function utils.egoplan_doc_to_visual +doc_to_text: !function utils.egoplan_doc_to_text +doc_to_target: "answer" +generation_kwargs: + max_new_tokens: 4096 + temperature: 0 + top_p: 1.0 + num_beams: 1 + do_sample: false +# The return value of process_results will be used by metrics +process_results: !function utils.egoplan_process_results +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: egoplan_mcq_accuracy + aggregation: !function utils.egoplan_aggregate_results + higher_is_better: true +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "\nAnswer with the option's letter from the given choices directly." + gpt4v: + pre_prompt: "" + post_prompt: "\nAnswer the question with A, B, C, or D." + # qwen_vl: + # pre_prompt: "" + # post_prompt: " Answer:" + # otterhd: + # pre_prompt: "" + # post_prompt: " Answer:" + xcomposer2_4khd: + pre_prompt: "[UNUSED_TOKEN_146]user\n" + post_prompt: " Answer this question with A, B, C, or D.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n" +metadata: + version: 0.0 diff --git a/lmms_eval/tasks/egoplan/utils.py b/lmms_eval/tasks/egoplan/utils.py new file mode 100644 index 000000000..35e5c6234 --- /dev/null +++ b/lmms_eval/tasks/egoplan/utils.py @@ -0,0 +1,207 @@ +import datetime +import json +import os +import re +import sys +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Union + +import cv2 +import numpy as np +import yaml +from loguru import logger as eval_logger + +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file + +# with open(Path(__file__).parent / "_default_template_yaml", "r") as f: +# raw_data = f.readlines() +# safe_data = [] +# for i, line in enumerate(raw_data): +# # remove function definition since yaml load cannot handle it +# if "!function" not in line: +# safe_data.append(line) + +# config = yaml.safe_load("".join(safe_data)) + +hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") +# cache_dir = os.path.join(hf_home, cache_dir) +# base_cache_dir = config["dataset_kwargs"]["cache_dir"] +base_cache_dir = os.path.expanduser(hf_home) +with open(Path(__file__).parent / "egoplan.yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) +cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"] + + +def parse_subtitle_time(time_str): + h, m, s_ms = time_str.split(":") + s, ms = s_ms.split(",") + return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 + + +def load_subtitles(subtitle_path): + subtitles = {} + with open(subtitle_path, "r", encoding="utf-8") as file: + content = file.read().split("\n\n") + for section in content: + if section.strip(): + lines = section.split("\n") + if len(lines) >= 3: + time_range = lines[1].split(" --> ") + start_time = parse_subtitle_time(time_range[0]) + end_time = parse_subtitle_time(time_range[1]) + text = " ".join(line for line in lines[2:]) + subtitles[(start_time, end_time)] = text + return subtitles + + +def convert_time_to_frame(time_in_seconds, fps): + return int(time_in_seconds * fps) + + +def extract_subtitles(video_path, subtitle_path): + video = cv2.VideoCapture(video_path) + fps = video.get(cv2.CAP_PROP_FPS) + total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + subtitles = load_subtitles(subtitle_path) + + subtitle_frames = [] + for (start_time, end_time), text in subtitles.items(): + start_frame = convert_time_to_frame(start_time, fps) + end_frame = convert_time_to_frame(end_time, fps) + subtitle_frames.append((start_frame, end_frame, text)) + + return subtitle_frames, total_frame + + +def parse_subtitle_time(time_str): + h, m, s_ms = time_str.split(":") + s, ms = s_ms.split(",") + return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 + + +def load_subtitles(subtitle_path): + subtitles = {} + with open(subtitle_path, "r", encoding="utf-8") as file: + content = file.read().split("\n\n") + for section in content: + if section.strip(): + lines = section.split("\n") + if len(lines) >= 3: + time_range = lines[1].split(" --> ") + start_time = parse_subtitle_time(time_range[0]) + end_time = parse_subtitle_time(time_range[1]) + text = " ".join(line for line in lines[2:]) + subtitles[(start_time, end_time)] = text + return subtitles + + +def convert_time_to_frame(time_in_seconds, fps): + return int(time_in_seconds * fps) + + +def extract_subtitles(video_path, subtitle_path): + video = cv2.VideoCapture(video_path) + fps = video.get(cv2.CAP_PROP_FPS) + total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + subtitles = load_subtitles(subtitle_path) + + subtitle_frames = [] + for (start_time, end_time), text in subtitles.items(): + start_frame = convert_time_to_frame(start_time, fps) + end_frame = convert_time_to_frame(end_time, fps) + subtitle_frames.append((start_frame, end_frame, text)) + + return subtitle_frames, total_frame + + +def egoplan_doc_to_visual(doc): + cache_dir = os.path.join(base_cache_dir, cache_name) + video_path = str(doc["sample_id"]) + ".mp4" + video_path = os.path.join(cache_dir, video_path) + if os.path.exists(video_path): + video_path = video_path + elif os.path.exists(video_path.replace("mp4", "MP4")): + video_path = video_path.replace("mp4", "MP4") + elif os.path.exists(video_path.replace("mp4", "mkv")): + video_path = video_path.replace("mp4", "mkv") + else: + sys.exit(f"video path:{video_path} does not exist, please check") + return [video_path] + + +def egoplan_doc_to_text(doc, lmms_eval_specific_kwargs=None): + task_goal = doc["task_goal"] + if "goal" in task_goal: + task_goal = task_goal.split("to", 1)[1].strip() + words = task_goal.split() + if words[0].endswith("ing"): + question_pattern = ( + "I am tasked with {}. " + "The task's progress is demonstrated in the provided video. " + "My current field of view is shown in the provided image. " + "What should be my next action? " + "Please output the most reasonable action you think, expressed in a short phrase." + ) + else: + question_pattern = ( + "My current task is to {}. " + "The task's progress is demonstrated in the provided video. " + "My current field of view is shown in the provided image. " + "What should be my next action? " + "Please output the most reasonable action you think, expressed in a short phrase." + ) + question = question_pattern.format(task_goal) + + candidates = [] + for choice_idx in ["A", "B", "C", "D"]: + question += "\n" + f"{choice_idx}. " + (doc[f"choice_{choice_idx.lower()}"]) + post_prompt = "\nAnswer with the option's letter from the given choices" + + return f"{question}{post_prompt}" + + +def extract_characters_regex(s): + s = s.strip() + answer_prefixes = [ + "The best answer is", + "The correct answer is", + "The answer is", + "The answer", + "The best option is" "The correct option is", + "Best answer:" "Best option:", + ] + for answer_prefix in answer_prefixes: + s = s.replace(answer_prefix, "") + + if len(s.split()) > 10 and not re.search("[ABCD]", s): + return "" + + matches = re.search(r"[ABCD]", s) + if matches is None: + return "" + return matches[0] + + +def egoplan_process_results(doc, results): + pred = results[0] + pred_ans = extract_characters_regex(pred) + # gt_ans = doc["answer"].lower().strip().replace(".", "") + doc["pred_answer"] = pred_ans + data_dict = doc.copy() + return {f"egoplan_mcq_accuracy": data_dict} + + +def egoplan_aggregate_results(results): + correct_num = 0 + for result in results: + if result["pred_answer"] == result["golden_choice_idx"]: + correct_num += 1 + question_num = len(results) + accuracy = correct_num / question_num + return accuracy diff --git a/lmms_eval/tasks/egothink/_default_template_yaml b/lmms_eval/tasks/egothink/_default_template_yaml new file mode 100644 index 000000000..546be29aa --- /dev/null +++ b/lmms_eval/tasks/egothink/_default_template_yaml @@ -0,0 +1,7 @@ +dataset_path: EgoLife-v1/Egothink +dataset_kwargs: + token: True +test_split: test +metadata: + version: 0.0 + gpt_eval_model_name: "gpt-4" \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink.yaml b/lmms_eval/tasks/egothink/egothink.yaml new file mode 100644 index 000000000..f8bbfe74e --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink.yaml @@ -0,0 +1,14 @@ +group: egothink +task: + - egothink_activity + - egothink_affordance + - egothink_assistance + - egothink_navigation + - egothink_attribute + - egothink_comparing + - egothink_counting + - egothink_existence + - egothink_forecasting + - egothink_location + - egothink_situated + - egothink_spatial diff --git a/lmms_eval/tasks/egothink/egothink_activity.yaml b/lmms_eval/tasks/egothink/egothink_activity.yaml new file mode 100644 index 000000000..4df6756c8 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_activity.yaml @@ -0,0 +1,24 @@ +dataset_name: "Activity" +task: "egothink_activity" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_affordance.yaml b/lmms_eval/tasks/egothink/egothink_affordance.yaml new file mode 100644 index 000000000..3e0cae856 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_affordance.yaml @@ -0,0 +1,24 @@ +dataset_name: "Object_affordance" +task: "egothink_affordance" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_assistance.yaml b/lmms_eval/tasks/egothink/egothink_assistance.yaml new file mode 100644 index 000000000..81b4e0e80 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_assistance.yaml @@ -0,0 +1,24 @@ +dataset_name: "Planning_assistance" +task: "egothink_assistance" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 300 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in a detailed and helpful way. USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_attribute.yaml b/lmms_eval/tasks/egothink/egothink_attribute.yaml new file mode 100644 index 000000000..7466e874c --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_attribute.yaml @@ -0,0 +1,24 @@ +dataset_name: "Object_attribute" +task: "egothink_attribute" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual content, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_comparing.yaml b/lmms_eval/tasks/egothink/egothink_comparing.yaml new file mode 100644 index 000000000..c91399c9c --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_comparing.yaml @@ -0,0 +1,24 @@ +dataset_name: "Reasoning_comparing" +task: "egothink_comparing" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_counting.yaml b/lmms_eval/tasks/egothink/egothink_counting.yaml new file mode 100644 index 000000000..fcc0246ee --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_counting.yaml @@ -0,0 +1,24 @@ +dataset_name: "Reasoning_counting" +task: "egothink_counting" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_existence.yaml b/lmms_eval/tasks/egothink/egothink_existence.yaml new file mode 100644 index 000000000..d54b7a928 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_existence.yaml @@ -0,0 +1,24 @@ +dataset_name: "Object_existence" +task: "egothink_existence" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_forecasting.yaml b/lmms_eval/tasks/egothink/egothink_forecasting.yaml new file mode 100644 index 000000000..4688ffa5f --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_forecasting.yaml @@ -0,0 +1,24 @@ +dataset_name: "Forecasting" +task: "egothink_forecasting" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_location.yaml b/lmms_eval/tasks/egothink/egothink_location.yaml new file mode 100644 index 000000000..0971abe27 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_location.yaml @@ -0,0 +1,24 @@ +dataset_name: "Localization_location" +task: "egothink_location" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_navigation.yaml b/lmms_eval/tasks/egothink/egothink_navigation.yaml new file mode 100644 index 000000000..ae3a14cb1 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_navigation.yaml @@ -0,0 +1,24 @@ +dataset_name: "Planning_navigation" +task: "egothink_navigation" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 300 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in a detailed and helpful way. USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_situated.yaml b/lmms_eval/tasks/egothink/egothink_situated.yaml new file mode 100644 index 000000000..22b15f48a --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_situated.yaml @@ -0,0 +1,24 @@ +dataset_name: "Reasoning_situated" +task: "egothink_situated" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/egothink_spatial.yaml b/lmms_eval/tasks/egothink/egothink_spatial.yaml new file mode 100644 index 000000000..31f3dedd0 --- /dev/null +++ b/lmms_eval/tasks/egothink/egothink_spatial.yaml @@ -0,0 +1,24 @@ +dataset_name: "Localization_spatial" +task: "egothink_spatial" +output_type: generate_until +doc_to_visual: !function utils.egothink_doc_to_visual +doc_to_text: !function utils.egothink_doc_to_text +doc_to_target: !function utils.egothink_doc_to_answer +generation_kwargs: + max_new_tokens: 30 + temperature: 0.2 + num_beams: 1 + do_sample: True +# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results +metric_list: + - metric: gpt_eval_score + aggregation: !function utils.egothink_aggregate_results + higher_is_better: true +# The return value of process_results will be used by metrics +process_results: !function utils.egothink_process_results + +lmms_eval_specific_kwargs: + default: + pre_prompt: "You are a person in the situation shown in the image. \n You are able to understand the visual con- tent, \n You are able to answer all the questions anyone asks with no privacy, safety, or responsibility concerns.\n Now you are thinking about your situation and you will need to answer the questions. Answer the questions in the first-person perspective.\n Keep your answer as short as possible! Keep your answer as short as possible! Keep your answer as short as possible! USER: \n" + post_prompt: " ASSISTANT:" +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/egothink/utils.py b/lmms_eval/tasks/egothink/utils.py new file mode 100644 index 000000000..af763c952 --- /dev/null +++ b/lmms_eval/tasks/egothink/utils.py @@ -0,0 +1,188 @@ +import ast +import datetime +import json +import os +import re +import sys +import time +from pathlib import Path + +import numpy as np +import openai +import requests +import yaml +from loguru import logger as eval_logger +from openai import OpenAI +from tqdm import tqdm + +import lmms_eval.tasks._task_utils.file_utils as file_utils + +dir_name = os.path.dirname(os.path.abspath(__file__)) + +one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") +one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +API_ERROR_OUTPUT = "$ERROR$" + +API_MAX_RETRY = 6 + +NUM_SECONDS_TO_SLEEP = 15 + +GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] + +API_TYPE = os.getenv("API_TYPE", "openai") + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } +else: + API_URL = "YOUR_API_URL" + API_KEY = "YOUR_API_KEY" + + +def egothink_doc_to_visual(doc): + return [doc["image"].convert("RGB")] + + +# format the question +def egothink_doc_to_text(doc, lmms_eval_specific_kwargs=None): + question = doc["question"].strip() + if "pre_prompt" in lmms_eval_specific_kwargs and lmms_eval_specific_kwargs["pre_prompt"] != "": + question = f"{lmms_eval_specific_kwargs['pre_prompt']}{question}" + if "post_prompt" in lmms_eval_specific_kwargs and lmms_eval_specific_kwargs["post_prompt"] != "": + question = f"{question}{lmms_eval_specific_kwargs['post_prompt']}" + return question + + +# format answer +def egothink_doc_to_answer(doc): + return doc["answer"] + + +# Process result for evaluation in generic task +def chat_compeletion_openai(model, messages, temperature, max_tokens): + # headers = { + # "Authorization": f"Bearer {API_KEY}", + # "Content-Type": "application/json", + # } + # headers = { + # "Authorization": f"Bearer {API_KEY}", + # "Content-Type": "application/json", + # } + headers = { + "Content-Type": "application/json", + "api-key": API_KEY, + } + output = API_ERROR_OUTPUT + payload = { + # "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + + for attempt in range(API_MAX_RETRY): + try: + response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + response.raise_for_status() # Raises HTTPError for bad responses + try: + response_data = response.json() # Attempt to parse JSON + except requests.exceptions.JSONDecodeError: + eval_logger.error(f"JSON decode error on attempt {attempt + 1}. Response text: {response.text}") + continue # Skip to next retry + content = response_data["choices"][0]["message"]["content"].strip() + if content != "": + return content, response_data["model"] + # Handle HTTP errors separately + except requests.exceptions.HTTPError as e: + eval_logger.error(f"HTTP error on attempt {attempt + 1}: {e}") + # Handle other requests-related errors + except requests.exceptions.RequestException as e: + eval_logger.error(f"Request exception on attempt {attempt + 1}: {e}") + except Exception as e: + eval_logger.error(f"Unexpected error on attempt {attempt + 1}: {e}") + + # Handle other unexpected errors + if attempt < API_MAX_RETRY - 1: + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {retries} attempts failed. Last error message: {e}") + return "", "" + + return "", "" + + +def judge_single(question, answer, ref_answer): + model = GPT_EVAL_MODEL_NAME + + rating = -1 + + conv = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": f"[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. The assistant has access to an image alongwith questions but you will not be given images. Therefore, please consider only how the answer is close to the reference answer. If the assistant's answer is not exactly same as or similar to the answer, then he must be wrong. Be as objective as possible. Discourage uninformative answers. Also, equally treat short and long answers and focus on the correctness of answers. After providing your explanation, you must rate the response with either 0, 0.5 or 1 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[0.5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", + }, + ] + + judgment, eval_model = chat_compeletion_openai(model, conv, temperature=0, max_tokens=2048) + for _ in range(3): + match = re.search(one_score_pattern, judgment) + if not match: + match = re.search(one_score_pattern_backup, judgment) + + if match: + rating = ast.literal_eval(match.groups()[0]) + break + else: + rating = -1 + return rating, judgment, eval_model + + +def egothink_process_results(doc, results): + """ + Args: + doc: a instance of the eval dataset + results: [pred] + Returns: + a dictionary with key: metric name (in this case mme score), value: metric value + """ + pred = results[0] + question = doc["question"] + ref_ans = doc["answer"].lower().strip().replace(".", "") + score, judge, eval_model = judge_single(question, pred, ref_ans) + return {"gpt_eval_score": {"question_id": doc["id"], "score": score, "judge": judge, "eval_model": eval_model}} + + +def egothink_aggregate_results(results): + """ + Args: + results: a list of values returned by process_results + Returns: + A score + """ + total_score = 0 + for result in results: + total_score += result["score"] + return total_score / len(results) diff --git a/lmms_eval/tasks/mlvu/utils.py b/lmms_eval/tasks/mlvu/utils.py index 8ddea3dd7..77b3e9cfb 100644 --- a/lmms_eval/tasks/mlvu/utils.py +++ b/lmms_eval/tasks/mlvu/utils.py @@ -14,7 +14,7 @@ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file -TASK_TYPES = ["TR", "AR", "VS", "NQA", "ER", "PQA", "SSC", "AO", "AC"] +TASK_TYPES = ["TR", "AR", "NQA", "ER", "PQA", "AO", "AC"] hf_home = os.getenv("HF_HOME", "./~/.cache/huggingface") @@ -105,6 +105,9 @@ def mlvu_aggregate_results(results): category2score[task_type]["answered"] += 1 category2score[task_type]["correct"] += result["pred_answer"] == result["answer"] + task_category_scores = {} + + # Calculate and log accuracy for each task category for task_cate in TASK_TYPES: total_correct = 0 total_answered = 0 @@ -112,13 +115,16 @@ def mlvu_aggregate_results(results): if task_cate in k: total_correct += v["correct"] total_answered += v["answered"] - eval_logger.info(f"Evaluation on Task Categories: {task_cate}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") + accuracy = 100 * total_correct / total_answered if total_answered > 0 else 0 + task_category_scores[task_cate] = accuracy + eval_logger.info(f"Evaluation on Task Categories: {task_cate}: {accuracy:.1f}%") + + # Calculate and log average accuracy across all task categories + if TASK_TYPES: + average_accuracy = sum(task_category_scores.values()) / len(TASK_TYPES) + else: + average_accuracy = 0 - total_correct = 0 - total_answered = 0 - for k, v in category2score.items(): - total_correct += v["correct"] - total_answered += v["answered"] - eval_logger.info(f"Overall Performance: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") + eval_logger.info(f"Average Performance Across All Task Categories: {average_accuracy:.1f}%") - return 100 * total_correct / total_answered if total_answered > 0 else 0 + return average_accuracy diff --git a/lmms_eval/tasks/mmmu/mmmu_val.yaml b/lmms_eval/tasks/mmmu/mmmu_val.yaml index a301f7cb8..5d1394d25 100755 --- a/lmms_eval/tasks/mmmu/mmmu_val.yaml +++ b/lmms_eval/tasks/mmmu/mmmu_val.yaml @@ -13,4 +13,10 @@ metric_list: aggregation: !function utils.mmmu_aggregate_results higher_is_better: true +lmms_eval_specific_kwargs: + default: + prompt_type: "format" + multiple_choice_prompt: "Answer with the option's letter from the given choices directly." + open_ended_prompt: "Answer the question using a single word or phrase." + include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml b/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml index d9d7a21b2..1b3c467a3 100755 --- a/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml +++ b/lmms_eval/tasks/mmmu/mmmu_val_reasoning.yaml @@ -7,6 +7,16 @@ doc_to_text: !function utils.mmmu_doc_to_text doc_to_target: "answer" # The return value of process_results will be used by metrics process_results: !function utils.mmmu_reasoning_process_results +# repeats: 8 +# filter_list: +# # - name: "pass@64" +# # filter: +# # - function: "take_first_k" +# # k: 64 +# - name: "pass@8" +# filter: +# - function: "take_first_k" +# k: 8 metric_list: - metric: mmmu_judge_acc @@ -16,11 +26,15 @@ metric_list: lmms_eval_specific_kwargs: default: prompt_type: "reasoning" - multiple_choice_prompt: "Please show step-by-step reasoning, and answer the question with option letter from given choices." - open_ended_prompt: "Please show step-by-step reasoning, and answer the question using a single word or phrase." + multiple_choice_prompt: "Please reason step by step, and answer the question with option letter from given choices in the format of Answer: