Skip to content

Commit

Permalink
Llava static (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
LokeZhou authored Jul 1, 2024
1 parent ed68f37 commit d83fc2e
Show file tree
Hide file tree
Showing 8 changed files with 667 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ repos:
name: copyright_checker
entry: python .copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
59 changes: 59 additions & 0 deletions deploy/llava/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# LLaVA

## 1. 模型介绍

[LLaVA](https://arxiv.org/pdf/2310.03744.pdf) 是基于大规模语言模型 llama 的视觉语言模型。支持多个多模态任务,包括零样本图像描述生成(Zero-shot Image Caption)、视觉问答(VQA)、细粒度视觉定位(Referring Expression Comprehension)等任务。

其性能优于其他模型,在多个任务上取得了更好的效果。

<p align="center">
<img src="https://github.com/haotian-liu/LLaVA/blob/main/images/llava_v1_5_radar.jpg" align="middle" width = "600" />
</p>

注:图片引用自[LLaVA](https://github.com/haotian-liu/LLaVA).


## 2. 安装依赖

* `paddlenlp_ops`依赖安装

```bash
git clone https://github.com/PaddlePaddle/PaddleNLP.git
cd PaddleNLP
pip install -e .
cd csrc
python setup_cuda.py install
```

* `fused_ln`需要安装[此目录](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/gpt-3/external_ops)下的自定义OP, `python setup.py install`

## 3. 示例

### 3.1 转出静态图推理所需的视觉模型和语言模型

*`PaddleMIX`目录下,执行转换脚本,得到视觉模型部分静态图

```bash
#!/bin/bash

python deploy/llava/export_model.py \
--model_name_or_path "paddlemix/llava/llava-v1.5-7b" \
--save_path "./llava_static" \
--fp16
```


### 3.2 静态图推理

*`PaddleMIX`目录下,运行执行脚本,进行静态图推理

```bash
#!/bin/bash

python3.10 deploy/llava/run_static_predict.py --model_name_or_path "paddlemix/llava/llava-v1.5-7b" \
--image_file "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg" \
--first_model_path "llava_static/encode_image/clip" \
--second_model_path "llava_static/encode_text/llama" \
--fp16

```
88 changes: 88 additions & 0 deletions deploy/llava/export_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import paddle

from deploy.llava.llama_inference_model import LlamaForClipInferenceModel
from paddlemix.auto import AutoConfigMIX, AutoModelMIX
from paddlemix.utils.log import logger


def export_encode_text(model, config, compute_dtype):

# save to static model
save_path = args.save_path + "/encode_text/llama"
model.to_static(save_path, config, compute_dtype)
logger.info(f"static model has been to {save_path}")


def export_encode_image(model, compute_dtype):

# convert to static graph with specific input description
model = paddle.jit.to_static(
model.encode_images,
input_spec=[
paddle.static.InputSpec(shape=[None, 3, 336, 336], dtype=compute_dtype), # images
],
)

# save to static model
save_path = args.save_path + "/encode_image/clip"
paddle.jit.save(model, save_path)
logger.info(f"static model has been to {save_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default="paddlemix/llava/llava-v1.5-7b",
type=str,
help="The dir name of llava checkpoint.",
)
parser.add_argument(
"--save_path",
default="./llava_static",
type=str,
help="The saving path of static llava vision.",
)
parser.add_argument("--fp16", action="store_true")

args = parser.parse_args()

compute_dtype = "float16" if args.fp16 else "bfloat16"
if not paddle.amp.is_bfloat16_supported() and compute_dtype == "bfloat16":
logger.warning("bfloat16 is not supported on your device,change to float32")
compute_dtype = "float32"

model = AutoModelMIX.from_pretrained(args.model_name_or_path, dtype=compute_dtype)
vision_tower = model.get_vision_tower()
vision_tower.load_model()
model.eval()
export_encode_image(model, compute_dtype)

config = AutoConfigMIX.from_pretrained(args.model_name_or_path)
config.tensor_parallel_degree = 1
config.tensor_parallel_rank = 0
config.weight_only_quant_bits = -1
config.quant_type = None

model = LlamaForClipInferenceModel.from_pretrained(args.model_name_or_path, config=config)

model.to(dtype=compute_dtype)
model.eval()

export_encode_text(model, config, compute_dtype)
127 changes: 127 additions & 0 deletions deploy/llava/llama_inference_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddlenlp.experimental.transformers import LlamaForCausalLMInferenceModel


class LlamaForClipInferenceModel(LlamaForCausalLMInferenceModel):
"""
This class is 99% like LlamaForCausalLMInferenceModel.
Used only for llava's second part.
"""

@paddle.no_grad()
def generate_text_with_image_features(
self,
input_ids: paddle.Tensor,
image_features: paddle.Tensor,
img_pos: paddle.Tensor,
attention_mask=None,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
**generate_kwargs
) -> paddle.Tensor:

inputs_embeds = self.llama.embed_tokens(input_ids)
for batch_idx, pos in enumerate(img_pos):
for idx, p in enumerate(pos):
index = paddle.arange(p[0], p[1]).unsqueeze(-1)
inputs_embeds[batch_idx] = paddle.scatter(inputs_embeds[batch_idx], index, image_features[idx])

outputs = self.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

def to_static(self, output_path: str, config: dict, compute_dtype: str):

cache_kvs_shapes = self.get_cache_kvs_shape(config, max_length=config.get("max_length", None))

input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int32", name="inputs_ids"),
paddle.static.InputSpec(
shape=[None, None, None], dtype=compute_dtype, name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None, 2], dtype="int64", name="img_pos"), # img_pos
paddle.static.InputSpec(
shape=[None, None, None, None], dtype="int64", name="attention_mask"
), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(shape=[None, 1, 1, None], name="tgt_generation_mask"), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=compute_dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path, skip_prune_program=True)
Loading

0 comments on commit d83fc2e

Please sign in to comment.