diff --git a/ape/layers/multi_scale_deform_attn.py b/ape/layers/multi_scale_deform_attn.py index 4e10027..4e12e8f 100644 --- a/ape/layers/multi_scale_deform_attn.py +++ b/ape/layers/multi_scale_deform_attn.py @@ -151,6 +151,7 @@ def __init__( img2col_step: int = 64, dropout: float = 0.1, batch_first: bool = False, + pytorch_attn: bool = False, ): super().__init__() if embed_dim % num_heads != 0: @@ -184,6 +185,8 @@ def __init__( self.init_weights() + self.pytorch_attn = pytorch_attn + def init_weights(self): """ Default initialization for Parameters of Module. @@ -314,7 +317,7 @@ def forward( ) # the original impl for fp32 training - if torch.cuda.is_available() and value.is_cuda: + if torch.cuda.is_available() and value.is_cuda and not self.pytorch_attn: if torch.jit.is_scripting() or torch.jit.is_tracing(): output = torch.ops.ape.ms_deform_attn_forward( # value.to(torch.float32), diff --git a/ape/modeling/ape_deta/deformable_transformer.py b/ape/modeling/ape_deta/deformable_transformer.py index 29d41a1..60d8894 100644 --- a/ape/modeling/ape_deta/deformable_transformer.py +++ b/ape/modeling/ape_deta/deformable_transformer.py @@ -27,6 +27,7 @@ def __init__( post_norm: bool = False, num_feature_levels: int = 4, use_act_checkpoint: bool = False, + pytorch_attn=False, ): super(DeformableDetrTransformerEncoder, self).__init__( transformer_layers=BaseTransformerLayer( @@ -36,6 +37,7 @@ def __init__( dropout=attn_dropout, batch_first=True, num_levels=num_feature_levels, + pytorch_attn=pytorch_attn, ), ffn=FFN( embed_dim=embed_dim, @@ -106,6 +108,7 @@ def __init__( return_intermediate: bool = True, num_feature_levels: int = 4, use_act_checkpoint: bool = False, + pytorch_attn=False, ): super(DeformableDetrTransformerDecoder, self).__init__( transformer_layers=BaseTransformerLayer( @@ -122,6 +125,7 @@ def __init__( dropout=attn_dropout, batch_first=True, num_levels=num_feature_levels, + pytorch_attn=pytorch_attn, ), ], ffn=FFN( diff --git a/ape/modeling/ape_deta/deformable_transformer_vl.py b/ape/modeling/ape_deta/deformable_transformer_vl.py index e642746..17681da 100644 --- a/ape/modeling/ape_deta/deformable_transformer_vl.py +++ b/ape/modeling/ape_deta/deformable_transformer_vl.py @@ -29,6 +29,7 @@ def __init__( num_feature_levels: int = 4, vl_layer=None, use_act_checkpoint=False, + pytorch_attn=False, ): super(DeformableDetrTransformerEncoderVL, self).__init__( transformer_layers=BaseTransformerLayer( @@ -38,6 +39,7 @@ def __init__( dropout=attn_dropout, batch_first=True, num_levels=num_feature_levels, + pytorch_attn=pytorch_attn, ), ffn=FFN( embed_dim=embed_dim, @@ -122,6 +124,7 @@ def __init__( num_feature_levels: int = 4, use_act_checkpoint: bool = False, look_forward_twice: bool = False, + pytorch_attn=False, ): super(DeformableDetrTransformerDecoderVL, self).__init__( transformer_layers=BaseTransformerLayer( @@ -138,6 +141,7 @@ def __init__( dropout=attn_dropout, batch_first=True, num_levels=num_feature_levels, + pytorch_attn=pytorch_attn, ), ], ffn=FFN( diff --git a/demo/app.py b/demo/app.py index 40307e6..d3ffc1f 100644 --- a/demo/app.py +++ b/demo/app.py @@ -24,8 +24,8 @@ this_dir = path.dirname(path.abspath(__file__)) -os.system("git clone https://github.com/shenyunhang/APE.git") -os.system("python3.10 -m pip install -e APE/") +# os.system("git clone https://github.com/shenyunhang/APE.git") +# os.system("python3.10 -m pip install -e APE/") example_list = [ [ @@ -83,7 +83,7 @@ [ this_dir + "/examples/Transformers.webp", "Optimus Prime", - 0.08, + 0.11, ["object detection", "instance segmentation"], ], ] @@ -296,6 +296,8 @@ def load_APE_A(): "model.model_language.cache_dir=''", "model.model_vision.select_box_nums_for_evaluation=500", "model.model_vision.backbone.net.xattn=False", + "model.model_vision.transformer.encoder.pytorch_attn=True", + "model.model_vision.transformer.decoder.pytorch_attn=True", ] if running_device == "cpu": args.opts += [ @@ -342,6 +344,8 @@ def load_APE_B(): "model.model_vision.select_box_nums_for_evaluation=500", "model.model_vision.text_feature_bank_reset=True", "model.model_vision.backbone.net.xattn=False", + "model.model_vision.transformer.encoder.pytorch_attn=True", + "model.model_vision.transformer.decoder.pytorch_attn=True", ] if running_device == "cpu": args.opts += [ @@ -388,6 +392,8 @@ def load_APE_C(): "model.model_vision.select_box_nums_for_evaluation=500", "model.model_vision.text_feature_bank_reset=True", "model.model_vision.backbone.net.xattn=False", + "model.model_vision.transformer.encoder.pytorch_attn=True", + "model.model_vision.transformer.decoder.pytorch_attn=True", ] if running_device == "cpu": args.opts += [ @@ -434,6 +440,8 @@ def load_APE_D(): "model.model_vision.select_box_nums_for_evaluation=500", "model.model_vision.text_feature_bank_reset=True", "model.model_vision.backbone.net.xattn=False", + "model.model_vision.transformer.encoder.pytorch_attn=True", + "model.model_vision.transformer.decoder.pytorch_attn=True", ] if running_device == "cpu": args.opts += [ @@ -577,7 +585,7 @@ def APE_D_tab(): ) score_threshold = gr.Slider( - label="Score Threshold", minimum=0.01, maximum=1.0, value=0.3, step=0.01 + label="Score Threshold", minimum=0.01, maximum=1.0, value=0.1, step=0.01 ) output_type = gr.CheckboxGroup( @@ -625,7 +633,7 @@ def comparison_tab(): ) score_threshold = gr.Slider( - label="Score Threshold", minimum=0.01, maximum=1.0, value=0.2, step=0.01 + label="Score Threshold", minimum=0.01, maximum=1.0, value=0.1, step=0.01 ) output_type = gr.CheckboxGroup( diff --git a/demo/pre-requirements.txt b/demo/pre-requirements.txt index 0adc281..bda06b1 100644 --- a/demo/pre-requirements.txt +++ b/demo/pre-requirements.txt @@ -1,6 +1,4 @@ -opencv-python -torch==1.12.1 -torchvision -transformers==4.32.1 -scipy -cython +--index-url https://download.pytorch.org/whl/cu118 +torch==2.0.1 +torchvision==0.15.2 +torchaudio==2.0.2 diff --git a/demo/requirements.txt b/demo/requirements.txt index e9802f3..c1086f9 100644 --- a/demo/requirements.txt +++ b/demo/requirements.txt @@ -1,7 +1,6 @@ +transformers +cython opencv-python -torch==1.12.1 -torchvision -transformers==4.32.1 scipy einops lvis diff --git a/requirements.txt b/requirements.txt index 51a38c2..526c0cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -opencv-python torch==1.12.1 torchvision transformers==4.32.1 +cython +opencv-python scipy einops lvis