From 8a8c7ba627ba0248a2c7a1de50e99d884345369a Mon Sep 17 00:00:00 2001 From: Jonnie Diegelman Date: Fri, 26 Mar 2021 00:04:27 -0400 Subject: [PATCH] Fixes some chain rules stuff. Fixes #76 --- Project.toml | 2 +- src/ComponentArrays.jl | 1 - src/if_required/chainrulescore.jl | 14 ++------------ 3 files changed, 3 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index ef5ee480..6db66581 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.9.2" +version = "0.9.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 5e090b1e..d84d88b3 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -1,7 +1,6 @@ module ComponentArrays using ArrayInterface -# using ChainRulesCore using LinearAlgebra using Requires diff --git a/src/if_required/chainrulescore.jl b/src/if_required/chainrulescore.jl index 247afd6d..5c2b07dc 100644 --- a/src/if_required/chainrulescore.jl +++ b/src/if_required/chainrulescore.jl @@ -1,11 +1,3 @@ -using ChainRulesCore: NO_FIELDS - -# ChainRulesCore.frule(Δ, ::typeof(getproperty), x::ComponentArray, s::Symbol) = frule((_, Δ), getproperty, x, Val(s)) -# function ChainRulesCore.frule(Δ, ::typeof(getproperty), x::ComponentArray, ::Val{s}) where s -# zero_x = zero(x) -# setproperty!(zero_x, s, Δ) -# return getproperty(x, s), zero_x -# end ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s = ChainRulesCore.rrule(getproperty, x, s) function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol) function getproperty_adjoint(Δ) @@ -17,8 +9,6 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbo return getproperty(x, s), getproperty_adjoint end -ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->ComponentArray(Δ, getaxes(x)) - -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(NO_FIELDS, getdata(Δ), getaxes(Δ)) +ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NO_FIELDS, ComponentArray(Δ, getaxes(x))) -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(getdata(Δ), getaxes(Δ)) +ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NO_FIELDS, getdata(Δ), getaxes(Δ))