Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] llava pir #988

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddlemix/examples/llava/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ def main():
train_ds, max_length=data_args.max_length, processor=train_processor, tokenizer=tokenizer
)

input_spec = [
paddle.static.InputSpec(name='input_ids', shape=[-1, 2048], dtype='int32'),
paddle.static.InputSpec(name='attention_mask', shape=[-1, 2048], dtype='bool'),
paddle.static.InputSpec(name='labels', shape=[-1, 2048], dtype='int32'),
paddle.static.InputSpec(name='images', shape=[-1, 3, 336, 336], dtype='float32'),
]
model = paddle.jit.to_static(model, input_spec=input_spec)

# get Trainer
trainer = get_trainer(
pretrained_model_name_or_path=model_args.model_name_or_path,
Expand Down
11 changes: 11 additions & 0 deletions paddlemix/examples/qwen2_vl/qwen2vl_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,17 @@ def _freeze_params(module):
label_pad_token_id=IGNORE_INDEX,
)

input_spec = [
paddle.static.InputSpec(name='input_ids', shape=[-1, 400], dtype='int32'),
paddle.static.InputSpec(name='attention_mask', shape=[-1, 400], dtype='bool'),
paddle.static.InputSpec(name='labels', shape=[-1, 400], dtype='int32'),
paddle.static.InputSpec(name='pixel_values', shape=[-1, 1224, 1176], dtype='float32'),
paddle.static.InputSpec(name='image_grid_thw', shape=[-1, 1224, 1176], dtype='int32'),
]

model = paddle.jit.to_static(model, input_spec=input_spec)
print("--------------------paddle.jit.to_static successful------------------------")

trainer = Trainer(
model=model,
args=training_args,
Expand Down
6 changes: 5 additions & 1 deletion paddlemix/models/llava/language_model/llava_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def init_train(self):
def get_model(self):
return self.llama

def forward(

def forward(self, input_ids, attention_mask, labels, images):
return self._forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, images=images)

def _forward(
self,
input_ids: paddle.Tensor = None,
attention_mask: Optional[paddle.Tensor] = None,
Expand Down
22 changes: 21 additions & 1 deletion paddlemix/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,27 @@ def vision_forward(
inputs_embeds[video_mask] = video_embeds
return inputs_embeds

def forward(

def forward(self, input_ids, attention_mask, labels, pixel_values, image_grid_thw):
return self._forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw)

# print('input_ids: ', input_ids.shape, input_ids.dtype)
# print('attention_mask: ', (attention_mask.shape, attention_mask.dtype) if attention_mask is not None else 'None')
# print('position_ids: ', (position_ids.shape, position_ids.dtype) if position_ids is not None else 'None')
# print('past_key_values: ', (past_key_values[0].shape, past_key_values[0].dtype) if past_key_values is not None else 'None')
# print('inputs_embeds: ', (inputs_embeds.shape, inputs_embeds.dtype) if inputs_embeds is not None else 'None')
# print('labels: ', (labels.shape, labels.dtype) if labels is not None else 'None')
# print('use_cache: ', use_cache)
# print('output_attentions: ', output_attentions)
# print('output_hidden_states: ', output_hidden_states)
# print('return_dict: ', return_dict)
# print('pixel_values: ', (pixel_values.shape, pixel_values.dtype) if pixel_values is not None else 'None')
# print('pixel_values_videos: ', (pixel_values_videos.shape, pixel_values_videos.dtype) if pixel_values_videos is not None else 'None')
# print('image_grid_thw: ', (image_grid_thw.shape, image_grid_thw.dtype) if image_grid_thw is not None else 'None')
# print('video_grid_thw: ', (video_grid_thw.shape, video_grid_thw.dtype) if video_grid_thw is not None else 'None')
# print('rope_deltas: ', (rope_deltas.shape, rope_deltas.dtype) if rope_deltas is not None else 'None')

def _forward(
self,
input_ids: paddle.Tensor = None, # [1, 400] sum 49356255
attention_mask: Optional[paddle.Tensor] = None, # [1, 400] sum 396
Expand Down
28 changes: 14 additions & 14 deletions tests/test_tipc/dygraph/dp/llava/benchmark_common/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ ln -s /root/.paddlemix/datasets/llava_bench_data ./
export http_proxy=agent.baidu.com:8188
export https_proxy=agent.baidu.com:8188

export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH
python -m pip install --upgrade pip -i https://mirror.baidu.com/pypi/simple
python -m pip install einops -i https://mirror.baidu.com/pypi/simple
python -m pip install -r ../requirements.txt
python -m pip install -e ../
python -m pip install --upgrade paddlenlp pybind11 regex sentencepiece tqdm visualdl attrdict easydict pyyaml -i https://mirror.baidu.com/pypi/simple
pip install -r ../paddlemix/appflow/requirements.txt
pip install -U ppdiffusers
bash ../build_paddle_env.sh
# python -m pip install https://paddle-wheel.bj.bcebos.com/develop/linux/linux-gpu-cuda11.8-cudnn8.6-mkl-gcc8.2-avx/paddlepaddle_gpu-0.0.0.post118-cp310-cp310-linux_x86_64.whl
python -m pip install paddlenlp==3.0.0b2
python -m pip install huggingface_hub==0.23.0
python -m pip list
cd -
# export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH
# python -m pip install --upgrade pip -i https://mirror.baidu.com/pypi/simple
# python -m pip install einops -i https://mirror.baidu.com/pypi/simple
# python -m pip install -r ../requirements.txt
# python -m pip install -e ../
# python -m pip install --upgrade paddlenlp pybind11 regex sentencepiece tqdm visualdl attrdict easydict pyyaml -i https://mirror.baidu.com/pypi/simple
# pip install -r ../paddlemix/appflow/requirements.txt
# pip install -U ppdiffusers
# bash ../build_paddle_env.sh
# # python -m pip install https://paddle-wheel.bj.bcebos.com/develop/linux/linux-gpu-cuda11.8-cudnn8.6-mkl-gcc8.2-avx/paddlepaddle_gpu-0.0.0.post118-cp310-cp310-linux_x86_64.whl
# python -m pip install paddlenlp==3.0.0b2
# python -m pip install huggingface_hub==0.23.0
# python -m pip list
# cd -
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ function _train(){
export http_proxy=agent.baidu.com:8188
export https_proxy=agent.baidu.com:8188

export FLAGS_prim_all=true;export FLAGS_prim_enable_dynamic=true;export FLAGS_use_cinn=true;export MIN_GRAPH_SIZE=0;export FLAGS_prim_forward_blacklist="pd_op.dropout"

#训练阶段
if [ ${train_stage} = "sft" ]; then
train_cmd="../paddlemix/tools/supervised_finetune.py \
Expand Down