Skip to content

Commit

Permalink
ef35a38ad29c@2024-05-08_09-20-35: fix import error
Browse files Browse the repository at this point in the history
  • Loading branch information
shenyunhang committed May 8, 2024
1 parent e79b148 commit 3bc4e5c
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 18 deletions.
4 changes: 2 additions & 2 deletions ape/data/dataset_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from detectron2.data import transforms as T
from detectron2.data.dataset_mapper import DatasetMapper as DatasetMapper_d2

from . import detection_utils as utils_sota
from . import detection_utils as utils_ape

"""
This file contains the default mapping that's applied to "dataset dicts".
Expand Down Expand Up @@ -33,7 +33,7 @@ class DatasetMapper_ape(DatasetMapper_d2):

def __init__(self, cfg, is_train: bool = True):
super().__init__(cfg, is_train)
augmentations = utils_sota.build_augmentation(cfg, is_train)
augmentations = utils_ape.build_augmentation(cfg, is_train)
self.augmentations = T.AugmentationList(augmentations)

logger = logging.getLogger(__name__)
Expand Down
8 changes: 4 additions & 4 deletions ape/data/dataset_mapper_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.structures import BitMasks, Boxes, Instances

from . import detection_utils as utils_sota
from . import detection_utils as utils_ape
from . import mapper_utils

"""
Expand Down Expand Up @@ -124,10 +124,10 @@ def __init__(

@classmethod
def from_config(cls, cfg, is_train: bool = True):
augs = utils_sota.build_augmentation(cfg, is_train)
augs = utils_ape.build_augmentation(cfg, is_train)
augs_d2 = utils.build_augmentation(cfg, is_train)
augs_aa = utils_sota.build_augmentation_aa(cfg, is_train)
augs_lsj = utils_sota.build_augmentation_lsj(cfg, is_train)
augs_aa = utils_ape.build_augmentation_aa(cfg, is_train)
augs_lsj = utils_ape.build_augmentation_lsj(cfg, is_train)
if cfg.INPUT.CROP.ENABLED and is_train:
raise NotImplementedError("cfg.INPUT.CROP.ENABLED is not supported yet")
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
Expand Down
1 change: 0 additions & 1 deletion ape/data/mapper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def copypaste(dataset_dict, dataset_dict_bg, image_format, instance_mask_format)
return None, None


# from SotA-T/ape/data/datasets/coco.py
def maybe_load_annotation_from_file(record, meta=None, extra_annotation_keys=None):

file_name = record["file_name"]
Expand Down
9 changes: 3 additions & 6 deletions ape/layers/zero_shot_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from torch import nn
from torch.nn import functional as F

# from sota.modeling.text import build_clip_text_encoder, get_clip_embeddings
# from ..modeling.text import build_clip_text_encoder, get_clip_embeddings

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -54,7 +51,7 @@ def __init__(
torch.nn.init.normal_(self.linear.weight, std=0.01)

if len(zs_vocabulary) > 0:
from sota.modeling.text import get_clip_embeddings
from ape.modeling.text import get_clip_embeddings

logger.info("Generating weight for " + zs_vocabulary)
zs_vocabulary = zs_vocabulary.split(",")
Expand All @@ -67,7 +64,7 @@ def __init__(
elif zs_weight_path == "zeros":
zs_weight = torch.zeros((zs_weight_dim, num_classes))
elif zs_weight_path == "online":
from sota.modeling.text import build_clip_text_encoder
from ape.modeling.text import build_clip_text_encoder

zs_weight = torch.zeros((zs_weight_dim, num_classes))
self.text_encoder = build_clip_text_encoder(text_model, pretrain=True)
Expand Down Expand Up @@ -111,7 +108,7 @@ def forward(self, x, classifier=None):
x = self.linear(x)
if classifier is not None:
if isinstance(classifier, str):
from sota.modeling.text import get_clip_embeddings
from ape.modeling.text import get_clip_embeddings

zs_weight = get_clip_embeddings(
self.text_encoder, classifier, prompt="", device=x.device
Expand Down
2 changes: 1 addition & 1 deletion datasets/tools/odinw/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from detectron2.data import MetadataCatalog

import sota_t
import ape


print(MetadataCatalog.keys())
Expand Down
2 changes: 1 addition & 1 deletion datasets/tools/openimages2coco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm import tqdm

from detectron2.data.detection_utils import read_image
from sota.data.mapper_utils import mask_to_polygons
from ape.data.mapper_utils import mask_to_polygons


def csvread(file):
Expand Down
10 changes: 10 additions & 0 deletions demo/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
examples/094_56726435.jpg filter=lfs diff=lfs merge=lfs -text
examples/199_3946193540.jpg filter=lfs diff=lfs merge=lfs -text
examples/SolvayConference1927.jpg filter=lfs diff=lfs merge=lfs -text
examples/TheGreatWall.jpg filter=lfs diff=lfs merge=lfs -text
examples/Totoro01.png filter=lfs diff=lfs merge=lfs -text
examples/Transformers.webp filter=lfs diff=lfs merge=lfs -text
examples/013_438973263.jpg filter=lfs diff=lfs merge=lfs -text
examples/Pisa.jpg filter=lfs diff=lfs merge=lfs -text
examples/Terminator3.jpg filter=lfs diff=lfs merge=lfs -text
examples/MatrixRevolutionForZion.jpg filter=lfs diff=lfs merge=lfs -text
2 changes: 1 addition & 1 deletion demo/pre-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--index-url https://download.pytorch.org/whl/cu118
pytorch==2.2.1
torch==2.2.1
torchvision==0.17.1
torchaudio==2.2.1
2 changes: 1 addition & 1 deletion tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def main(args):

default_setup(cfg, args)

setup_logger(cfg.train.output_dir, distributed_rank=comm.get_rank(), name="sota")
setup_logger(cfg.train.output_dir, distributed_rank=comm.get_rank(), name="ape")
setup_logger(cfg.train.output_dir, distributed_rank=comm.get_rank(), name="timm")

Expand All @@ -640,6 +639,7 @@ def main(args):
logger = logging.getLogger("ape")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
model.to(torch.float16)
model = create_ddp_model(model)

ema.may_build_model_ema(cfg, model)
Expand Down
1 change: 0 additions & 1 deletion tools/train_net_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ def main(args):

default_setup(cfg, args)

setup_logger(cfg.train.output_dir, distributed_rank=comm.get_rank(), name="sota")
setup_logger(cfg.train.output_dir, distributed_rank=comm.get_rank(), name="ape")
setup_logger(cfg.train.output_dir, distributed_rank=comm.get_rank(), name="timm")

Expand Down

0 comments on commit 3bc4e5c

Please sign in to comment.