Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 637931928
Change-Id: I0c49a20c44f6359c21e42351ffb03a02e80b2d6a
  • Loading branch information
Nanodo Team authored and peterjliu committed May 29, 2024
1 parent 8fd56f3 commit c9ca893
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions nanodo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ class TBlock(nn.Module):

@nn.compact
def __call__(self, in_BxLxD: jax.Array):
docfg = self.docfg
cfg = self.docfg

# "pre-layernorm"
x_BxLxD = nn.LayerNorm(dtype=docfg.dtype, use_bias=False)(in_BxLxD)
x_BxLxD = CausalAttn(docfg)(x_BxLxD)
x_BxLxD = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)(in_BxLxD)
x_BxLxD = CausalAttn(cfg)(x_BxLxD)
x_BxLxD += in_BxLxD

z_BxLxD = nn.LayerNorm(dtype=docfg.dtype, use_bias=False)(x_BxLxD)
z_BxLxD = Mlp(docfg)(z_BxLxD)
z_BxLxD = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)(x_BxLxD)
z_BxLxD = Mlp(cfg)(z_BxLxD)

return x_BxLxD + z_BxLxD

Expand Down

0 comments on commit c9ca893

Please sign in to comment.