Skip to content

Commit

Permalink
How can one person write so many fucking bugs
Browse files Browse the repository at this point in the history
- ProductNeuron backprop error resulting in 0 neuron delta for
  connecting neurons
- Network evaluation broken for context neurons
- Timestep handling error in backprop
- Network traversal didn't actually work with certain topologies
  • Loading branch information
Mike Campbell committed Dec 3, 2017
1 parent 48c2bf5 commit f951825
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 27 deletions.
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
- So. Many. Bugs. Turns out the product neuron stuff was still broken and
network evaluation wasn't behaving correctly with context neurons (recurrent
connections). Also an error in timestep handling during backprop, and just
generally network traversal ...

*Michael Campbell*

- Don't lock the input connections into the LSTM layer, that acts as the fully
connected part of the network and that's where the majority of learning takes
place, derp.
Expand Down
68 changes: 45 additions & 23 deletions lib/rann/backprop.rb
Original file line number Diff line number Diff line change
Expand Up @@ -113,46 +113,67 @@ def self.run_single network, inputs, targets
error = mse targets, outputs

# backward pass with unravelling for recurrent networks
node_deltas = Hash.new{ |h, k| h[k] = Hash.new 0.to_d }
node_deltas = Hash.new{ |h, k| h[k] = {} }
gradients = Hash.new 0.to_d

initial_timestep = inputs.size - 1
neuron_stack = network.output_neurons.map{ |n| [n, initial_timestep] }
skipped = []

while current = neuron_stack.shift
neuron, timestep = current
next if node_deltas[timestep].key? neuron.id

from_here = bptt_connecting_to neuron, network, timestep
neuron_stack.push *from_here

# neuron delta is summation of neuron deltas deltas for the connections
# from this neuron
step_one =
if neuron.output?
output_index = network.output_neurons.index neuron
mse_delta targets[output_index], outputs[output_index]
else
if neuron.output?
output_index = network.output_neurons.index neuron
step_one = mse_delta targets[output_index], outputs[output_index]
else
sum =
network.connections_from(neuron).reduce 0.to_d do |m, c|
out_timestep = c.output_neuron.context? ? timestep + 1 : timestep
output_node_delta = node_deltas[out_timestep][c.output_neuron.id]

# connection delta is the output neuron delta multiplied by the
# connection's weight
connection_delta =
if c.output_neuron.is_a? ProductNeuron
intermediate =
network.connections_to(c.output_neuron).reject{ |c2| c2 == c }.reduce 0.to_d do |m, c2|
m * states[timestep][:values][c2.input_neuron.id] * c2.weight
end
output_node_delta * intermediate * c.weight
else
output_node_delta * c.weight
if out_timestep > initial_timestep
m
else
# complicated network case, see NOTES.md
# can't find node delta, re-enqueue at back of queue and record
# the skip.
if !output_node_delta
if skipped.size == neuron_stack.size + 1
output_node_delta = 0.to_d
else
neuron_stack.push current
skipped << current
break
end
end

m + connection_delta
# connection delta is the output neuron delta multiplied by the
# connection's weight
connection_delta =
if c.output_neuron.is_a? ProductNeuron
intermediate =
network.connections_to(c.output_neuron).reject{ |c2| c2 == c }.reduce 1.to_d do |m, c2|
m * states[timestep][:values][c2.input_neuron.id] * c2.weight
end
output_node_delta * intermediate * c.weight
else
output_node_delta * c.weight
end

m + connection_delta
end
end
end

step_one = sum || next
end

from_here = bptt_connecting_to neuron, network, timestep
neuron_stack.push *from_here
skipped.clear

node_delta =
ACTIVATION_DERIVATIVES[neuron.activation_function]
Expand Down Expand Up @@ -233,11 +254,12 @@ def self.bptt_connecting_to neuron, network, timestep
# halt traversal if we're at a context and we're at the base timestep
return [] if neuron.context? && timestep == 0

timestep -= 1 if neuron.context?

network.connections_to(neuron).each.with_object [] do |c, a|
# don't enqueue connections from inputs
next if c.input_neuron.input?

timestep -= timestep if neuron.context?
a << [c.input_neuron, timestep]
end
end
Expand Down
4 changes: 1 addition & 3 deletions lib/rann/lstm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def init
memory_standard = RANN::Neuron.new("LSTM #{name} Mem Standard #{j}", 2, :standard, :linear).tap{ |n| @network.add n }
memory_tanh = RANN::Neuron.new("LSTM #{name} Mem Tanh #{j}", 1, :standard, :tanh).tap{ |n| @network.add n }
memory_o_product = RANN::ProductNeuron.new("LSTM #{name} Mem/Hidden 4 Product #{j}", 2, :standard, :linear).tap{ |n| @network.add n }
output = RANN::Neuron.new("LSTM #{name} Output #{j}", 1, :standard, :linear).tap{ |n| @network.add n }
@outputs << output
@outputs << memory_o_product
memory_context = RANN::Neuron.new("LSTM #{name} Mem Context #{j}", 1, :context).tap{ |n| @network.add n }
output_context = RANN::Neuron.new("LSTM #{name} Output Context #{j}", 1, :context).tap{ |n| @network.add n }

Expand All @@ -52,7 +51,6 @@ def init
@network.add RANN::LockedConnection.new memory_standard, memory_tanh, 1.to_d
@network.add RANN::LockedConnection.new o, memory_o_product, 1.to_d
@network.add RANN::LockedConnection.new memory_tanh, memory_o_product, 1.to_d
@network.add RANN::LockedConnection.new memory_o_product, output, 1.to_d
@network.add RANN::LockedConnection.new memory_standard, memory_context, 1.to_d
@network.add RANN::LockedConnection.new memory_context, memory_product, 1.to_d
@network.add RANN::LockedConnection.new memory_context, i, 1.to_d
Expand Down
2 changes: 1 addition & 1 deletion lib/rann/network.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def evaluate input
# would probably be easier to detect circular dependency this way too?
begin
i = 0
until output_neurons.all?{ |neuron| neuron.value }
until connections.select(&:enabled?).all? &:processed?
i += 1
connections.each do |connection|
next if !connection.enabled?
Expand Down

0 comments on commit f951825

Please sign in to comment.