diff --git a/trainers/cocoop.py b/trainers/cocoop.py index 51508c88..bb6dced4 100644 --- a/trainers/cocoop.py +++ b/trainers/cocoop.py @@ -54,7 +54,7 @@ def forward(self, prompts, tokenized_prompts): # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection + x = x[torch.arange(x.shape[0]), tokenized_prompts.repeat(x.shape[0]//tokenized_prompts.shape[0], 1).argmax(dim=-1)] @ self.text_projection return x @@ -127,28 +127,31 @@ def construct_prompts(self, ctx, prefix, suffix, label=None): # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) if label is not None: - prefix = prefix[label] - suffix = suffix[label] + prefix = prefix[:,label] + suffix = suffix[:,label] prompts = torch.cat( [ - prefix, # (dim0, 1, dim) - ctx, # (dim0, n_ctx, dim) - suffix, # (dim0, *, dim) + prefix, # (batch, dim0, 1, dim) + ctx, # (batch, dim0, n_ctx, dim) + suffix, # (batch, dim0, *, dim) ], - dim=1, + dim=2, ) return prompts def forward(self, im_features): prefix = self.token_prefix + prefix = prefix.unsqueeze(0).expand(im_features.shape[0], -1, -1, -1) # (batch, n_cls, 1, dim) suffix = self.token_suffix + suffix = suffix.unsqueeze(0).expand(im_features.shape[0], -1, -1, -1) # (batch, n_cls, 1, dim) ctx = self.ctx # (n_ctx, ctx_dim) + ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) # (n_cls, n_ctx, ctx_dim) bias = self.meta_net(im_features) # (batch, ctx_dim) - bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) - ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) - ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) + bias = bias.unsqueeze(1).unsqueeze(1).expand(-1, self.n_cls, -1, -1) # (batch, n_cls, 1, ctx_dim) + ctx = ctx.unsqueeze(0).expand(im_features.shape[0], -1, -1, -1) # (batch, n_cls, n_ctx, ctx_dim) + ctx_shifted = ctx + bias # (batch, c_cls, n_ctx, ctx_dim) # Use instance-conditioned context tokens for all classes prompts = [] @@ -179,14 +182,13 @@ def forward(self, image, label=None): image_features = image_features / image_features.norm(dim=-1, keepdim=True) prompts = self.prompt_learner(image_features) - - logits = [] - for pts_i, imf_i in zip(prompts, image_features): - text_features = self.text_encoder(pts_i, tokenized_prompts) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - l_i = logit_scale * imf_i @ text_features.t() - logits.append(l_i) - logits = torch.stack(logits) + + prompts_shape = prompts.shape + prompts = prompts.view(prompts_shape[0] * prompts_shape[1], prompts_shape[2], prompts_shape[3]) + text_features = self.text_encoder(prompts, tokenized_prompts) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + text_features = text_features.view(prompts_shape[0], prompts_shape[1], text_features.shape[1]) + logits = logit_scale * torch.bmm(image_features.unsqueeze(1), text_features.permute(0,2,1)).squeeze(1) if self.prompt_learner.training: return F.cross_entropy(logits, label)