Skip to content

Commit

Permalink
Adding a dense pattern with reshape rather than broadcast on the bias -
Browse files Browse the repository at this point in the history
#114.

PiperOrigin-RevId: 528136116
  • Loading branch information
botev authored and KfacJaxDev committed Apr 29, 2023
1 parent 04a384b commit 3be9b1a
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,16 @@ def _dense(x: Array, params: Sequence[Array]) -> Array:
return y if not opt_b else y + opt_b[0]


def _dense_with_reshape(x: Array, params: Sequence[Array],) -> Array:
w, b = params
y = jnp.matmul(x, w)
return y + b.reshape([1, b.size])


def _dense_parameter_extractor(
eqns: Sequence[JaxprEqn],
) -> Mapping[str, Any]:
"""Extracts all parameters from the conv_general_dilated operator."""
"""Extracts all parameters from the `dot_general` operator."""
for eqn in eqns:
if eqn.primitive.name == "dot_general":
return dict(**eqn.params)
Expand All @@ -973,6 +979,7 @@ def _dense_parameter_extractor(

def _make_dense_pattern(
with_bias: bool,
reshape: bool,
in_dim: int = 13,
out_dim: int = 7,
) -> GraphPattern:
Expand All @@ -982,7 +989,7 @@ def _make_dense_pattern(
return GraphPattern(
name="dense_with_bias" if with_bias else "dense_no_bias",
tag_primitive=tags.dense,
compute_func=_dense,
compute_func=_dense_with_reshape if reshape else _dense,
parameters_extractor_func=_dense_parameter_extractor,
example_args=[np.zeros(x_shape), [np.zeros(s) for s in p_shapes]],
)
Expand All @@ -1007,7 +1014,7 @@ def _conv2d(x: Array, params: Sequence[Array]) -> Array:
def _conv2d_parameter_extractor(
eqns: Sequence[JaxprEqn],
) -> Mapping[str, Any]:
"""Extracts all parameters from the conv_general_dilated operator."""
"""Extracts all parameters from the `conv_general_dilated` operator."""
for eqn in eqns:
if eqn.primitive.name == "conv_general_dilated":
return dict(**eqn.params)
Expand Down Expand Up @@ -1161,8 +1168,9 @@ def _make_normalization_haiku_pattern(


DEFAULT_GRAPH_PATTERNS = (
_make_dense_pattern(True),
_make_dense_pattern(False),
_make_dense_pattern(True, False),
_make_dense_pattern(True, True),
_make_dense_pattern(False, False),
_make_conv2d_pattern(True),
_make_conv2d_pattern(False),
_make_scale_and_shift_pattern(1, True, True),
Expand Down

0 comments on commit 3be9b1a

Please sign in to comment.