diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index 77b5dde0c1..ea0bac5cbe 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -106,15 +106,26 @@ def build_model(self): no_split_module_classes=['InternVisionEncoderLayer'], dtype=torch.half) - self.model = model + # We need eval mode to freeze the weights in model, thus, + # avoid randomness in inference. + self.model = model.eval() self.config = config if getattr(self.config, 'dynamic_image_size', False): logger.info('using InternVL-Chat-V1-5 vision preprocess') - MEAN = (123.675, 116.28, 103.53) - STD = (58.395, 57.12, 57.375) + MEAN = (0.485, 0.456, 0.406) + STD = (0.229, 0.224, 0.225) import torchvision.transforms as T - self.transform = T.Compose([T.Normalize(mean=MEAN, std=STD)]) + from torchvision.transforms.functional import InterpolationMode + input_size = self.config.vision_config.image_size + self.transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') + if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), + interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) self._forward_func = self._forward_v1_5 else: self.image_processor = CLIPImageProcessor.from_pretrained( @@ -123,7 +134,6 @@ def build_model(self): def _preprocess_v1_5(self, images: List[Image]): outputs = [] - import torchvision.transforms.functional as F for image in images: out = dynamic_preprocess( image, @@ -131,7 +141,7 @@ def _preprocess_v1_5(self, images: List[Image]): max_num=self.config.max_dynamic_patch, image_size=self.config.vision_config.image_size, use_thumbnail=self.config.use_thumbnail) - out = [F.pil_to_tensor(x).half() for x in out] + out = [self.transform(x) for x in out] out = torch.stack(out) # (patch) x c x h x w outputs.append(out) return outputs @@ -142,7 +152,6 @@ def _forward_v1_5(self, images: List[Image]): split = [x.shape[0] for x in outputs] outputs = torch.cat(outputs, dim=0) outputs = outputs.to(self.model.device, dtype=torch.float16) - outputs = self.transform(outputs) outputs = self.model.extract_feature(outputs) outputs = torch.split(outputs, split, dim=0) outputs = [x.reshape(-1, x.shape[-1]) for x in outputs] diff --git a/lmdeploy/vl/model/internvl_llava.py b/lmdeploy/vl/model/internvl_llava.py index dd017327ef..4f11578d88 100644 --- a/lmdeploy/vl/model/internvl_llava.py +++ b/lmdeploy/vl/model/internvl_llava.py @@ -127,9 +127,9 @@ def build_model(self): no_split_module_classes=['InternVisionEncoderLayer'], dtype=torch.half) - self.model = model.model - self.vision_tower = model.model.vision_tower - self.mm_projector = model.model.mm_projector + self.model = model.model.eval() + self.vision_tower = model.model.vision_tower.eval() + self.mm_projector = model.model.mm_projector.eval() def encode_images(self, images: torch.Tensor) -> torch.Tensor: """encode images.""" diff --git a/lmdeploy/vl/model/llava.py b/lmdeploy/vl/model/llava.py index 90a00ddf15..7b70288d43 100644 --- a/lmdeploy/vl/model/llava.py +++ b/lmdeploy/vl/model/llava.py @@ -124,9 +124,9 @@ def build_model(self): no_split_module_classes=['CLIPEncoderLayer'], dtype=torch.half) - self.model = model.model - self.vision_tower = model.model.vision_tower.half() - self.mm_projector = model.model.mm_projector.half() + self.model = model.model.eval() + self.vision_tower = model.model.vision_tower.half().eval() + self.mm_projector = model.model.mm_projector.half().eval() def encode_images(self, images: torch.Tensor) -> torch.Tensor: """encode images.""" diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index 8aadb5e618..ba641d243a 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -49,7 +49,7 @@ def build_model(self): model.resampler.pos_embed = model.resampler.pos_embed.to( device=model.resampler.proj.device) self.config = config - self.model = model + self.model = model.eval() if hasattr(config, 'vision_config'): self._forward_func = self._forward_v2_5 diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py index 40e73c2ef7..64a4c19915 100644 --- a/lmdeploy/vl/model/qwen.py +++ b/lmdeploy/vl/model/qwen.py @@ -60,7 +60,7 @@ def build_model(self): no_split_module_classes=['VisualAttentionBlock'], dtype=torch.half) - self.model = model.transformer.visual + self.model = model.transformer.visual.eval() @torch.no_grad() def forward(self, images: List[Image]) -> List[torch.Tensor]: diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py index 4ab2d9613f..e4f5a0f7bb 100644 --- a/lmdeploy/vl/model/xcomposer2.py +++ b/lmdeploy/vl/model/xcomposer2.py @@ -99,7 +99,7 @@ def build_model(self): device_map['plora_glb_GN'], lambda x: (x[0].to(device=device_map['plora_glb_GN']), )) - self.model = model + self.model = model.eval() def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]: """internlm-xcomposer2-7b vit forward.""" diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py index 23f6d5c027..c1198908ba 100644 --- a/lmdeploy/vl/templates.py +++ b/lmdeploy/vl/templates.py @@ -3,6 +3,7 @@ from typing import Dict, List, Tuple, Union import PIL +import PIL.Image from lmdeploy.model import BaseModel from lmdeploy.utils import get_hf_config_content, get_logger @@ -38,15 +39,29 @@ def prompt_to_messages(self, prompt: VLPromptType): images = [images] messages['content'][0]['text'] = prompt for image in images: + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. if isinstance(image, str): image = load_image(image) - image_base64_data = encode_image_base64(image) - item = { - 'type': 'image_url', - 'image_url': { - 'url': f'data:image/jpeg;base64,{image_base64_data}' + image_base64_data = encode_image_base64(image) + item = { + 'type': 'image_url', + 'image_url': { + 'url': + f'data:image/jpeg;base64,{image_base64_data}' + } } - } + elif isinstance(image, PIL.Image.Image): + item = { + 'type': 'image_data', + 'image_data': { + 'data': image + } + } + else: + raise ValueError( + 'image should be a str(url/path) or PIL.Image.Image') + messages['content'].append(item) return [messages] @@ -61,14 +76,18 @@ async def async_collect_pil_images( if role != 'user' or isinstance(content, str): continue for item in content: - if item['type'] != 'image_url': - continue - url = item['image_url']['url'] - images.append(url) + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. + if item['type'] == 'image_url': + url = item['image_url']['url'] + images.append(url) + elif item['type'] == 'image_data': + data = item['image_data']['data'] + images.append(data) def _inner_call(i, images): - url = images[i] - images[i] = load_image(url) + url_or_data = images[i] + images[i] = load_image(url_or_data) await asyncio.gather(*[ asyncio.get_event_loop().run_in_executor( @@ -95,8 +114,12 @@ def convert_messages(self, messages, sequence_start=True): continue num_images = 0 for item in content: + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. if item['type'] == 'image_url': num_images += 1 + elif item['type'] == 'image_data': + num_images += 1 elif item['type'] == 'text': prompt = item['text'] # if IMAGE_TOKEN in user prompt, use user custom prompt instead diff --git a/lmdeploy/vl/utils.py b/lmdeploy/vl/utils.py index faa762cc4d..a89cfc7a30 100644 --- a/lmdeploy/vl/utils.py +++ b/lmdeploy/vl/utils.py @@ -20,8 +20,10 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: return Image.open(BytesIO(base64.b64decode(image))) -def load_image(image_url: str) -> Image.Image: +def load_image(image_url: Union[str, Image.Image]) -> Image.Image: """load image from url, local path or openai GPT4V.""" + if isinstance(image_url, Image.Image): + return image_url FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10)) headers = { @@ -40,6 +42,7 @@ def load_image(image_url: str) -> Image.Image: elif image_url.startswith('data:image'): img = load_image_from_base64(image_url.split(',')[1]) else: + # Load image from local path img = Image.open(image_url) return img