-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathanalogy_encoder.py
68 lines (57 loc) · 2.99 KB
/
analogy_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
ADOBE CONFIDENTIAL
Copyright 2024 Adobe
All Rights Reserved.
NOTICE: All information contained herein is, and remains
the property of Adobe and its suppliers, if any. The intellectual
and technical concepts contained herein are proprietary to Adobe
and its suppliers and are protected by all applicable intellectual
property laws, including trade secret and copyright laws.
Dissemination of this information or reproduction of this material
is strictly forbidden unless prior written permission is obtained
from Adobe.
"""
import torch as th
from diffusers import ModelMixin
from transformers import AutoModel, SiglipVisionConfig, Dinov2Config
from transformers import SiglipVisionModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
class AnalogyEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, load_pretrained=False,
dino_config_dict=None, siglip_config_dict=None):
super().__init__()
if load_pretrained:
image_encoder_dino = AutoModel.from_pretrained('facebook/dinov2-large', torch_dtype=th.float16)
image_encoder_siglip = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256", torch_dtype=th.float16, attn_implementation="sdpa")
else:
image_encoder_dino = AutoModel.from_config(Dinov2Config.from_dict(dino_config_dict))
image_encoder_siglip = AutoModel.from_config(SiglipVisionConfig.from_dict(siglip_config_dict))
image_encoder_dino.requires_grad_(False)
image_encoder_dino = image_encoder_dino.to(memory_format=th.channels_last)
image_encoder_siglip.requires_grad_(False)
image_encoder_siglip = image_encoder_siglip.to(memory_format=th.channels_last)
self.image_encoder_dino = image_encoder_dino
self.image_encoder_siglip = image_encoder_siglip
def dino_normalization(self, encoder_output):
embeds = encoder_output.last_hidden_state
embeds_pooled = embeds[:, 0:1]
embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
def siglip_normalization(self, encoder_output):
embeds = th.cat ([encoder_output.pooler_output[:, None, :], encoder_output.last_hidden_state], dim=1)
embeds_pooled = embeds[:, 0:1]
embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
def forward(self, dino_in, siglip_in):
x_1 = self.image_encoder_dino(dino_in, output_hidden_states=True)
x_1_first = x_1.hidden_states[0]
x_1 = self.dino_normalization(x_1)
x_2 = self.image_encoder_siglip(siglip_in, output_hidden_states=True)
x_2_first = x_2.hidden_states[0]
x_2_first_pool = th.mean(x_2_first, dim=1, keepdim=True)
x_2_first = th.cat([x_2_first_pool, x_2_first], 1)
x_2 = self.siglip_normalization(x_2)
dino_embd = th.cat([x_1, x_1_first], -1)
siglip_embd = th.cat([x_2, x_2_first], -1)
return dino_embd, siglip_embd