Skip to content

Commit

Permalink
Fix aesara.gpuarray access that fails on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 26, 2021
1 parent 8493641 commit 5795760
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
7 changes: 6 additions & 1 deletion aesara/gpuarray/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@
WIN32_CUDNN_NAMES = ["cudnn64_7.dll", "cudnn64_6.dll", "cudnn64_5.dll"]

if sys.platform == "win32":
aesara.gpuarray.pathparse.PathParser(config.dnn__bin_path)
try:
from aesara.gpuarray.pathparse import PathParser

PathParser(config.dnn__bin_path)
except ImportError:
pass


def _load_lib(name):
Expand Down
10 changes: 7 additions & 3 deletions aesara/misc/check_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order=

f() # Ignore first function call to get representative time.
if execute:
sync = hasattr(aesara, "gpuarray") and isinstance(
c, aesara.gpuarray.GpuArraySharedVariable
)
try:
from aesara.gpuarray import GpuArraySharedVariable

sync = isinstance(c, GpuArraySharedVariable)
except ImportError:
sync = False

if sync:
# Make sure we don't include the time from the first call
c.get_value(borrow=True, return_internal_type=True).sync()
Expand Down
11 changes: 0 additions & 11 deletions aesara/scan/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,18 +1063,7 @@ def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):

def apply(self, fgraph):

# Depending on the value of gpua_flag, get the list of memory
# allocation ops that the optimization should be able to
# handle
alloc_ops = (Alloc, AllocEmpty)
if self.gpua_flag:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try:
alloc_ops += (aesara.gpuarray.GpuAlloc, aesara.gpuarray.GpuAllocEmpty)
except Exception:
pass

nodes = fgraph.toposort()[::-1]
scan_nodes = [
x
Expand Down

0 comments on commit 5795760

Please sign in to comment.