時空上下文視覺跟蹤(STC)算法的解讀與代碼複現
[email protected]
http://blog.csdn.net/zouxy09
本博文主要是關注一篇視覺跟蹤的論文。這篇論文是Kaihua Zhang等人今年投稿到一個會議的文章,因為會議還沒有出結果,是以作者還沒有釋出他的Matlab源代碼。但為了讓我們先睹為快,作者把論文放在arxiv這個網站上面供大家下載下傳了。對于裡面所描述的神奇的效果,大家都躍躍欲試,也有人将其複現了。我這裡也花了一天的時間去複現了單尺度的C++版本,主要是基于OpenCV。多尺度的有點複雜,這個後面再做考慮了。另外,能力有限,論文解讀和代碼實作可能會出現錯誤,是以如果代碼裡面出現錯誤,還望大家不吝指點。
論文見:
Kaihua Zhang, Lei Zhang, Ming-Hsuan Yang,and David Zhang, Fast Trackingvia Spatio-Temporal Context Learning
目前作者已公開了支援多尺度的Matlab代碼了哦。可以到以下網址下載下傳:
http://www4.comp.polyu.edu.hk/~cslzhang/STC/STC.htm
一、概述
該論文提出一種簡單卻非常有效的視覺跟蹤方法。更迷人的一點是,它速度很快,原作者實作的Matlab代碼在i7的電腦上達到350fps。
該論文的關鍵點是對時空上下文(Spatio-Temporal Context)資訊的利用。主要思想是通過貝葉斯架構對要跟蹤的目标和它的局部上下文區域的時空關系進行模組化,得到目标和其周圍區域低級特征的統計相關性。然後綜合這一時空關系和生物視覺系統上的focus of attention特性來評估新的一幀中目标出現位置的置信圖,置信最大的位置就是我們得到的新的一幀的目标位置。另外,時空模型的學習和目标的檢測都是通過FFT(傅裡葉變換)來實作,是以學習和檢測的速度都比較快。
二、工作過程
具體過程見下圖:

(1)t幀:在該幀目标(第一幀由人工指定)已經知道的情況下,我們計算得到一個目标的置信圖(Confidence Map,也就是目标的似然)。通過生物視覺系統上的focus of attention特性我們可以得到另一張機率圖(先驗機率)。通過對這兩個機率圖的傅裡葉變換做除再反傅裡葉變換,就可以得到模組化目标和周圍背景的空間相關性的空間上下文模型(條件機率)。然後我們用這個模型去更新跟蹤下一幀需要的時空上下文模型(可能這裡還不太能了解,看到後面的相關理論分析和算法描述後可能會清晰一點)。
(2)t+1幀:利用t幀的上下文資訊(時空上下文模型),卷積圖像得到一個目标的置信圖,值最大的位置就是我們的目标所在地。或者了解為圖像各個地方對該上下文資訊的響應,響應最大的地方自然就是滿足這個上下文的地方,也就是目标了。
三、相關理論描述
3.1、上下文的重要性
時間和空間上的上下文資訊對跟蹤來說是非常重要的。雖然對跟蹤,我們一直利用了時間上的上下文資訊(用t去跟蹤t+1等),但對空間上下文資訊的利用卻比較匮乏。為什麼空間上下文資訊會重要呢?考慮我們人,例如我們需要在人群中識别某個人臉(衆裡尋他千百度),那我們為什麼隻關注它的臉呢?如果這個人穿的衣服啊帽子和其他人不一樣,那麼這時候的識别和跟蹤就會更加容易和魯棒。或者場景中這個人和其他的東西有一定的關系,例如他靠在一棵樹上,那麼他和樹就存在了一定的關系,而樹在場景中是不會動的(除非你搖動攝像頭了),那我們借助樹來輔助找到這個人是不是比單單去找這個人要容易,特别是人被部分遮擋住的時候。還有一些就是如果這個人帶着女朋友(有其他物體陪着一起運動),那麼可以将他們看成一個集合結構,作為一組進行跟蹤,這樣會比跟蹤他們其中一個要容易。
總之,一個目标很少與整個場景隔離或者沒有任何聯系,因為總存在一些和目标運動存在短時或者長時相關的目标。這種空間上下文的相關性就是我們可以利用的。
在視覺跟蹤,局部上下文包括一個目标和它的附近的一定區域的背景。因為,在連續幀間目标周圍的局部場景其實存在着很強的時空關系。例如,上圖中的目标存在着嚴重的阻擋,導緻目标的外觀發生了很大的變化。然而,因為隻有小部分的上下文區域是被阻擋的,整體的上下問區域是保持相似的,是以該目标的局部上下文不會發生很大的變化。是以,目前幀局部上下文會有助于幫助預測下一幀中目标的位置。圖中,黃色框的是目标,然後它和它的周圍區域,也就是紅色框包圍的區域,就是該目标的上下文區域。左:雖然出現嚴重的阻擋導緻目标的外觀發現很大的變化,但目标中心(由黃色的點表示)和其上下文區域中的周圍區域的其他位置(由紅色點表示)的空間關系幾乎沒有發生什麼變化。中:學習到的時空上下文模型(藍色方框内的區域具有相似的值,表示該區域與目标中心具有相似的空間關系)。右:學習到的置信圖。
時間資訊:鄰近幀間目标變化不會很大。位置也不會發生突變。
空間資訊:目标和目标周圍的背景存在某種特定的關系,當目标的外觀發生很大變化時,這種關系可以幫助區分目标和背景。
對目标這兩個資訊的組合就是時空上下文資訊,該論文就是利用這兩個資訊來進行對阻擋等魯棒并且快速的跟蹤。
3.2、具體細節
跟蹤問題可以描述為計算一個估計目标位置x似然的置信圖:
置信圖c(x)最大的那個位置x*就是目标的位置。從公式上可以看到,似然函數可以分解為兩個機率部分。一個是模組化目标與周圍上下文資訊的空間關系的條件機率P(x|c(z),o),一個是模組化局部上下文各個點x的上下文先驗機率P(c(x)|o)。而條件機率P(x|c(z),o),也就是目标位置和它的空間上下文的關系我們需要學習出來。
(1)Spatial Context Model 空間上下文模型
空間上下文模型描述的是條件機率函數:
hsc(x-z)是一個關于目标x和局部上下文位置z的相對距離和方向的函數,它編碼了目标和它的空間上下文的空間關系。需要注意的是,這個函數并不是徑向對稱的。這有助于分辨二義性。例如圖三,左眼和右眼相對于位置x*來說他們的距離是一樣的,但相對位置也就是方向是不一樣的。是以他們會有不一樣的空間關系。這樣就對防止誤跟蹤有幫助。
另外,這個模型是通過線上學習得到的。随着跟蹤的進行不斷更新。
(2)Context Prior Model 上下文先驗模型
這是先驗機率,模組化為:
其中I(z)是點z的灰階,描述的是這個上下文z的外觀。w是一個權重函數,z離x越近,權值越大。定義如下:
這個權重函數是由生物視覺系統的focus of attention 啟發得到的,它表示人看東西的時候,會聚焦在一個确定的圖像區域。通俗的來說,就是離我們的目标越近的點,會越受關注,越遠就不好意思了,你的光芒會被無情的忽略掉。那多遠的距離會被多大程度的忽略呢?這就得看參數sigma(相當于高斯權重函數的方差)了,這個值越大,越多的風景映入眼簾,祖國大好河山,盡收眼底。如果這個值越小,那就相當于坐井觀天了。
(3)Confidence Map 置信圖
定義為:
這個公式的參數β是很重要的,太大太小效果可能差之千裡。具體分析見原論文。這個置信圖是在給定目标的位置x*的基礎上,我們通過這個公式來計算得到上下文區域任何一點x的似然得到的。
(4)時空模型的快速學習
我們需要基于上下文先驗模型和置信圖來學習這個時空模型:
裡面的卷積可以通過FFT來加速(時域的卷積相當于頻域的乘積),具體如下:
這樣,我們就可以通過兩個FFT和一個IFFT來學習我們要的空間上下文模型了:
然後我們用這個模型去更新時空上下文模型:
(4)最後的跟蹤
得到時空上下文模型後,我們就可以在新的一幀計算目标的置信圖了:
同樣是通過FFT來加速。然後置信圖中值最大的位置,就是我們的目标位置了。
(5)多尺度的實作
多尺度可以通過調整方差sigma來實作。具體分析見原論文。(感覺這個是很remarkable的一點)。尺度和方差sigma的更新如下:
四、算法描述
簡單的算法描述如下,程式設計實作其實也是這個過程。(另外,不知道我的尺度更新的位置對不對,望指點)
(1)t幀:
根據該幀圖像I和得到的目标位置x*。順序進行以下計算:
1)學習空間上下文模型:
2)更新跟蹤下一幀目标需要的時空上下文模型:
3)更新尺度等參數:
(2)t+1幀:
1)計算置信圖:
2)找到最大值,這個最大值的位置就是我們要求的目标位置:
五、代碼實作
我的代碼是基于VS2010+OpenCV2.4.2的(暫時還沒加入邊界處理,也就是跟蹤框到達圖像邊緣的時候程式就會出錯)。代碼可以讀入視訊,也可以讀攝像頭,兩者的選擇隻需要在代碼中稍微修改即可。對于視訊來說,運作會先顯示第一幀,然後我們用滑鼠框選要跟蹤的目标,然後跟蹤器開始跟蹤每一幀。對攝像頭來說,就會一直采集圖像,然後我們用滑鼠框選要跟蹤的目标,接着跟蹤器開始跟蹤後面的每一幀。
另外,為了消去光照的影響,需要先對圖像去均值化,還需要加Hamming窗以減少圖像邊緣對FFT帶來的頻率影響。Hamming窗如下:
另外,OpenCV沒有複數(FFT後是複數)的乘除運算,是以需要自己編寫,參考如下:
複數除法:
複數乘法:
具體代碼如下:
STCTracker.h
[cpp] view plain copy
- // Fast object tracking algorithm
- // Author : zouxy
- // Date : 2013-11-21
- // HomePage : http://blog.csdn.net/zouxy09
- // Email : [email protected]
- // Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning
- // HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/
- // Email: [email protected]
- #pragma once
- #include <opencv2/opencv.hpp>
- using namespace cv;
- using namespace std;
- class STCTracker
- {
- public:
- STCTracker();
- ~STCTracker();
- void init(const Mat frame, const Rect box);
- void tracking(const Mat frame, Rect &trackBox);
- private:
- void createHammingWin();
- void complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag = 0);
- void getCxtPriorPosteriorModel(const Mat image);
- void learnSTCModel(const Mat image);
- private:
- double sigma; // scale parameter (variance)
- double alpha; // scale parameter
- double beta; // shape parameter
- double rho; // learning parameter
- Point center; // the object position
- Rect cxtRegion; // context region
- Mat cxtPriorPro; // prior probability
- Mat cxtPosteriorPro; // posterior probability
- Mat STModel; // conditional probability
- Mat STCModel; // spatio-temporal context model
- Mat hammingWin; // Hamming window
- };
STCTracker.cpp
[cpp] view plain copy
- // Fast object tracking algorithm
- // Author : zouxy
- // Date : 2013-11-21
- // HomePage : http://blog.csdn.net/zouxy09
- // Email : [email protected]
- // Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning
- // HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/
- // Email: [email protected]
- #include "STCTracker.h"
- STCTracker::STCTracker()
- {
- }
- STCTracker::~STCTracker()
- {
- }
- void STCTracker::createHammingWin()
- {
- for (int i = 0; i < hammingWin.rows; i++)
- {
- for (int j = 0; j < hammingWin.cols; j++)
- {
- hammingWin.at<double>(i, j) = (0.54 - 0.46 * cos( 2 * CV_PI * i / hammingWin.rows ))
- * (0.54 - 0.46 * cos( 2 * CV_PI * j / hammingWin.cols ));
- }
- }
- }
- void STCTracker::complexOperation(const Mat src1, const Mat src2, Mat &dst, int flag)
- {
- CV_Assert(src1.size == src2.size);
- CV_Assert(src1.channels() == 2);
- Mat A_Real, A_Imag, B_Real, B_Imag, R_Real, R_Imag;
- vector<Mat> planes;
- split(src1, planes);
- planes[0].copyTo(A_Real);
- planes[1].copyTo(A_Imag);
- split(src2, planes);
- planes[0].copyTo(B_Real);
- planes[1].copyTo(B_Imag);
- dst.create(src1.rows, src1.cols, CV_64FC2);
- split(dst, planes);
- R_Real = planes[0];
- R_Imag = planes[1];
- for (int i = 0; i < A_Real.rows; i++)
- {
- for (int j = 0; j < A_Real.cols; j++)
- {
- double a = A_Real.at<double>(i, j);
- double b = A_Imag.at<double>(i, j);
- double c = B_Real.at<double>(i, j);
- double d = B_Imag.at<double>(i, j);
- if (flag)
- {
- // division: (a+bj) / (c+dj)
- R_Real.at<double>(i, j) = (a * c + b * d) / (c * c + d * d + 0.000001);
- R_Imag.at<double>(i, j) = (b * c - a * d) / (c * c + d * d + 0.000001);
- }
- else
- {
- // multiplication: (a+bj) * (c+dj)
- R_Real.at<double>(i, j) = a * c - b * d;
- R_Imag.at<double>(i, j) = b * c + a * d;
- }
- }
- }
- merge(planes, dst);
- }
- void STCTracker::getCxtPriorPosteriorModel(const Mat image)
- {
- CV_Assert(image.size == cxtPriorPro.size);
- double sum_prior(0), sum_post(0);
- for (int i = 0; i < cxtRegion.height; i++)
- {
- for (int j = 0; j < cxtRegion.width; j++)
- {
- double x = j + cxtRegion.x;
- double y = i + cxtRegion.y;
- double dist = sqrt((center.x - x) * (center.x - x) + (center.y - y) * (center.y - y));
- // equation (5) in the paper
- cxtPriorPro.at<double>(i, j) = exp(- dist * dist / (2 * sigma * sigma));
- sum_prior += cxtPriorPro.at<double>(i, j);
- // equation (6) in the paper
- cxtPosteriorPro.at<double>(i, j) = exp(- pow(dist / sqrt(alpha), beta));
- sum_post += cxtPosteriorPro.at<double>(i, j);
- }
- }
- cxtPriorPro.convertTo(cxtPriorPro, -1, 1.0/sum_prior);
- cxtPriorPro = cxtPriorPro.mul(image);
- cxtPosteriorPro.convertTo(cxtPosteriorPro, -1, 1.0/sum_post);
- }
- void STCTracker::learnSTCModel(const Mat image)
- {
- // step 1: Get context prior and posterior probability
- getCxtPriorPosteriorModel(image);
- // step 2-1: Execute 2D DFT for prior probability
- Mat priorFourier;
- Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)};
- merge(planes1, 2, priorFourier);
- dft(priorFourier, priorFourier);
- // step 2-2: Execute 2D DFT for posterior probability
- Mat postFourier;
- Mat planes2[] = {cxtPosteriorPro, Mat::zeros(cxtPosteriorPro.size(), CV_64F)};
- merge(planes2, 2, postFourier);
- dft(postFourier, postFourier);
- // step 3: Calculate the division
- Mat conditionalFourier;
- complexOperation(postFourier, priorFourier, conditionalFourier, 1);
- // step 4: Execute 2D inverse DFT for conditional probability and we obtain STModel
- dft(conditionalFourier, STModel, DFT_INVERSE | DFT_REAL_OUTPUT | DFT_SCALE);
- // step 5: Use the learned spatial context model to update spatio-temporal context model
- addWeighted(STCModel, 1.0 - rho, STModel, rho, 0.0, STCModel);
- }
- void STCTracker::init(const Mat frame, const Rect box)
- {
- // initial some parameters
- alpha = 2.25;
- beta = 1;
- rho = 0.075;
- sigma = 0.5 * (box.width + box.height);
- // the object position
- center.x = box.x + 0.5 * box.width;
- center.y = box.y + 0.5 * box.height;
- // the context region
- cxtRegion.width = 2 * box.width;
- cxtRegion.height = 2 * box.height;
- cxtRegion.x = center.x - cxtRegion.width * 0.5;
- cxtRegion.y = center.y - cxtRegion.height * 0.5;
- cxtRegion &= Rect(0, 0, frame.cols, frame.rows);
- // the prior, posterior and conditional probability and spatio-temporal context model
- cxtPriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
- cxtPosteriorPro = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
- STModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
- STCModel = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
- // create a Hamming window
- hammingWin = Mat::zeros(cxtRegion.height, cxtRegion.width, CV_64FC1);
- createHammingWin();
- Mat gray;
- cvtColor(frame, gray, CV_RGB2GRAY);
- // normalized by subtracting the average intensity of that region
- Scalar average = mean(gray(cxtRegion));
- Mat context;
- gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);
- // multiplies a Hamming window to reduce the frequency effect of image boundary
- context = context.mul(hammingWin);
- // learn Spatio-Temporal context model from first frame
- learnSTCModel(context);
- }
- void STCTracker::tracking(const Mat frame, Rect &trackBox)
- {
- Mat gray;
- cvtColor(frame, gray, CV_RGB2GRAY);
- // normalized by subtracting the average intensity of that region
- Scalar average = mean(gray(cxtRegion));
- Mat context;
- gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);
- // multiplies a Hamming window to reduce the frequency effect of image boundary
- context = context.mul(hammingWin);
- // step 1: Get context prior probability
- getCxtPriorPosteriorModel(context);
- // step 2-1: Execute 2D DFT for prior probability
- Mat priorFourier;
- Mat planes1[] = {cxtPriorPro, Mat::zeros(cxtPriorPro.size(), CV_64F)};
- merge(planes1, 2, priorFourier);
- dft(priorFourier, priorFourier);
- // step 2-2: Execute 2D DFT for conditional probability
- Mat STCModelFourier;
- Mat planes2[] = {STCModel, Mat::zeros(STCModel.size(), CV_64F)};
- merge(planes2, 2, STCModelFourier);
- dft(STCModelFourier, STCModelFourier);
- // step 3: Calculate the multiplication
- Mat postFourier;
- complexOperation(STCModelFourier, priorFourier, postFourier, 0);
- // step 4: Execute 2D inverse DFT for posterior probability namely confidence map
- Mat confidenceMap;
- dft(postFourier, confidenceMap, DFT_INVERSE | DFT_REAL_OUTPUT| DFT_SCALE);
- // step 5: Find the max position
- Point point;
- minMaxLoc(confidenceMap, 0, 0, 0, &point);
- // step 6-1: update center, trackBox and context region
- center.x = cxtRegion.x + point.x;
- center.y = cxtRegion.y + point.y;
- trackBox.x = center.x - 0.5 * trackBox.width;
- trackBox.y = center.y - 0.5 * trackBox.height;
- trackBox &= Rect(0, 0, frame.cols, frame.rows);
- cxtRegion.x = center.x - cxtRegion.width * 0.5;
- cxtRegion.y = center.y - cxtRegion.height * 0.5;
- cxtRegion &= Rect(0, 0, frame.cols, frame.rows);
- // step 7: learn Spatio-Temporal context model from this frame for tracking next frame
- average = mean(gray(cxtRegion));
- gray(cxtRegion).convertTo(context, CV_64FC1, 1.0, - average[0]);
- context = context.mul(hammingWin);
- learnSTCModel(context);
- }
runTracker.cpp
[cpp] view plain copy
- // Fast object tracking algorithm
- // Author : zouxy
- // Date : 2013-11-21
- // HomePage : http://blog.csdn.net/zouxy09
- // Email : [email protected]
- // Reference: Kaihua Zhang, et al. Fast Tracking via Spatio-Temporal Context Learning
- // HomePage : http://www4.comp.polyu.edu.hk/~cskhzhang/
- // Email: [email protected]
- #include "STCTracker.h"
- // Global variables
- Rect box;
- bool drawing_box = false;
- bool gotBB = false;
- // bounding box mouse callback
- void mouseHandler(int event, int x, int y, int flags, void *param){
- switch( event ){
- case CV_EVENT_MOUSEMOVE:
- if (drawing_box){
- box.width = x-box.x;
- box.height = y-box.y;
- }
- break;
- case CV_EVENT_LBUTTONDOWN:
- drawing_box = true;
- box = Rect( x, y, 0, 0 );
- break;
- case CV_EVENT_LBUTTONUP:
- drawing_box = false;
- if( box.width < 0 ){
- box.x += box.width;
- box.width *= -1;
- }
- if( box.height < 0 ){
- box.y += box.height;
- box.height *= -1;
- }
- gotBB = true;
- break;
- }
- }
- int main(int argc, char * argv[])
- {
- VideoCapture capture;
- capture.open("handwave.wmv");
- bool fromfile = true;
- if (!capture.isOpened())
- {
- cout << "capture device failed to open!" << endl;
- return -1;
- }
- //Register mouse callback to draw the bounding box
- cvNamedWindow("Tracker", CV_WINDOW_AUTOSIZE);
- cvSetMouseCallback("Tracker", mouseHandler, NULL );
- Mat frame;
- capture >> frame;
- while(!gotBB)
- {
- if (!fromfile)
- capture >> frame;
- imshow("Tracker", frame);
- if (cvWaitKey(20) == 27)
- return 1;
- }
- //Remove callback
- cvSetMouseCallback("Tracker", NULL, NULL );
- STCTracker stcTracker;
- stcTracker.init(frame, box);
- int frameCount = 0;
- while (1)
- {
- capture >> frame;
- if (frame.empty())
- return -1;
- double t = (double)cvGetTickCount();
- frameCount++;
- // tracking
- stcTracker.tracking(frame, box);
- // show the result
- stringstream buf;
- buf << frameCount;
- string num = buf.str();
- putText(frame, num, Point(20, 30), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 0, 255), 3);
- rectangle(frame, box, Scalar(0, 0, 255), 3);
- imshow("Tracker", frame);
- t = (double)cvGetTickCount() - t;
- cout << "cost time: " << t / ((double)cvGetTickFrequency()*1000.) << endl;
- if ( cvWaitKey(1) == 27 )
- break;
- }
- return 0;
- }