Skip to content

Commit

Permalink
Fix sparse tracers
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Aug 3, 2024
1 parent 780f8e0 commit 6a6ff51
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __add__(self, other):
result = {dim: d1.get(dim, 0) + d2.get(dim, 0) for dim in all_dims}
return Shift(result)

def __bool__(self):
return bool(self.by_dim)


class ShiftLinTracer(Tensor):
"""
Expand Down Expand Up @@ -435,6 +438,8 @@ def __init__(self, source: TracerSource, matrix: SparseCoordinateTensor, bias: T
assert isinstance(matrix, Tensor)
assert bias.shape in shape
assert matrix.shape.only(shape) == shape.only(matrix.shape, reorder=True)
if any(d.endswith('_src') for d in matrix.shape.names):
assert any(not d.endswith('_src') for d in matrix.shape.names) # If there is an input dim, there needs to be an output dim
for dim in source.shape:
if '~' + dim.name + '_src' in matrix.shape:
assert (dim in matrix.shape) == (dim in shape), f"Inconsistent traced output dim {dim} for tracer {shape} with matrix {matrix.shape}"
Expand Down Expand Up @@ -750,7 +755,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
pattern_dims = tracer._source.shape.only(pattern_dim_names(tracer))
assert batch(pattern_dims).is_empty, f"Batch dimensions may not be sliced in linear operations but got pattern for {batch(pattern_dims)}"
out_shape, src_shape, typed_src_shape, missing_dims, sliced_src_shape = matrix_dims_for_tracer(tracer, sparsify_batch)
out_shape_original = rename_dims(out_shape, [*tracer._out_name_to_original.keys()], [*tracer._out_name_to_original.values()])
original_out_names = [tracer._out_name_to_original.get(d, d) for d in out_shape.names]
batch_val = merge_shapes(*tracer.val.values()).without(out_shape)
if non_batch(out_shape).is_empty:
assert len(tracer.val) == 1 and not non_batch(tracer.val[Shift({})])
Expand All @@ -763,11 +768,11 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
native_shift_values = math.reshaped_native(shift_val, [batch_val, *out_shape])
mask = np.sum(abs(native_shift_values), 0) # only 0 where no batch entry has a non-zero value
out_idx = numpy.nonzero(mask)
src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, out_shape_original)]
src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, original_out_names)]
values.append(native_shift_values[(slice(None), *out_idx)])
else: # add full stencil tensor
out_idx = np.unravel_index(np.arange(out_shape.volume), out_shape.sizes) if out_shape else 0
src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, out_shape_original)]
src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, original_out_names)]
values.append(math.reshaped_native(shift_val, [batch_val, out_shape]))
out_indices.append(out_idx)
src_idx_all = []
Expand Down Expand Up @@ -870,7 +875,8 @@ def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer:
return tracer
if isinstance(tracer, ShiftLinTracer):
matrix, bias = tracer_to_coo(tracer, sparsify_batch=False, separate_independent=False)
matrix = rename_dims(matrix, dual, [n + '_src' for n in dual(matrix).as_batch().names])
src_dims = dual(matrix) - set(tracer._renamed)
matrix = rename_dims(matrix, src_dims, [n + '_src' for n in src_dims.as_batch().names])
return SparseLinTracer(tracer._source, matrix, bias, tracer.shape)
assert isinstance(tracer, GatherLinTracer)
if tracer._selection is None:
Expand Down Expand Up @@ -905,7 +911,7 @@ def to_gather_tracer(t: Tensor) -> GatherLinTracer:
assert isinstance(t, ShiftLinTracer)
if len(t.val) > 1 or next(iter(t.val)):
raise NotImplementedError(f"Converting off-diagonal elements to sparse tracer not supported")
return GatherLinTracer(t._source, t.val[EMPTY_SHAPE], t._bias, t._shape, None, t._renamed)
return GatherLinTracer(t._source, t.val[Shift({})], t._bias, t._shape, None, t._renamed)


def expand_matrix(matrix: Tensor, dims: Shape) -> Tensor:
Expand Down

0 comments on commit 6a6ff51

Please sign in to comment.