Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Fei-Wang committed Mar 21, 2023
1 parent 20b0892 commit 8b22d27
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 87 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
transformers_pretrained_model_dir = 'bert-base-chinese'

model = dict(
type='RMModel',
type='NamiBertForSequenceClassification',
bert=dict(
type='NamiAutoModel',
model_type='BertForSequenceClassification',
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# )
#
model = dict(
type='GPTModel',
type='NamiGPT2LMHeadModel',
gpt=dict(
type='NamiAutoModel',
model_type='GPT2LMHeadModel',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = [
'../_base_/datasets/reward_model.py', '../_base_/models/reward_model.py',
'../_base_/datasets/bert.py', '../_base_/models/bert.py',
'../_base_/schedules/default_schedule.py', '../_base_/default_runtime.py'
]
2 changes: 1 addition & 1 deletion src/nami/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import gpt, reward_model, utils
from . import gpt2, bert, utils
7 changes: 7 additions & 0 deletions src/nami/models/bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from transformers import AutoConfig, AutoModelForSequenceClassification

from .configuration_bert import NamiBertConfig
from .modeling_bert import NamiBertForSequenceClassification

AutoConfig.register("nami_bert", NamiBertConfig)
AutoModelForSequenceClassification.register(NamiBertConfig, NamiBertForSequenceClassification)
5 changes: 5 additions & 0 deletions src/nami/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..utils import Config


class NamiBertConfig(Config):
model_type = 'nami_bert'
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from nami.registry import MODELS
from ..utils import HFModel, Config


class RMConfig(Config):
model_type = 'nami_rm'
from .configuration_bert import NamiBertConfig
from ..utils import HFModel


@MODELS.register_module()
class RMModel(HFModel):
config_class = RMConfig
class NamiBertForSequenceClassification(HFModel):
config_class = NamiBertConfig

def forward(
self,
Expand Down
6 changes: 0 additions & 6 deletions src/nami/models/gpt/__init__.py

This file was deleted.

63 changes: 0 additions & 63 deletions src/nami/models/gpt/modeling_gpt.py

This file was deleted.

8 changes: 8 additions & 0 deletions src/nami/models/gpt2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from transformers import AutoConfig, AutoModelForCausalLM

from .configuration_gpt2 import NamiGPT2Config
from .modeling_gpt2 import NamiGPT2LMHeadModel, NamiGPT2LMHeadModelWithValueHead

AutoConfig.register("nami_gpt2", NamiGPT2Config)
AutoModelForCausalLM.register(NamiGPT2Config, NamiGPT2LMHeadModel)
AutoModelForCausalLM.register(NamiGPT2Config, NamiGPT2LMHeadModelWithValueHead)
5 changes: 5 additions & 0 deletions src/nami/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..utils import Config


class NamiGPT2Config(Config):
model_type = 'nami_gpt2'
124 changes: 124 additions & 0 deletions src/nami/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from nami.registry import MODELS
from .configuration_gpt2 import NamiGPT2Config
from ..utils import HFModel


@MODELS.register_module()
class NamiGPT2LMHeadModel(HFModel):
config_class = NamiGPT2Config

def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
mode: str = 'pred', **kwargs):
output = self.gpt(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_attention_mask=encoder_attention_mask,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
self.input_ids = input_ids
self.labels = labels
if mode == 'train':
loss = self._get_loss(output, input_ids, labels)
return dict(loss=loss)
elif mode == 'eval':
loss = self._get_loss(output, input_ids, labels)
return dict(loss=loss),
else:
return output

def _get_loss(self, output, input_ids, labels):
shift_logits = output.logits[..., :-1, :].contiguous()
if labels is None:
labels = input_ids
shift_labels = labels[..., 1:].contiguous()
loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

return loss


@MODELS.register_module()
class NamiGPT2LMHeadModelWithValueHead(HFModel):
config_class = NamiGPT2Config

def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
mode: str = 'pred', **kwargs):

output_hidden_states = True # set True to compute value
output = self.gpt(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_attention_mask=encoder_attention_mask,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)

lm_logits = output.logits
last_hidden_state = output.hidden_states[-1]
value = self.value_head(last_hidden_state).squeeze(-1)

# self.input_ids = input_ids
# self.labels = labels
if mode == 'train':
loss = self._get_loss(output, input_ids, labels)
return dict(loss=loss)
elif mode == 'eval':
loss = self._get_loss(output, input_ids, labels)
return dict(loss=loss),
else:
return output

def _get_loss(self, output, input_ids, labels):
shift_logits = output.logits[..., :-1, :].contiguous()
if labels is None:
labels = input_ids
shift_labels = labels[..., 1:].contiguous()
loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

return loss
6 changes: 0 additions & 6 deletions src/nami/models/reward_model/__init__.py

This file was deleted.

1 change: 1 addition & 0 deletions src/nami/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_model import HFModel, NamiAutoModel, Config
from .builder import register_transformers_automodel, register_transformers_models
from .layers import LinearHead

register_transformers_automodel()
register_transformers_models()
27 changes: 27 additions & 0 deletions src/nami/models/utils/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch.nn as nn

from franky.model import BaseModule
from nami.registry import MODELS


@MODELS.register_module()
class LinearHead(BaseModule):
def __init__(self,
num_classes,
in_channels,
dropout_prob=0.1,
init_cfg=dict(type='BasicNLP', layer='Linear'),
**kwargs):
super(LinearHead, self).__init__(init_cfg=init_cfg)

self.in_channels = in_channels
self.num_classes = num_classes
self.dropout = nn.Dropout(dropout_prob) if dropout_prob else nn.Identity()

self.fc = nn.Linear(self.in_channels, self.num_classes)

def forward(self, feats):
output = self.dropout(feats)
output = self.fc(output)

return output
2 changes: 1 addition & 1 deletion src/nami/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.4'
__version__ = '0.0.5'

0 comments on commit 8b22d27

Please sign in to comment.