Skip to content

Commit

Permalink
fixed bug in the process of back-propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinlin311tw committed Oct 4, 2016
1 parent 09de548 commit 46ccfc8
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 114 deletions.
10 changes: 6 additions & 4 deletions examples/SSDH/solver.prototxt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
net: "./examples/SSDH/train_val.prototxt"
test_iter: 100
test_interval: 100
test_iter: 200
test_interval: 1000
base_lr: 0.001
lr_policy: "step"
gamma: 0.1
stepsize: 25000
display: 20
display: 100
max_iter: 50000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot: 50000
snapshot_prefix: "./examples/SSDH/SSDH48"
random_seed: 42
solver_mode: GPU
116 changes: 80 additions & 36 deletions examples/SSDH/train_val.prototxt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: "CaffeNet"
name: "SSDH"
layer {
name: "data"
type: "Data"
Expand All @@ -10,10 +10,18 @@ layer {
transform_param {
mirror: true
crop_size: 227
mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"
}
mean_file: "/home/iis/Research/projects/SSDH/caffe/data/ilsvrc12/imagenet_mean.binaryproto"
}
# mean pixel / channel-wise mean instead of mean image
# transform_param {
# crop_size: 227
# mean_value: 104
# mean_value: 117
# mean_value: 123
# mirror: true
# }
data_param {
source: "data/cifar10/cifar10_train_leveldb"
source: "../data/cifar-10/train-test-split/train_leveldb"
batch_size: 32
}
}
Expand All @@ -26,13 +34,21 @@ layer {
phase: TEST
}
transform_param {
mirror: true
mirror: false
crop_size: 227
mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"
}
mean_file: "/home/iis/Research/packages/caffe-master/data/ilsvrc12/imagenet_mean.binaryproto"
}
# mean pixel / channel-wise mean instead of mean image
# transform_param {
# crop_size: 227
# mean_value: 104
# mean_value: 117
# mean_value: 123
# mirror: true
# }
data_param {
source: "data/cifar10/cifar10_val_leveldb"
batch_size: 32
source: "../data/cifar-10/train-test-split/test_leveldb"
batch_size: 50
}
}
layer {
Expand Down Expand Up @@ -339,10 +355,10 @@ layer {
}
}
layer {
name: "latent_layer"
name: "latent"
type: "InnerProduct"
bottom: "fc7"
top: "latent_layer"
top: "latent"
param {
lr_mult: 1
decay_mult: 1
Expand All @@ -355,41 +371,70 @@ layer {
num_output: 48
weight_filler {
type: "gaussian"
std: 0.005
std: 0.01
}
bias_filler {
type: "constant"
value: 1
value: 0
}
}
}
layer {
name: "encode_neuron"
bottom: "latent_layer"
top: "encode_neuron"
bottom: "latent"
top: "latent_sigmoid"
name: "latent_sigmoid"
type: "Sigmoid"
}
layer {
name: "loss_beta"
bottom: "latent_sigmoid"
bottom: "latent_sigmoid"
name: "loss_1"
type: "K1_EuclideanLoss"
bottom: "encode_neuron"
bottom: "encode_neuron"
propagate_down: 1
propagate_down: 0
top: "loss: forcing-binary"
loss_weight: 1
}
layer {
name: "loss_gamma"
name: "latent_sigmoid_reshape"
type: "Reshape"
bottom: "latent_sigmoid"
top: "latent_sigmoid_reshape"
reshape_param {
shape {
dim: 0 # copy the dimension from below
dim: 1
dim: 1
dim: -1 # infer it from the other dimensions
}
}
}
layer {
name: "latent_sigmoid_avg"
type: "Pooling"
bottom: "latent_sigmoid_reshape"
top: "latent_sigmoid_avg"
pooling_param {
pool: AVE
kernel_h: 1
kernel_w: 48
}
}
layer {
bottom: "latent_sigmoid_avg"
bottom: "latent_sigmoid_avg"
name: "loss_2"
type: "K2_EuclideanLoss"
bottom: "encode_neuron"
bottom: "encode_neuron"
propagate_down: 1
propagate_down: 0
top: "loss: 50%-fire-rate"
loss_weight: 1
}
layer {
name: "fc8_classification"
name: "fc9"
type: "InnerProduct"
bottom: "encode_neuron"
top: "fc8_classification"
bottom: "latent_sigmoid"
top: "fc9"
param {
lr_mult: 10
decay_mult: 1
Expand All @@ -402,30 +447,29 @@ layer {
num_output: 10
weight_filler {
type: "gaussian"
std: 0.01
std: 0.2
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "loss_alpha"
type: "SoftmaxWithLoss"
bottom: "fc8_classification"
bottom: "label"
top: "loss: classfication-error"
loss_weight: 1
}
layer {
name: "accuracy"
type: "Accuracy"
bottom: "fc8_classification"
bottom: "fc9"
bottom: "label"
top: "accuracy"
include {
phase: TEST
}
}

layer {
name: "loss"
type: "SoftmaxWithLoss"
bottom: "fc9"
bottom: "label"
top: "loss: classfication-error"
loss_weight: 1
}
38 changes: 3 additions & 35 deletions src/caffe/layers/K1_euclidean_loss_layer.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ namespace caffe {
template <typename Dtype>
void K1_EuclideanLossLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
/* LossLayer<Dtype>::Reshape(bottom, top);
CHECK_EQ(bottom[0]->count(1), bottom[1]->count(1))
<< "Inputs must have the same dimension.";
diff_.ReshapeLike(*bottom[0]);
*/
LossLayer<Dtype>::Reshape(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
Expand All @@ -26,19 +21,7 @@ void K1_EuclideanLossLayer<Dtype>::Reshape(
template <typename Dtype>
void K1_EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
/* int count = bottom[0]->count();
caffe_sub(
count,
bottom[0]->cpu_data(),
bottom[1]->cpu_data(),
diff_.mutable_cpu_data());
Dtype dot = caffe_cpu_dot(count, diff_.cpu_data(), diff_.cpu_data());
Dtype loss = dot / bottom[0]->num() / Dtype(2);
top[0]->mutable_cpu_data()[0] = loss;
*/

int count = bottom[0]->count();

const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype loss = 0;
for (int i = 0; i < count; ++i) {
Expand All @@ -47,29 +30,16 @@ void K1_EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& botto
Dtype dot = sub * sub;
loss = loss - dot;
}
top[0]->mutable_cpu_data()[0] = loss/(Dtype(2)*count*bottom[0]->num());
top[0]->mutable_cpu_data()[0] = loss/(Dtype(2)*count);
}

template <typename Dtype>
void K1_EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
/* for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num();
caffe_cpu_axpby(
bottom[i]->count(), // count
alpha, // alpha
diff_.cpu_data(), // a
Dtype(0), // beta
bottom[i]->mutable_cpu_diff()); // b
}
}
*/
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? -1 : 1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->count() / bottom[i]->num();
const Dtype sign = -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->count();
caffe_cpu_axpby(
bottom[i]->count(), // count
alpha, // alpha
Expand All @@ -78,8 +48,6 @@ void K1_EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
bottom[i]->mutable_cpu_diff()); // b
}
}

//LOG(ERROR) << "K1: Loss weight: " << Dtype(top[0]->cpu_diff()[0]);
}

#ifdef CPU_ONLY
Expand Down
49 changes: 10 additions & 39 deletions src/caffe/layers/K2_euclidean_loss_layer.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ namespace caffe {
template <typename Dtype>
void K2_EuclideanLossLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
/* LossLayer<Dtype>::Reshape(bottom, top);
CHECK_EQ(bottom[0]->count(1), bottom[1]->count(1))
<< "Inputs must have the same dimension.";
diff_.ReshapeLike(*bottom[0]);
*/
LossLayer<Dtype>::Reshape(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
Expand All @@ -26,39 +21,26 @@ void K2_EuclideanLossLayer<Dtype>::Reshape(
template <typename Dtype>
void K2_EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
/* int count = bottom[0]->count();
caffe_sub(
count,
bottom[0]->cpu_data(),
bottom[1]->cpu_data(),
diff_.mutable_cpu_data());
Dtype dot = caffe_cpu_dot(count, diff_.cpu_data(), diff_.cpu_data());
Dtype loss = dot / bottom[0]->num() / Dtype(2);
top[0]->mutable_cpu_data()[0] = loss;
*/
int count = bottom[0]->count();
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype loss = 0;
Dtype sum = 0;
Dtype avg = 0;
for (int i = 0; i < count; ++i) {
sum = sum + bottom_data[i];
}
avg = sum / count;
loss = (avg - Dtype(0.5)) * (avg - Dtype(0.5));
top[0]->mutable_cpu_data()[0] = loss / bottom[0]->num() / Dtype(2);


diff_.mutable_cpu_data()[0] = avg - Dtype(0.5);
for (int i = 0; i < count; ++i) {
Dtype sub = bottom_data[i] - Dtype(0.5);
diff_.mutable_cpu_data()[i] = sub;
Dtype dot = sub * sub;
loss = loss + dot;
}
top[0]->mutable_cpu_data()[0] = loss / (Dtype(2)*count);
}

template <typename Dtype>
void K2_EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
/* for (int i = 0; i < 2; ++i) {
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num();
const Dtype sign = 1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->count();
caffe_cpu_axpby(
bottom[i]->count(), // count
alpha, // alpha
Expand All @@ -67,17 +49,6 @@ void K2_EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
bottom[i]->mutable_cpu_diff()); // b
}
}
*/
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0];
Dtype result = (alpha * diff_.cpu_data()[0] / bottom[i]->count()) / bottom[i]->num();
for (int j = 0; j < bottom[i]->count(); j++)
bottom[i]->mutable_cpu_diff()[j] = result;
}
}
//LOG(ERROR) << "K2: Loss weight: " << Dtype(top[0]->cpu_diff()[0]);
}

#ifdef CPU_ONLY
Expand Down

0 comments on commit 46ccfc8

Please sign in to comment.