Skip to content

Commit

Permalink
Merge pull request #100 from Juice-jl/compressed-io
Browse files Browse the repository at this point in the history
Support for `read` and `write` of circuits to *.gz compressed files
  • Loading branch information
guyvdbroeck authored Nov 26, 2021
2 parents a6d61b1 + 4c1f95f commit 912ad7e
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 11 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/slow_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ name: Slow Tests
# Controls when the action will run. Triggers the workflow on push or pull request
# events but only for the master branch
on:
pull_request:
branches:
master
# tests are not being maintained so don't run on pull requests
# pull_request:
# branches:
# master

schedule:
- cron: '0 0 */7 * *'
Expand All @@ -26,14 +27,13 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1.5
version: 1.6

# Runs a single command using the runners shell
- name: Unit Tests
run: |
julia --project=test -e 'using Pkg; Pkg.instantiate(); Pkg.build(); Pkg.precompile();'
julia --project=test -e 'using Pkg; Pkg.develop("ProbabilisticCircuits");'
julia --project=test --check-bounds=yes --depwarn=yes test/_manual_/aqua_test.jl
julia --project=test --check-bounds=yes --depwarn=yes test/_manual_/strudel_marginal_tests.jl
julia --project=test --check-bounds=yes --depwarn=yes test/_manual_/strudel_likelihood_tests.jl
julia --project=test --check-bounds=yes --depwarn=yes test/_manual_/ensembles_tests.jl
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ProbabilisticCircuits"
uuid = "2396afbe-23d7-11ea-1e05-f1aa98e17a44"
authors = ["Guy Van den Broeck <[email protected]>"]
version = "0.3.1"
version = "0.3.2"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down Expand Up @@ -38,11 +38,11 @@ CUDA = "3"
Clustering = "0.14"
DataFrames = "1.2"
DataStructures = "0.17, 0.18"
DirectedAcyclicGraphs = "0.1.2"
Distributions = "0.25"
DirectedAcyclicGraphs = "0.1.1"
Graphs = "1.4"
Lerche = "0.5"
LogicCircuits = "0.3.1"
LogicCircuits = "0.3.2"
LoopVectorization = "0.11, 0.12"
MetaGraphs = "0.7"
Metis = "1.0"
Expand Down
13 changes: 11 additions & 2 deletions src/io/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ include("plot.jl")
# if no logic circuit file format is given on read, infer file format from extension

function file2pcformat(file)
if endswith(file,".jpc")
if endswith(file,".gz")
# duplicate code from `LogicCircuits.file2logicformat` -- not sure how to make this nicer
file_inner, _ = splitext(file)
format_inner = file2pcformat(file_inner)
GzipFormat(format_inner)
elseif endswith(file,".jpc")
JpcFormat()
elseif endswith(file,".psdd")
PsddFormat()
Expand Down Expand Up @@ -64,6 +69,10 @@ Base.parse(::Type{ProbCircuit}, args...) =
Base.read(io::IO, ::Type{ProbCircuit}, args...) =
read(io, PlainProbCircuit, args...)

Base.read(io::IO, ::Type{ProbCircuit}, f::GzipFormat) =
# avoid method ambiguity
read(io, PlainProbCircuit, f)

# copy read/write API for tuples of files

function Base.read(files::Tuple{AbstractString, AbstractString}, ::Type{C}, args...) where C <: StructProbCircuit
Expand All @@ -81,4 +90,4 @@ function Base.write(files::Tuple{AbstractString,AbstractString},
write((io1, io2), circuit, args...)
end
end
end
end
2 changes: 1 addition & 1 deletion src/io/spn_io.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export zoo_spn, zoo_spn_file,
SpnFormat, SpnVtreeFormat
SpnFormat

struct SpnFormat <: FileFormat end

Expand Down
File renamed without changes.
17 changes: 17 additions & 0 deletions test/io/jpc_io_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ include("../helper/pc_equals.jl")
test_my_circuit(pc3)
test_pc_equals(pc1, pc3)
@test vtree(pc1) == vtree(pc3)

# read/write compressed
write("$jpc_path.gz", pc1)
pc2 = read("$jpc_path.gz", ProbCircuit)

test_my_circuit(pc2)
test_pc_equals(pc1, pc2)

# read/write compressed structured
paths = ("$jpc_path.gz", vtree_path)
write(paths, pc1)
pc3 = read(paths, StructProbCircuit)

@test pc3 isa StructProbCircuit
test_my_circuit(pc3)
test_pc_equals(pc1, pc3)
@test vtree(pc1) == vtree(pc3)

end

Expand Down
16 changes: 16 additions & 0 deletions test/io/psdd_io_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ include("../helper/pc_equals.jl")
@test vtree(pc1) == vtree(pc3)
test_pc_equals(pc0, pc3)

# read/write compressed
write("$psdd_path.gz", pc1)
pc2 = read("$psdd_path.gz", ProbCircuit)

test_my_circuit(pc2)
test_pc_equals(pc1, pc2)

# read/write compressed structured
paths = ("$psdd_path.gz", vtree_path)
write(paths, pc1)
pc3 = read(paths, StructProbCircuit)

@test pc3 isa StructProbCircuit
test_my_circuit(pc3)
test_pc_equals(pc1, pc3)
@test vtree(pc1) == vtree(pc3)
end

end
Expand Down
7 changes: 7 additions & 0 deletions test/io/spn_io_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ include("../helper/pc_equals.jl")
test_my_circuit(pc2)
test_pc_equals(pc1, pc2)

# try compressed
write("$spn_path.gz", pc1)
pc2 = read("$spn_path.gz", ProbCircuit)

test_my_circuit(pc2)
test_pc_equals(pc1, pc2)

end

end
Expand Down

0 comments on commit 912ad7e

Please sign in to comment.