openpose训练代码(一): http://blog.csdn.net/u011956147/article/details/79292026
openpose训练代码(二):http://blog.csdn.net/u011956147/article/details/79292734
在上一篇openpose训练代码(一) 中讲到cpm_data_transformer,其实这个文件才是包含数据处理核心代码的文件,在上一篇博客提高Transform_nv函数,我们先来看看Transform_nv函数:
template<typename Dtype> void CPMDataTransformer<Dtype>::Transform_nv(const Datum& datum, Blob<Dtype>* transformed_data, Blob<Dtype>* transformed_label, int cnt) {
//std::cout << "Function 2 is used"; std::cout.flush();
const int datum_channels = datum.channels();
//const int datum_height = datum.height();
//const int datum_width = datum.width();
const int im_channels = transformed_data->channels();
//const int im_height = transformed_data->height();
//const int im_width = transformed_data->width();
const int im_num = transformed_data->num();
//const int lb_channels = transformed_label->channels();
//const int lb_height = transformed_label->height();
//const int lb_width = transformed_label->width();
const int lb_num = transformed_label->num();
//LOG(INFO) << "image shape: " << transformed_data->num() << " " << transformed_data->channels() << " "
// << transformed_data->height() << " " << transformed_data->width();
//LOG(INFO) << "label shape: " << transformed_label->num() << " " << transformed_label->channels() << " "
// << transformed_label->height() << " " << transformed_label->width();
CHECK_EQ(datum_channels, );
CHECK_EQ(im_channels, );
///CHECK_EQ(im_channels, 4);
//CHECK_EQ(datum_channels, 4);
CHECK_EQ(im_num, lb_num);
//CHECK_LE(im_height, datum_height);
//CHECK_LE(im_width, datum_width);
CHECK_GE(im_num, );
//const int crop_size = param_.crop_size();
// if (crop_size) {
// CHECK_EQ(crop_size, im_height);
// CHECK_EQ(crop_size, im_width);
// } else {
// CHECK_EQ(datum_height, im_height);
// CHECK_EQ(datum_width, im_width);
// }
Dtype* transformed_data_pointer = transformed_data->mutable_cpu_data();
Dtype* transformed_label_pointer = transformed_label->mutable_cpu_data();
CPUTimer timer;
timer.Start();
Transform_nv(datum, transformed_data_pointer, transformed_label_pointer, cnt); //call function 1
VLOG() << "Transform_nv: " << timer.MicroSeconds() / << " ms";
}
这个函数主要就是得到lmdb的一些参数,比如datum_channels,im_channels 等,转而调用Transform_nv函数
template<typename Dtype> void CPMDataTransformer<Dtype>::Transform_nv(const Datum& datum, Dtype* transformed_data, Dtype* transformed_label, int cnt) {
...
}
data是lmdb的首地址,datum_channels,datum_height ,datum_width 分别是之前python代码确定的每页的尺寸,mask_miss 和mask_all全1的矩阵,为后续所用做准备。
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
// To do: make this a parameter in caffe.proto
//const int mode = 5; //related to datum.channels();
const int mode = ;
//const int crop_size = param_.crop_size();
//const Dtype scale = param_.scale();
//const bool do_mirror = param_.mirror() && Rand(2);
//const bool has_mean_file = param_.has_mean_file();
const bool has_uint8 = data.size() > ;
//const bool has_mean_values = mean_values_.size() > 0;
int crop_x = param_.crop_size_x();
int crop_y = param_.crop_size_y();
CHECK_GT(datum_channels, );
//CHECK_GE(datum_height, crop_size);
//CHECK_GE(datum_width, crop_size);
CPUTimer timer1;
timer1.Start();
//before any transformation, get the image from datum
Mat img = Mat::zeros(datum_height, datum_width, CV_8UC3);
Mat mask_all, mask_miss;
if(mode >= ){
mask_miss = Mat::ones(datum_height, datum_width, CV_8UC1);
}
if(mode == ){
mask_all = Mat::zeros(datum_height, datum_width, CV_8UC1);
}
读取原始图片数据保存在rbg中,以及读取mask_miss 和 mask_all,如下:
offset = img.rows * img.cols,为指针偏移量,和python文件一一对应。
int offset = img.rows * img.cols;
int dindex;
Dtype d_element;
for (int i = ; i < img.rows; ++i) {
for (int j = ; j < img.cols; ++j) {
Vec3b& rgb = img.at<Vec3b>(i, j);
for(int c = ; c < ; c++){
dindex = c*offset + i*img.cols + j;
if (has_uint8)
d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));
else
d_element = datum.float_data(dindex);
rgb[c] = d_element;
}
if(mode >= ){
dindex = *offset + i*img.cols + j;
if (has_uint8)
d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));
else
d_element = datum.float_data(dindex);
if (round(d_element/)!= && round(d_element/)!=){
cout << d_element << " " << round(d_element/) << endl;
}
mask_miss.at<uchar>(i, j) = d_element; //round(d_element/255);
}
if(mode == ){
dindex = *offset + i*img.cols + j;
if (has_uint8)
d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));
else
d_element = datum.float_data(dindex);
mask_all.at<uchar>(i, j) = d_element;
}
}
}
VLOG() << " rgb[:] = datum: " << timer1.MicroSeconds()/ << " ms";
timer1.Start();
接下来开始读meta文件,就是存储的关键点和尺寸信息,其中关键的是ReadMetaData函数,这个函数就是完全按照python写入格式来读的,所以,一定要理清楚python代码的逻辑,不然,这里很容易混乱,同时,这里有一个小的技巧就是转换了关键点的顺序,TransformMetaJoints函数实现这一功能,其实就是为了和MPII数据集对应,我的理解是方便transfer 权重,代码如下:
//color, contract
if(param_.do_clahe())
clahe(img, clahe_tileSize, clahe_clipLimit);
if(param_.gray() == ){
cv::cvtColor(img, img, CV_BGR2GRAY);
cv::cvtColor(img, img, CV_GRAY2BGR);
}
VLOG() << " color: " << timer1.MicroSeconds()/ << " ms";
timer1.Start();
int offset3 = * offset;
int offset1 = datum_width;
int stride = param_.stride();
ReadMetaData(meta, data, offset3, offset1);
if(param_.transform_body_joint()) // we expect to transform body joints, and not to transform hand joints
TransformMetaJoints(meta);
VLOG() << " ReadMeta+MetaJoints: " << timer1.MicroSeconds()/ << " ms";
读取到原始数据后,接下来做的就是数据增广,原始代码主要做了如下几种数据增广:scale、rotate、crop、flip;具体实现如下,没做一个都是叠加在原来的基础上,这里在做数据增广的时候,用到了原图scale的信息:
//Start transforming
Mat img_aug = Mat::zeros(crop_y, crop_x, CV_8UC3);
Mat mask_miss_aug, mask_all_aug ;
//Mat mask_miss_aug = Mat::zeros(crop_y, crop_x, CV_8UC1);
//Mat mask_all_aug = Mat::zeros(crop_y, crop_x, CV_8UC1);
Mat img_temp, img_temp2, img_temp3; //size determined by scale
VLOG() << " input size (" << img.cols << ", " << img.rows << ")";
// We only do random transform as augmentation when training.
if (phase_ == TRAIN) {
as.scale = augmentation_scale(img, img_temp, mask_miss, mask_all, meta, mode);
//LOG(INFO) << meta.joint_self.joints.size();
//LOG(INFO) << meta.joint_self.joints[];
as.degree = augmentation_rotate(img_temp, img_temp2, mask_miss, mask_all, meta, mode);
//LOG(INFO) << meta.joint_self.joints.size();
//LOG(INFO) << meta.joint_self.joints[];
if( && param_.visualize())
visualize(img_temp2, meta, as);
as.crop = augmentation_croppad(img_temp2, img_temp3, mask_miss, mask_miss_aug, mask_all, mask_all_aug, meta, mode);
//LOG(INFO) << meta.joint_self.joints.size();
//LOG(INFO) << meta.joint_self.joints[];
if( && param_.visualize())
visualize(img_temp3, meta, as);
as.flip = augmentation_flip(img_temp3, img_aug, mask_miss_aug, mask_all_aug, meta, mode);
//LOG(INFO) << meta.joint_self.joints.size();
//LOG(INFO) << meta.joint_self.joints[];
if(param_.visualize())
visualize(img_aug, meta, as);
// imshow("img_aug", img_aug);
// Mat label_map = mask_miss_aug;
// applyColorMap(label_map, label_map, COLORMAP_JET);
// addWeighted(label_map, , img_aug, , , label_map);
// imshow("mask_miss_aug", label_map);
if (mode > ){
resize(mask_miss_aug, mask_miss_aug, Size(), /stride, /stride, INTER_CUBIC);
}
if (mode > ){
resize(mask_all_aug, mask_all_aug, Size(), /stride, /stride, INTER_CUBIC);
}
}
else {
img_aug = img.clone();
as.scale = ;
as.crop = Size();
as.flip = ;
as.degree = ;
}
VLOG() << " Aug: " << timer1.MicroSeconds()/ << " ms";
timer1.Start();
数据增广过后就是归一化,和准备label文件,有一点不同的地方就是负责背景关键点的那一个label使用的是mask_miss信息,同时,把输入归一化到 [-0.5, 0.5] 具体如下:
for (int i = ; i < img_aug.rows; ++i) {
for (int j = ; j < img_aug.cols; ++j) {
Vec3b& rgb = img_aug.at<Vec3b>(i, j);
transformed_data[*offset + i*img_aug.cols + j] = (rgb[] - )/;
transformed_data[*offset + i*img_aug.cols + j] = (rgb[] - )/;
transformed_data[*offset + i*img_aug.cols + j] = (rgb[] - )/;
}
}
// label size is image size/ stride
if (mode > ){
for (int g_y = ; g_y < grid_y; g_y++){
for (int g_x = ; g_x < grid_x; g_x++){
for (int i = ; i < np; i++){
float weight = float(mask_miss_aug.at<uchar>(g_y, g_x)) /; //mask_miss_aug.at<uchar>(i, j);
if (meta.joint_self.isVisible[i] != ){
transformed_label[i*channelOffset + g_y*grid_x + g_x] = weight;
}
}
// background channel
if(mode == ){
transformed_label[np*channelOffset + g_y*grid_x + g_x] = float(mask_miss_aug.at<uchar>(g_y, g_x)) /;
}
if(mode > ){
transformed_label[np*channelOffset + g_y*grid_x + g_x] = ;
transformed_label[(*np+)*channelOffset + g_y*grid_x + g_x] = float(mask_all_aug.at<uchar>(g_y, g_x)) /;
}
}
}
}
做完上面的工作,把图片数据准备好,背景关键点准备好,就剩下其它关键点和PAF的label了,主要是在generateLabelMap函数中完成。
//putGaussianMaps(transformed_data + *offset, meta.objpos, , img_aug.cols, img_aug.rows, param_.sigma_center());
//LOG(INFO) << "image transformation done!";
generateLabelMap(transformed_label, img_aug, meta);
VLOG() << " putGauss+genLabel: " << timer1.MicroSeconds()/ << " ms";
//starts to visualize everything (transformed_data in ch, label) fed into conv1
//if(param_.visualize()){
//dumpEverything(transformed_data, transformed_label, meta);
//}
具体的,我们来看一下generateLabelMap函数,大概的说来,主要就是做两件事,其一是在每个关键点部位放置高斯响应,其二就是在有连接的关键点之间放vector,更具体的细节,可以去查阅源代码,这里不再做更为详细的说明:
template<typename Dtype>
void CPMDataTransformer<Dtype>::generateLabelMap(Dtype* transformed_label, Mat& img_aug, MetaData meta) {
int rezX = img_aug.cols;
int rezY = img_aug.rows;
int stride = param.stride();
int grid_x = rezX / stride;
int grid_y = rezY / stride;
int channelOffset = grid_y * grid_x;
int mode = ; // TO DO: make this as a parameter
for (int g_y = ; g_y < grid_y; g_y++){
for (int g_x = ; g_x < grid_x; g_x++){
for (int i = np+; i < *(np+); i++){
if (mode == && i == (*np + ))
continue;
transformed_label[i*channelOffset + g_y*grid_x + g_x] = ;
}
}
}
if (np == ){
for (int i = ; i < ; i++){
Point2f center = meta.joint_self.joints[i];
if(meta.joint_self.isVisible[i] <= ){
putGaussianMaps(transformed_label + (i+np+)*channelOffset, center, param.stride(),
grid_x, grid_y, param.sigma()); //self
}
for(int j = ; j < meta.numOtherPeople; j++){ //for every other person
Point2f center = meta.joint_others[j].joints[i];
if(meta.joint_others[j].isVisible[i] <= ){
putGaussianMaps(transformed_label + (i+np+)*channelOffset, center, param.stride(),
grid_x, grid_y, param.sigma());
}
}
}
int mid_1[] = {, , , , , , , , , , , , , , , , , , };
int mid_2[] = {, , , , , , , , , , , , , , , , , , };
int thre = ;
for(int i=;i<;i++){
Mat count = Mat::zeros(grid_y, grid_x, CV_8UC1);
Joints jo = meta.joint_self;
if(jo.isVisible[mid_1[i]-]<= && jo.isVisible[mid_2[i]-]<=){
//putVecPeaks
putVecMaps(transformed_label + (np+ + *i)*channelOffset, transformed_label + (np+ + *i)*channelOffset,
count, jo.joints[mid_1[i]-], jo.joints[mid_2[i]-], param.stride(), grid_x, grid_y, param.sigma(), thre); //self
}
for(int j = ; j < meta.numOtherPeople; j++){ //for every other person
Joints jo2 = meta.joint_others[j];
if(jo2.isVisible[mid_1[i]-]<= && jo2.isVisible[mid_2[i]-]<=){
//putVecPeaks
putVecMaps(transformed_label + (np+ + *i)*channelOffset, transformed_label + (np+ + *i)*channelOffset,
count, jo2.joints[mid_1[i]-], jo2.joints[mid_2[i]-], param.stride(), grid_x, grid_y, param.sigma(), thre); //self
}
}
}
//put background channel
for (int g_y = ; g_y < grid_y; g_y++){
for (int g_x = ; g_x < grid_x; g_x++){
float maximum = ;
//second background channel
for (int i = np+; i < np+; i++){
maximum = (maximum > transformed_label[i*channelOffset + g_y*grid_x + g_x]) ? maximum : transformed_label[i*channelOffset + g_y*grid_x + g_x];
}
transformed_label[(*np+)*channelOffset + g_y*grid_x + g_x] = max(-maximum, .);
}
}
//LOG(INFO) << "background put";
}
else if (np == ){
for (int i = ; i < ; i++){
Point2f center = meta.joint_self.joints[i];
if(meta.joint_self.isVisible[i] <= ){
putGaussianMaps(transformed_label + (i+np+)*channelOffset, center, param.stride(),
grid_x, grid_y, param.sigma()); //self
}
for(int j = ; j < meta.numOtherPeople; j++){ //for every other person
Point2f center = meta.joint_others[j].joints[i];
if(meta.joint_others[j].isVisible[i] <= ){
putGaussianMaps(transformed_label + (i+np+)*channelOffset, center, param.stride(),
grid_x, grid_y, param.sigma());
}
}
}
int mid_1[] = {, , , , , , , , , , , , , };
int mid_2[] = {, , , , , , , , , , , , , };
int thre = ;
for(int i=;i<;i++){
Mat count = Mat::zeros(grid_y, grid_x, CV_8UC1);
Joints jo = meta.joint_self;
if(jo.isVisible[mid_1[i]]<= && jo.isVisible[mid_2[i]]<=){
//putVecPeaks
putVecMaps(transformed_label + (np+ + *i)*channelOffset, transformed_label + (np+ + *i)*channelOffset,
count, jo.joints[mid_1[i]], jo.joints[mid_2[i]], param.stride(), grid_x, grid_y, param.sigma(), thre); //self
}
for(int j = ; j < meta.numOtherPeople; j++){ //for every other person
Joints jo2 = meta.joint_others[j];
if(jo2.isVisible[mid_1[i]]<= && jo2.isVisible[mid_2[i]]<=){
//putVecPeaks
putVecMaps(transformed_label + (np+ + *i)*channelOffset, transformed_label + (np+ + *i)*channelOffset,
count, jo2.joints[mid_1[i]], jo2.joints[mid_2[i]], param.stride(), grid_x, grid_y, param.sigma(), thre); //self
}
}
}
//put background channel
for (int g_y = ; g_y < grid_y; g_y++){
for (int g_x = ; g_x < grid_x; g_x++){
float maximum = ;
//second background channel
for (int i = np+; i < np+; i++){
maximum = (maximum > transformed_label[i*channelOffset + g_y*grid_x + g_x]) ? maximum : transformed_label[i*channelOffset + g_y*grid_x + g_x];
}
transformed_label[(*np+)*channelOffset + g_y*grid_x + g_x] = max(-maximum, .);
}
}
//LOG(INFO) << "background put";
}
//visualize
if( && param.visualize()){
Mat label_map;
for(int i = ; i < *(np+); i++){
label_map = Mat::zeros(grid_y, grid_x, CV_8UC1);
//int MPI_index = MPI_to_ours[i];
//Point2f center = meta.joint_self.joints[MPI_index];
for (int g_y = ; g_y < grid_y; g_y++){
//printf("\n");
for (int g_x = ; g_x < grid_x; g_x++){
label_map.at<uchar>(g_y,g_x) = (int)(transformed_label[i*channelOffset + g_y*grid_x + g_x]*255);
//printf("%f ", transformed_label_entry[g_y*grid_x + g_x]*255);
}
}
resize(label_map, label_map, Size(), stride, stride, INTER_LINEAR);
applyColorMap(label_map, label_map, COLORMAP_JET);
addWeighted(label_map, ., img_aug, ., ., label_map);
//center = center * (/(float)param.stride());
//circle(label_map, center, , CV_RGB(,,), -);
char imagename [];
sprintf(imagename, "augment_%04d_label_part_%02d.jpg", meta.write_number, i);
//LOG(INFO) << "filename is " << imagename;
imwrite(imagename, label_map);
}
}
}
原文链接:http://blog.csdn.net/u011956147/article/details/79292734