#include <functional>
#include <utility>
#include <vector>
#include "caffe/layers/accuracy_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void AccuracyLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
top_k_ = this->layer_param_.accuracy_param().top_k();//獲得k,也就是正确類别排前k名算個入acc
has_ignore_label_ =
this->layer_param_.accuracy_param().has_ignore_label();//有沒有要忽略的标簽
if (has_ignore_label_) {
ignore_label_ = this->layer_param_.accuracy_param().ignore_label();
}
}
/*定義中關于axis的說明:
axis指出在預測blob中,哪一維是label軸,如(N x C x H x W)的blob,axis=0,則N為label對應的次元。
axis=1,則C為label對應的次元,而剩下的N為outer樣本數量, H x W為inner樣本數量。由代碼可知,
當axis=k時outer_num_=blob.shape[0,..,k),inner_num_=blob.shape[k+1,..,shape.size)。一般的,
label blob的次元為(N x C),N為樣本數量,C為标簽數量(即類别個數)。
axis=1,outer_num_=N,inner_num_=shape[2,2)=1(即沒有inner)
*/
template <typename Dtype>
void AccuracyLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
CHECK_LE(top_k_, bottom[0]->count() / bottom[1]->count())//要取的k不能比總類别數大
<< "top_k must be less than or equal to the number of classes.";
label_axis_ =
bottom[0]->CanonicalAxisIndex(this->layer_param_.accuracy_param().axis());//label的坐标軸
outer_num_ = bottom[0]->count(0, label_axis_);//基本可以了解為batch中的樣本數
inner_num_ = bottom[0]->count(label_axis_ + 1);//1
CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count())
<< "Number of labels must match number of predictions; "
<< "e.g., if label axis == 1 and prediction shape is (N, C, H, W), "
<< "label count (number of labels) must be N*H*W, "
<< "with integer values in {0, 1, ..., C-1}.";
vector<int> top_shape(0); // Accuracy is a scalar; 0 axes.
top[0]->Reshape(top_shape);//top[0]是總體樣本正确率,标量top[1]為每個類别的正确率,向量
if (top.size() > 1) {
// Per-class accuracy is a vector; 1 axes.
vector<int> top_shape_per_class(1);
top_shape_per_class[0] = bottom[0]->shape(label_axis_);
top[1]->Reshape(top_shape_per_class);
nums_buffer_.Reshape(top_shape_per_class);
}
}
template <typename Dtype>
void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
Dtype accuracy = 0;
const Dtype* bottom_data = bottom[0]->cpu_data();//樣本數*标簽個數(也就是最後一個全連結的輸出層節點個數)
const Dtype* bottom_label = bottom[1]->cpu_data();
const int dim = bottom[0]->count() / outer_num_;
const int num_labels = bottom[0]->shape(label_axis_);
vector<Dtype> maxval(top_k_+1);
vector<int> max_id(top_k_+1);
if (top.size() > 1) {
caffe_set(nums_buffer_.count(), Dtype(0), nums_buffer_.mutable_cpu_data());
caffe_set(top[1]->count(), Dtype(0), top[1]->mutable_cpu_data());
}
int count = 0;
for (int i = 0; i < outer_num_; ++i) {//對于每個樣本
for (int j = 0; j < inner_num_; ++j) {
const int label_value =
static_cast<int>(bottom_label[i * inner_num_ + j]);//第i個樣本的label
if (has_ignore_label_ && label_value == ignore_label_) {//如果這個類别被忽略就計算下一個。
continue;
}
if (top.size() > 1) ++nums_buffer_.mutable_cpu_data()[label_value];//batch中每個類别的總樣本數,為了計算類内正确率
DCHECK_GE(label_value, 0);
DCHECK_LT(label_value, num_labels);
// Top-k accuracy
std::vector<std::pair<Dtype, int> > bottom_data_vector;
for (int k = 0; k < num_labels; ++k) {
bottom_data_vector.push_back(std::make_pair(
bottom_data[i * dim + k * inner_num_ + j], k));//完成帶序号的排序
}
std::partial_sort(
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
// check if true label is in top k predictions
for (int k = 0; k < top_k_; k++) {
if (bottom_data_vector[k].second == label_value) { //如果标定的label在預測的前k個label中
++accuracy;
if (top.size() > 1) ++top[1]->mutable_cpu_data()[label_value];
break;
}
}
++count;
}
}
// LOG(INFO) << "Accuracy: " << accuracy;
top[0]->mutable_cpu_data()[0] = accuracy / count;
if (top.size() > 1) {
for (int i = 0; i < top[1]->count(); ++i) {
top[1]->mutable_cpu_data()[i] =
nums_buffer_.cpu_data()[i] == 0 ? 0 //batch中沒有某一類樣本就把這類樣本的正确率設定為0,不然的話就正常計算
: top[1]->cpu_data()[i] / nums_buffer_.cpu_data()[i];
}
}
// Accuracy layer should not be used as a loss function.
}
INSTANTIATE_CLASS(AccuracyLayer);
REGISTER_LAYER_CLASS(Accuracy);
} // namespace caffe