diff --git a/paddlemix/examples/llava/pretrain.py b/paddlemix/examples/llava/pretrain.py index dc9669f87..4d9825e81 100755 --- a/paddlemix/examples/llava/pretrain.py +++ b/paddlemix/examples/llava/pretrain.py @@ -132,6 +132,14 @@ def main(): train_ds, max_length=data_args.max_length, processor=train_processor, tokenizer=tokenizer ) + input_spec = [ + paddle.static.InputSpec(name='input_ids', shape=[-1, 2048], dtype='int32'), + paddle.static.InputSpec(name='attention_mask', shape=[-1, 2048], dtype='bool'), + paddle.static.InputSpec(name='labels', shape=[-1, 2048], dtype='int32'), + paddle.static.InputSpec(name='images', shape=[-1, 3, 336, 336], dtype='float32'), + ] + model = paddle.jit.to_static(model, input_spec=input_spec) + # get Trainer trainer = get_trainer( pretrained_model_name_or_path=model_args.model_name_or_path, diff --git a/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py b/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py index 7456e4508..56fc58476 100644 --- a/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py +++ b/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py @@ -584,6 +584,17 @@ def _freeze_params(module): label_pad_token_id=IGNORE_INDEX, ) + input_spec = [ + paddle.static.InputSpec(name='input_ids', shape=[-1, 400], dtype='int32'), + paddle.static.InputSpec(name='attention_mask', shape=[-1, 400], dtype='bool'), + paddle.static.InputSpec(name='labels', shape=[-1, 400], dtype='int32'), + paddle.static.InputSpec(name='pixel_values', shape=[-1, 1224, 1176], dtype='float32'), + paddle.static.InputSpec(name='image_grid_thw', shape=[-1, 1224, 1176], dtype='int32'), + ] + + model = paddle.jit.to_static(model, input_spec=input_spec) + print("--------------------paddle.jit.to_static successful------------------------") + trainer = Trainer( model=model, args=training_args, diff --git a/paddlemix/models/llava/language_model/llava_llama.py b/paddlemix/models/llava/language_model/llava_llama.py index 8308f055d..223b5d9a8 100644 --- a/paddlemix/models/llava/language_model/llava_llama.py +++ b/paddlemix/models/llava/language_model/llava_llama.py @@ -72,7 +72,11 @@ def init_train(self): def get_model(self): return self.llama - def forward( + + def forward(self, input_ids, attention_mask, labels, images): + return self._forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, images=images) + + def _forward( self, input_ids: paddle.Tensor = None, attention_mask: Optional[paddle.Tensor] = None, diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py index daa3ce821..ef67fcf41 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py @@ -1463,7 +1463,27 @@ def vision_forward( inputs_embeds[video_mask] = video_embeds return inputs_embeds - def forward( + + def forward(self, input_ids, attention_mask, labels, pixel_values, image_grid_thw): + return self._forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + + # print('input_ids: ', input_ids.shape, input_ids.dtype) + # print('attention_mask: ', (attention_mask.shape, attention_mask.dtype) if attention_mask is not None else 'None') + # print('position_ids: ', (position_ids.shape, position_ids.dtype) if position_ids is not None else 'None') + # print('past_key_values: ', (past_key_values[0].shape, past_key_values[0].dtype) if past_key_values is not None else 'None') + # print('inputs_embeds: ', (inputs_embeds.shape, inputs_embeds.dtype) if inputs_embeds is not None else 'None') + # print('labels: ', (labels.shape, labels.dtype) if labels is not None else 'None') + # print('use_cache: ', use_cache) + # print('output_attentions: ', output_attentions) + # print('output_hidden_states: ', output_hidden_states) + # print('return_dict: ', return_dict) + # print('pixel_values: ', (pixel_values.shape, pixel_values.dtype) if pixel_values is not None else 'None') + # print('pixel_values_videos: ', (pixel_values_videos.shape, pixel_values_videos.dtype) if pixel_values_videos is not None else 'None') + # print('image_grid_thw: ', (image_grid_thw.shape, image_grid_thw.dtype) if image_grid_thw is not None else 'None') + # print('video_grid_thw: ', (video_grid_thw.shape, video_grid_thw.dtype) if video_grid_thw is not None else 'None') + # print('rope_deltas: ', (rope_deltas.shape, rope_deltas.dtype) if rope_deltas is not None else 'None') + + def _forward( self, input_ids: paddle.Tensor = None, # [1, 400] sum 49356255 attention_mask: Optional[paddle.Tensor] = None, # [1, 400] sum 396 diff --git a/tests/test_tipc/dygraph/dp/llava/benchmark_common/prepare.sh b/tests/test_tipc/dygraph/dp/llava/benchmark_common/prepare.sh index 16535a048..c6fc477fa 100644 --- a/tests/test_tipc/dygraph/dp/llava/benchmark_common/prepare.sh +++ b/tests/test_tipc/dygraph/dp/llava/benchmark_common/prepare.sh @@ -24,17 +24,17 @@ ln -s /root/.paddlemix/datasets/llava_bench_data ./ export http_proxy=agent.baidu.com:8188 export https_proxy=agent.baidu.com:8188 -export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH -python -m pip install --upgrade pip -i https://mirror.baidu.com/pypi/simple -python -m pip install einops -i https://mirror.baidu.com/pypi/simple -python -m pip install -r ../requirements.txt -python -m pip install -e ../ -python -m pip install --upgrade paddlenlp pybind11 regex sentencepiece tqdm visualdl attrdict easydict pyyaml -i https://mirror.baidu.com/pypi/simple -pip install -r ../paddlemix/appflow/requirements.txt -pip install -U ppdiffusers -bash ../build_paddle_env.sh -# python -m pip install https://paddle-wheel.bj.bcebos.com/develop/linux/linux-gpu-cuda11.8-cudnn8.6-mkl-gcc8.2-avx/paddlepaddle_gpu-0.0.0.post118-cp310-cp310-linux_x86_64.whl -python -m pip install paddlenlp==3.0.0b2 -python -m pip install huggingface_hub==0.23.0 -python -m pip list -cd - +# export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH +# python -m pip install --upgrade pip -i https://mirror.baidu.com/pypi/simple +# python -m pip install einops -i https://mirror.baidu.com/pypi/simple +# python -m pip install -r ../requirements.txt +# python -m pip install -e ../ +# python -m pip install --upgrade paddlenlp pybind11 regex sentencepiece tqdm visualdl attrdict easydict pyyaml -i https://mirror.baidu.com/pypi/simple +# pip install -r ../paddlemix/appflow/requirements.txt +# pip install -U ppdiffusers +# bash ../build_paddle_env.sh +# # python -m pip install https://paddle-wheel.bj.bcebos.com/develop/linux/linux-gpu-cuda11.8-cudnn8.6-mkl-gcc8.2-avx/paddlepaddle_gpu-0.0.0.post118-cp310-cp310-linux_x86_64.whl +# python -m pip install paddlenlp==3.0.0b2 +# python -m pip install huggingface_hub==0.23.0 +# python -m pip list +# cd - diff --git a/tests/test_tipc/dygraph/dp/llava/benchmark_common/run_benchmark.sh b/tests/test_tipc/dygraph/dp/llava/benchmark_common/run_benchmark.sh index 23c9b1f71..af9d2ec8a 100644 --- a/tests/test_tipc/dygraph/dp/llava/benchmark_common/run_benchmark.sh +++ b/tests/test_tipc/dygraph/dp/llava/benchmark_common/run_benchmark.sh @@ -121,6 +121,8 @@ function _train(){ export http_proxy=agent.baidu.com:8188 export https_proxy=agent.baidu.com:8188 + export FLAGS_prim_all=true;export FLAGS_prim_enable_dynamic=true;export FLAGS_use_cinn=true;export MIN_GRAPH_SIZE=0;export FLAGS_prim_forward_blacklist="pd_op.dropout" + #训练阶段 if [ ${train_stage} = "sft" ]; then train_cmd="../paddlemix/tools/supervised_finetune.py \