From 54365f8ca0a060cfc926a5c88bb005a75d238f08 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 01:59:10 +0200 Subject: [PATCH 1/5] ruleconfig support and Zygote tests --- Project.toml | 2 ++ src/AbstractDifferentiation.jl | 4 ++++ src/ruleconfig.jl | 21 +++++++++++++++++++++ test/ruleconfig.jl | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+) create mode 100644 src/ruleconfig.jl create mode 100644 test/ruleconfig.jl diff --git a/Project.toml b/Project.toml index d5d742b..91446c2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Mohamed Tarek and contributors"] version = "0.4.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] +ChainRulesCore = "1" Compat = "3" ExprTools = "0.1" ForwardDiff = "0.10" diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index deebc2a..dc53ce5 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -643,11 +643,15 @@ end @inline asarray(x) = [x] @inline asarray(x::AbstractArray) = x +include("ruleconfig.jl") function __init__() @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl") @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl") @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl") @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) + end end end diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl new file mode 100644 index 0000000..d9f9a20 --- /dev/null +++ b/src/ruleconfig.jl @@ -0,0 +1,21 @@ +using ChainRulesCore: RuleConfig, rrule_via_ad + +""" + ReverseRuleConfigBackend + +AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. +""" +struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode + ruleconfig::RC +end + +AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) + return (vs) -> begin + _, back = rrule_via_ad(ab.ruleconfig, f, xs...) + if vs isa Tuple && length(vs) === 1 + return Base.tail(back(vs[1])) + else + return Base.tail(back(vs)) + end + end +end diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl new file mode 100644 index 0000000..61bcdbf --- /dev/null +++ b/test/ruleconfig.jl @@ -0,0 +1,33 @@ +using AbstractDifferentiation +using Test +using Zygote, Yota + +@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin + backends = [@inferred(AD.ZygoteBackend())] + @testset for backend in backends + @testset "Derivative" begin + test_derivatives(backend) + end + @testset "Gradient" begin + test_gradients(backend) + end + @testset "Jacobian" begin + test_jacobians(backend) + end + @testset "jvp" begin + test_jvp(backend) + end + @testset "j′vp" begin + test_j′vp(backend) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(backend) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(backend) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(backend) + end + end +end From 0a0d8154c4ed1a659af55f50c34a896c0b96c141 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:22:49 +0200 Subject: [PATCH 2/5] actually run the tests :) --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 0348523..d79bafc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,5 @@ using Test include("reversediff.jl") include("finitedifferences.jl") include("tracker.jl") + include("ruleconfig.jl") end From ce1c0c3b69c0c4c7e2fbe2168c3f0b376034fb4b Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:44:41 +0200 Subject: [PATCH 3/5] avoid loading Yota --- test/ruleconfig.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl index 61bcdbf..412d89c 100644 --- a/test/ruleconfig.jl +++ b/test/ruleconfig.jl @@ -1,6 +1,6 @@ using AbstractDifferentiation using Test -using Zygote, Yota +using Zygote @testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin backends = [@inferred(AD.ZygoteBackend())] From 3f645b66482cf97232aff8dbf651541b274e0ed4 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:45:41 +0200 Subject: [PATCH 4/5] yota tests --- src/AbstractDifferentiation.jl | 3 +++ test/ruleconfig.jl | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index dc53ce5..534324d 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -652,6 +652,9 @@ function __init__() @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) end + @require Yota = "cd998857-8626-517d-b929-70ad188a48f0" begin + YotaBackend() = ReverseRuleConfigBackend(Yota.YotaRuleConfig()) + end end end diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl index 412d89c..6ca6a46 100644 --- a/test/ruleconfig.jl +++ b/test/ruleconfig.jl @@ -1,9 +1,9 @@ using AbstractDifferentiation using Test -using Zygote +using Zygote, Yota @testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin - backends = [@inferred(AD.ZygoteBackend())] + backends = [@inferred(AD.ZygoteBackend()), @inferred(AD.YotaBackend())] @testset for backend in backends @testset "Derivative" begin test_derivatives(backend) From 9250d378079bc6766a06017e11185397edb25abf Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:47:49 +0200 Subject: [PATCH 5/5] add Yota to extras --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 91446c2..e4984b2 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesCore = "1" @@ -26,7 +27,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Zygote"] +test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Yota", "Zygote"]