Skip to content

Commit

Permalink
[Bugfix] fix internvl-1.5-chat vision model preprocess and freeze wei…
Browse files Browse the repository at this point in the history
…ghts (InternLM#1741)

* [Bugfix] fix internvl-1.5 vision model preprocess and freeze weights

* Use PNG instead of JPEG for base64 image buffer

* Update internvl.py

* Update utils.py

* support pass Image object directly for VL pipeline

* Update templates.py

* Update internvl_llava.py

* Update llava.py

* Update minicpmv.py

* Update qwen.py

* Update xcomposer2.py
  • Loading branch information
DefTruth authored Jun 13, 2024
1 parent 679572c commit da5ce97
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 29 deletions.
23 changes: 16 additions & 7 deletions lmdeploy/vl/model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,26 @@ def build_model(self):
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)

self.model = model
# We need eval mode to freeze the weights in model, thus,
# avoid randomness in inference.
self.model = model.eval()
self.config = config

if getattr(self.config, 'dynamic_image_size', False):
logger.info('using InternVL-Chat-V1-5 vision preprocess')
MEAN = (123.675, 116.28, 103.53)
STD = (58.395, 57.12, 57.375)
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
import torchvision.transforms as T
self.transform = T.Compose([T.Normalize(mean=MEAN, std=STD)])
from torchvision.transforms.functional import InterpolationMode
input_size = self.config.vision_config.image_size
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
self._forward_func = self._forward_v1_5
else:
self.image_processor = CLIPImageProcessor.from_pretrained(
Expand All @@ -123,15 +134,14 @@ def build_model(self):

def _preprocess_v1_5(self, images: List[Image]):
outputs = []
import torchvision.transforms.functional as F
for image in images:
out = dynamic_preprocess(
image,
min_num=self.config.min_dynamic_patch,
max_num=self.config.max_dynamic_patch,
image_size=self.config.vision_config.image_size,
use_thumbnail=self.config.use_thumbnail)
out = [F.pil_to_tensor(x).half() for x in out]
out = [self.transform(x) for x in out]
out = torch.stack(out) # (patch) x c x h x w
outputs.append(out)
return outputs
Expand All @@ -142,7 +152,6 @@ def _forward_v1_5(self, images: List[Image]):
split = [x.shape[0] for x in outputs]
outputs = torch.cat(outputs, dim=0)
outputs = outputs.to(self.model.device, dtype=torch.float16)
outputs = self.transform(outputs)
outputs = self.model.extract_feature(outputs)
outputs = torch.split(outputs, split, dim=0)
outputs = [x.reshape(-1, x.shape[-1]) for x in outputs]
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def build_model(self):
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)

self.model = model.model
self.vision_tower = model.model.vision_tower
self.mm_projector = model.model.mm_projector
self.model = model.model.eval()
self.vision_tower = model.model.vision_tower.eval()
self.mm_projector = model.model.mm_projector.eval()

def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""encode images."""
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def build_model(self):
no_split_module_classes=['CLIPEncoderLayer'],
dtype=torch.half)

self.model = model.model
self.vision_tower = model.model.vision_tower.half()
self.mm_projector = model.model.mm_projector.half()
self.model = model.model.eval()
self.vision_tower = model.model.vision_tower.half().eval()
self.mm_projector = model.model.mm_projector.half().eval()

def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""encode images."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/vl/model/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def build_model(self):
model.resampler.pos_embed = model.resampler.pos_embed.to(
device=model.resampler.proj.device)
self.config = config
self.model = model
self.model = model.eval()

if hasattr(config, 'vision_config'):
self._forward_func = self._forward_v2_5
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/vl/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def build_model(self):
no_split_module_classes=['VisualAttentionBlock'],
dtype=torch.half)

self.model = model.transformer.visual
self.model = model.transformer.visual.eval()

@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/vl/model/xcomposer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def build_model(self):
device_map['plora_glb_GN'], lambda x:
(x[0].to(device=device_map['plora_glb_GN']), ))

self.model = model
self.model = model.eval()

def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]:
"""internlm-xcomposer2-7b vit forward."""
Expand Down
47 changes: 35 additions & 12 deletions lmdeploy/vl/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Tuple, Union

import PIL
import PIL.Image

from lmdeploy.model import BaseModel
from lmdeploy.utils import get_hf_config_content, get_logger
Expand Down Expand Up @@ -38,15 +39,29 @@ def prompt_to_messages(self, prompt: VLPromptType):
images = [images]
messages['content'][0]['text'] = prompt
for image in images:
# 'image_url': means url or local path to image.
# 'image_data': means PIL.Image.Image object.
if isinstance(image, str):
image = load_image(image)
image_base64_data = encode_image_base64(image)
item = {
'type': 'image_url',
'image_url': {
'url': f'data:image/jpeg;base64,{image_base64_data}'
image_base64_data = encode_image_base64(image)
item = {
'type': 'image_url',
'image_url': {
'url':
f'data:image/jpeg;base64,{image_base64_data}'
}
}
}
elif isinstance(image, PIL.Image.Image):
item = {
'type': 'image_data',
'image_data': {
'data': image
}
}
else:
raise ValueError(
'image should be a str(url/path) or PIL.Image.Image')

messages['content'].append(item)

return [messages]
Expand All @@ -61,14 +76,18 @@ async def async_collect_pil_images(
if role != 'user' or isinstance(content, str):
continue
for item in content:
if item['type'] != 'image_url':
continue
url = item['image_url']['url']
images.append(url)
# 'image_url': means url or local path to image.
# 'image_data': means PIL.Image.Image object.
if item['type'] == 'image_url':
url = item['image_url']['url']
images.append(url)
elif item['type'] == 'image_data':
data = item['image_data']['data']
images.append(data)

def _inner_call(i, images):
url = images[i]
images[i] = load_image(url)
url_or_data = images[i]
images[i] = load_image(url_or_data)

await asyncio.gather(*[
asyncio.get_event_loop().run_in_executor(
Expand All @@ -95,8 +114,12 @@ def convert_messages(self, messages, sequence_start=True):
continue
num_images = 0
for item in content:
# 'image_url': means url or local path to image.
# 'image_data': means PIL.Image.Image object.
if item['type'] == 'image_url':
num_images += 1
elif item['type'] == 'image_data':
num_images += 1
elif item['type'] == 'text':
prompt = item['text']
# if IMAGE_TOKEN in user prompt, use user custom prompt instead
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/vl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(image)))


def load_image(image_url: str) -> Image.Image:
def load_image(image_url: Union[str, Image.Image]) -> Image.Image:
"""load image from url, local path or openai GPT4V."""
if isinstance(image_url, Image.Image):
return image_url

FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10))
headers = {
Expand All @@ -40,6 +42,7 @@ def load_image(image_url: str) -> Image.Image:
elif image_url.startswith('data:image'):
img = load_image_from_base64(image_url.split(',')[1])
else:
# Load image from local path
img = Image.open(image_url)

return img

0 comments on commit da5ce97

Please sign in to comment.