基于Permutohedral Lattice 的Bilateral filter 源碼及部分注釋【來自于網絡】
實作基于論文《Fast High-Dimensional Filtering Using the Permutohedral Lattice》 .
延伸閱讀 saliency filters精讀之permutohedral lattice
1.bilateralPermutohedral 方法:
static Mat bilateralPermutohedral(Mat img, Mat edge, float sigma_s, float sigma_r) // img 和 edge 都必須是CV_32F類型
{
float invSpatialStdev = 1.0f / sigma_s;
float invColorStdev = 1.0f / sigma_r;
// Construct the position vectors out of x, y, r, g, and b.
int height = img.rows;
int width = img.cols;
int eCh = edge.channels(); // 1 或 3
int iCh = img.channels();
Image positions(1, width, height, 2 + eCh); // 隻有一個子視窗
Image input(1, width, height, iCh);
//From Mat to Image
for (int y = 0; y < height; y++)
{
float *pimg = img.ptr<float>(y);
float *pedge = edge.ptr<float>(y);
for (int x = 0; x < width; x++)
{
// 參考論文 p4 3.1
// 5維的 positiion vector
positions(x, y)[0] = invSpatialStdev * x; // 0
positions(x, y)[1] = invSpatialStdev * y; // 1
for(int c = 0; c < eCh; c++)
positions(x, y)[2 + c] = invColorStdev * pedge[x * eCh + c]; // 2+
// 3維的 input vector
for(int c = 0; c < iCh; c++)
input(x, y)[c] = pimg[x * iCh + c];
}
}
// Filter the input with respect to the position vectors. (see permutohedral.h)
Image out = PermutohedralLattice::filter(input, positions);
// Save the result
Mat imgOut(img.size(), img.type());
for (int y = 0; y < height; y++)
{
float *pimgOut = imgOut.ptr<float>(y);
for (int x = 0; x < width; x++)
{
for(int c = 0; c < iCh; c++)
pimgOut[x * iCh + c] = out(x, y)[c];
}
}
return imgOut;
}
2. PermutohedralLattice 類:
/***************************************************************/
/* The algorithm class that performs the filter
*
* PermutohedralLattice::filter(...) does all the work.
*
*/
/***************************************************************/
class PermutohedralLattice
{
public:
/* Filters given image against a reference image.
* im : image to be bilateral-filtered. (input vector)
* ref : reference image whose edges are to be respected. (position vector)
*/
static Image filter(Image im, Image ref)
{
//timeval t[5];
// Create lattice
// gettimeofday(t+0, NULL);
// d = ref.channels (5)
// vd = im.channels + 1 (3+1)
PermutohedralLattice lattice(ref.channels, im.channels + 1, im.width * im.height * im.frames);
// Splat into the lattice
// gettimeofday(t+1, NULL);
// printf("Splatting...\n");
float *col = new float[im.channels + 1];
col[im.channels] = 1; // homogeneous coordinate
float *imPtr = im(0, 0, 0);
float *refPtr = ref(0, 0, 0); // position vector
for (int t = 0; t < im.frames; t++)
{
for (int y = 0; y < im.height; y++)
{
for (int x = 0; x < im.width; x++)
{
for (int c = 0; c < im.channels; c++)
{
col[c] = *imPtr++;
}
lattice.splat(refPtr, col);
refPtr += ref.channels;
}
}
}
// Blur the lattice
// gettimeofday(t+2, NULL);
// printf("Blurring...");
lattice.blur();
// Slice from the lattice
// gettimeofday(t+3, NULL);
// printf("Slicing...\n");
Image out(im.frames, im.width, im.height, im.channels);
lattice.beginSlice();
float *outPtr = out(0, 0, 0);
for (int t = 0; t < im.frames; t++)
{
for (int y = 0; y < im.height; y++)
{
for (int x = 0; x < im.width; x++)
{
lattice.slice(col);
float scale = 1.0f / col[im.channels];
for (int c = 0; c < im.channels; c++)
{
*outPtr++ = col[c] * scale;
}
}
}
}
// Print time elapsed for each step
// gettimeofday(t+4, NULL);
// const char *names[4] = {"Init ", "Splat ", "Blur ", "Slice "};
// for (int i = 1; i < 5; i++)
// printf("%s: %3.3f ms\n", names[i-1], (t[i].tv_sec - t[i-1].tv_sec) +
// (t[i].tv_usec - t[i-1].tv_usec)/1000000.0);
return out;
}
/* Constructor
* d_ : dimensionality of key vectors (ref.channels)
* vd_ : dimensionality of value vectors (im.channels + 1)
* nData_ : number of points in the input (im.size * im.frames)
*/
PermutohedralLattice(int d_, int vd_, int nData_) :
d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_)
{
// Allocate storage for various arrays
elevated = new float[d + 1];
scaleFactor = new float[d];
greedy = new short[d + 1];
rank = new char[d + 1];
barycentric = new float[d + 2];
replay = new ReplayEntry[nData * (d + 1)];
nReplay = 0;
canonical = new short[(d + 1) * (d + 1)];
key = new short[d + 1];
// compute the coordinates of the canonical simplex, in which
// the difference between a contained point and the zero
// remainder vertex is always in ascending order. (See pg.4 of paper.)
// 論文第四頁,d=4的矩陣例子(列主序)
for (int i = 0; i <= d; i++)
{
for (int j = 0; j <= d - i; j++)
canonical[i * (d + 1) + j] = i;
for (int j = d - i + 1; j <= d; j++)
canonical[i * (d + 1) + j] = i - (d + 1);
}
// Compute parts of the rotation matrix E. (See pg.4-5 of paper.)
for (int i = 0; i < d; i++)
{
// the diagonal entries for normalization
scaleFactor[i] = 1.0f / (sqrtf( (float)(i + 1) * (i + 2) ));
/* We presume that the user would like to do a Gaussian blur of standard deviation
* 1 in each dimension (or a total variance of d, summed over dimensions.)
* Because the total variance of the blur performed by this algorithm is not d,
* we must scale the space to offset this.
*
* The total variance of the algorithm is (See pg.6 and 10 of paper):
* [variance of splatting] + [variance of blurring] + [variance of splatting]
* = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12
* = 2d(d+1)(d+1)/3.
*
* So we need to scale the space by (d+1)sqrt(2/3).
*/
// 論文 第四頁 scale position vector
scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3);
}
}
/* Performs splatting with given position and value vectors */
// position: d-dimension position vector
// value: [r, g, b, 1]
void splat(float *position, float *value)
{
// first rotate position into the (d+1)-dimensional hyperplane
// 論文 第五頁 Ex計算
elevated[d] = -d * position[d - 1] * scaleFactor[d - 1];
for (int i = d - 1; i > 0; i--)
elevated[i] = (elevated[i + 1] -
i * position[i - 1] * scaleFactor[i - 1] +
(i + 2) * position[i] * scaleFactor[i]);
elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0];
// prepare to find the closest lattice points
float scale = 1.0f / (d + 1);
char *myrank = rank;
short *mygreedy = greedy;
// greedily search for the closest zero-colored lattice point
// 論文 第三頁
int sum = 0;
for (int i = 0; i <= d; i++)
{
float v = elevated[i] * scale;
float up = ceilf(v) * (d + 1); // 查找最近的整數點,up / down
float down = floorf(v) * (d + 1);
if (up - elevated[i] < elevated[i] - down)
mygreedy[i] = (short)up;
else
mygreedy[i] = (short)down;
sum += mygreedy[i];
}
sum /= d + 1; // consistent remainder (d+1)
// rank differential to find the permutation between this simplex and the canonical one.
// (See pg. 3-4 in paper.)
// 相對內插補點小的rank++
memset(myrank, 0, sizeof(char) * (d + 1));
for (int i = 0; i < d; i++)
for (int j = i + 1; j <= d; j++)
if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j])
myrank[i]++;
else
myrank[j]++;
if (sum > 0)
{
// sum too large - the point is off the hyperplane.
// need to bring down the ones with the smallest differential
for (int i = 0; i <= d; i++)
{
if (myrank[i] >= d + 1 - sum)
{
mygreedy[i] -= d + 1;
myrank[i] += sum - (d + 1);
}
else
myrank[i] += sum;
}
}
else if (sum < 0)
{
// sum too small - the point is off the hyperplane
// need to bring up the ones with largest differential
for (int i = 0; i <= d; i++)
{
if (myrank[i] < -sum)
{
mygreedy[i] += d + 1;
myrank[i] += (d + 1) + sum;
}
else
myrank[i] += sum;
}
}
// Compute barycentric coordinates (See pg.10 of paper.)
memset(barycentric, 0, sizeof(float) * (d + 2));
for (int i = 0; i <= d; i++)
{
barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale;
barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale;
}
barycentric[0] += 1.0f + barycentric[d + 1];
// Splat the value into each vertex of the simplex, with barycentric weights.
for (int remainder = 0; remainder <= d; remainder++)
{
// Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they sum to zero)
for (int i = 0; i < d; i++)
key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]];
// Retrieve pointer to the value at this vertex.
float *val = hashTable.lookup(key, true);
// Accumulate values with barycentric weight.
for (int i = 0; i < vd; i++)
val[i] += barycentric[remainder] * value[i];
// Record this interaction to use later when slicing
replay[nReplay].offset = val - hashTable.getValues();
replay[nReplay].weight = barycentric[remainder];
nReplay++;
}
}
// Prepare for slicing
void beginSlice()
{
nReplay = 0;
}
/* Performs slicing out of position vectors. Note that the barycentric weights and the simplex
* containing each position vector were calculated and stored in the splatting step.
* We may reuse this to accelerate the algorithm. (See pg. 6 in paper.)
*/
void slice(float *col)
{
float *base = hashTable.getValues();
for (int j = 0; j < vd; j++)
col[j] = 0;
for (int i = 0; i <= d; i++)
{
ReplayEntry r = replay[nReplay++];
for (int j = 0; j < vd; j++)
{
col[j] += r.weight * base[r.offset + j];
}
}
}
/* Performs a Gaussian blur along each projected axis in the hyperplane. */
void blur()
{
// Prepare arrays
short *neighbor1 = new short[d + 1];
short *neighbor2 = new short[d + 1];
float *newValue = new float[vd * hashTable.size()];
float *oldValue = hashTable.getValues();
float *hashTableBase = oldValue;
float *zero = new float[vd];
for (int k = 0; k < vd; k++)
zero[k] = 0;
// For each of d+1 axes,
for (int j = 0; j <= d; j++)
{
printf("blur %d\t", j);
fflush(stdout);
// For each vertex in the lattice,
for (int i = 0; i < hashTable.size(); i++) // blur point i in dimension j
{
short *key = hashTable.getKeys() + i * (d); // keys to current vertex
for (int k = 0; k < d; k++)
{
neighbor1[k] = key[k] + 1;
neighbor2[k] = key[k] - 1;
}
neighbor1[j] = key[j] - d;
neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis.
float *oldVal = oldValue + i * vd;
float *newVal = newValue + i * vd;
float *vm1, *vp1;
//printf("first neighbor\n");
vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor
if (vm1)
vm1 = vm1 - hashTableBase + oldValue;
else
vm1 = zero;
//printf("second neighbor\n");
vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor
if (vp1)
vp1 = vp1 - hashTableBase + oldValue;
else
vp1 = zero;
// Mix values of the three vertices
for (int k = 0; k < vd; k++)
newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]);
}
float *tmp = newValue;
newValue = oldValue;
oldValue = tmp;
// the freshest data is now in oldValue, and newValue is ready to be written over
}
// depending where we ended up, we may have to copy data
if (oldValue != hashTableBase)
{
memcpy(hashTableBase, oldValue, hashTable.size()*vd * sizeof(float));
delete oldValue;
}
else
{
delete newValue;
}
printf("\n");
delete zero;
delete neighbor1;
delete neighbor2;
}
private:
int d, vd, nData;
float *elevated, *scaleFactor, *barycentric;
short *canonical;
short *key;
// slicing is done by replaying splatting (ie storing the sparse matrix)
struct ReplayEntry
{
int offset;
float weight;
} *replay;
int nReplay, nReplaySub;
public:
char *rank;
short *greedy;
HashTablePermutohedral hashTable;
};
3. 用于permutohedral lattice的哈希表:
/***************************************************************/
/* Hash table implementation for permutohedral lattice
*
* The lattice points are stored sparsely using a hash table.
* The key for each point is its spatial location in the (d+1)-
* dimensional space.
*/
/***************************************************************/
class HashTablePermutohedral
{
public:
/* Constructor
* kd_: the dimensionality of the position vectors on the hyperplane.
* vd_: the dimensionality of the value vectors
*/
HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_)
{
capacity = 1 << 15;
filled = 0;
entries = new Entry[capacity];
keys = new short[kd * capacity / 2]; // 多元 鍵-值對
values = new float[vd * capacity / 2];
memset(values, 0, sizeof(float)*vd * capacity / 2);
}
// Returns the number of vectors stored.
int size()
{
return filled;
}
// Returns a pointer to the keys array.
short *getKeys()
{
return keys;
}
// Returns a pointer to the values array.
float *getValues()
{
return values;
}
/* Returns the index into the hash table for a given key.
* key: a pointer to the position vector.
* h: hash of the position vector.
* create: a flag specifying whether an entry should be created,
* should an entry with the given key not found.
*/
// 傳回 value 指針的偏移量
int lookupOffset(short *key, size_t h, bool create = true)
{
// Double hash table size if necessary
// 如果存儲的資料達到或超過容量的一半
if (filled >= (capacity / 2) - 1)
{
grow();
}
// Find the entry with the given key
// 根據給定的 hash 索引 entry
while (1)
{
Entry e = entries[h];
// check if the cell is empty
// 檢查該 entry 的 key 是否存在
if (e.keyIdx == -1)
{
if (!create)
return -1; // Return not found.
// need to create an entry. Store the given key.
for (int i = 0; i < kd; i++)
keys[filled * kd + i] = key[i];
e.keyIdx = filled * kd;
e.valueIdx = filled * vd;
entries[h] = e;
filled++;
return e.valueIdx;
}
// check if the cell has a matching key
bool match = true;
for (int i = 0; i < kd && match; i++)
match = keys[e.keyIdx + i] == key[i];
if (match)
return e.valueIdx;
// increment the bucket with wraparound
// 順序查找下一個 entry 【計算出的hash值相同的情況】
h++;
// 如果到達最後一個 entry, 則從第一個 entry 開始找
if (h == capacity)
h = 0;
}
}
/* Looks up the value vector associated with a given key vector.
* k : pointer to the key vector to be looked up.
* create : true if a non-existing key should be created.
*/
float *lookup(short *k, bool create = true)
{
size_t h = hash(k) % capacity;
int offset = lookupOffset(k, h, create);
if (offset < 0)
return NULL;
else
return values + offset;
};
/* Hash function used in this implementation. A simple base conversion. */
size_t hash(const short *key)
{
size_t k = 0;
for (int i = 0; i < kd; i++)
{
k += key[i];
k *= 2531011;
}
return k;
}
private:
/* Grows the size of the hash table */
void grow()
{
printf("Resizing hash table\n");
size_t oldCapacity = capacity;
capacity *= 2; // 變為2倍容量
// Migrate the value vectors.
float *newValues = new float[vd * capacity / 2];
memset(newValues, 0, sizeof(float)*vd * capacity / 2);
memcpy(newValues, values, sizeof(float)*vd * filled);
delete[] values;
values = newValues;
// Migrate the key vectors.
short *newKeys = new short[kd * capacity / 2];
memcpy(newKeys, keys, sizeof(short)*kd * filled);
delete[] keys;
keys = newKeys;
Entry *newEntries = new Entry[capacity];
// Migrate the table of indices.
for (size_t i = 0; i < oldCapacity; i++)
{
if (entries[i].keyIdx == -1)
continue;
// 根據鍵值計算hash
size_t h = hash(keys + entries[i].keyIdx) % capacity;
// 如果hash對應entry的keyidx已經被占用,則順序往後找 entry,直到發現該 entry 的 keyidx 未被占用
while (newEntries[h].keyIdx != -1)
{
h++;
if (h == capacity)
h = 0;
}
newEntries[h] = entries[i];
}
delete[] entries;
entries = newEntries;
}
// Private struct for the hash table entries.
struct Entry
{
Entry() : keyIdx(-1), valueIdx(-1) {}
int keyIdx; // keys 的索引
int valueIdx; // values 的索引
};
short *keys;
float *values;
Entry *entries;
size_t capacity, filled; // 分别表示 entry 的容量 和 已填充的 entry 數
int kd, vd; // keys 和 values 數組的次元(PermutohedraLattice 會将資料 splat 到高維空間)
};
效果圖:
