Skip to content

Commit

Permalink
fix and test for HIPS#552
Browse files Browse the repository at this point in the history
  • Loading branch information
Gattocrucco committed Apr 23, 2020
1 parent c6f630a commit 99d52db
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class ArrayBox(Box):
@primitive
def __getitem__(A, idx): return A[idx]

def item(self): return self[(0,) * len(self.shape)]

# Constants w.r.t float data just pass though
shape = property(lambda self: self._value.shape)
ndim = property(lambda self: self._value.ndim)
Expand All @@ -20,7 +22,7 @@ def __getitem__(A, idx): return A[idx]
T = property(lambda self: anp.transpose(self))
def __len__(self): return len(self._value)
def astype(self, *args, **kwargs): return anp._astype(self, *args, **kwargs)

def __neg__(self): return anp.negative(self)
def __add__(self, other): return anp.add( self, other)
def __sub__(self, other): return anp.subtract(self, other)
Expand Down
2 changes: 1 addition & 1 deletion autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def column_stack(tup):
def array(A, *args, **kwargs):
t = builtins.type(A)
if t in (list, tuple):
return array_from_args(args, kwargs, *map(array, A))
return array_from_args(args, kwargs, *map(lambda a: a if a.shape else a.item(), map(array, A)))
else:
return _array_from_scalar_or_array(args, kwargs, A)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ def test_flatten_complex():
val = 1 + 1j
flat, unflatten = flatten(val)
assert np.all(val == unflatten(flat))

def test_object_array():
x = object()
a = np.array([x])
assert a.item() is x

0 comments on commit 99d52db

Please sign in to comment.