孤立森林(Isolation Forest)算法,是一種用于排查異常資料的算法。它的名字聽上去十分高端,感覺不可能在OI裡用到,事實上也的确沒有聽說過這方面的題,而我接觸這個算法是因為數學模組化大賽時小組中發現這個算法可用,然後就去學了一下。本身要感性了解它并不困難,不過打起來還是有點繁瑣。當時隻用到了弱化再弱化的一維的情況,是以總結一下一維資料的做法,感性了解即可。
有一個數列儲存的一些計算資料,大部分正常,少部分是錯誤的,現在要排除出過于離譜的資料。
我們把資料放到數軸上,容易發現大部分資料都是堆在一起的,隻有少部分資料裡中心點較遠。
那麼我們在數列的最大值和最小值之間随機選取一個數作為中點,把數列一分為二。
接下來遞歸進入左右兩邊繼續同樣的操作,将值域不斷分割。顯然每個區間包含的點數是越來越少的,最終會變成1,這個時候就不再分割,記錄下這個點遞歸了多少層,然後return。

那麼這些區間就會形成一個搜尋樹,而越是離中心地區遠的點就越容易提早被孤立出來,它們的深度就越淺,是以以深度就可以大緻判斷那些資料是錯誤的。當然跑一次的偶然機率太大,我們可以多跑幾次(一般100次),給每一個點的深度求一個平均值以減小誤差。
但是這還遠遠沒有完,如果出現“異常資料也抱團”和“正常異常資料太接近”的情況,孤立森林的誤差就會大大增加, 是以我們應該減小資料的規模來避免它,每次取\(\psi\)個數跑IForest即可,而這個\(\psi\)一般取常數256。
同時,當分割已經分割出很多異常資料時,仍然“抱團的資料”基本上可以視為正常,我們規定一個最大深度\(Hmax = ceil(log_2 \psi) = 8\)(根節點深度是0),達到最大深度時,即使開沒有分割完畢也不再繼續下分。
最後我們得到每一個點相對準确的平均深度\(h_x\),可以用這個來計算“異常分數值”\(s(x, \psi)\),當\(\psi\)固定時可以按照下式計算:
\[s(x, \psi) = 2 ^ {\frac {h_x} {c(\psi)}}
\]
其中\(c(i) = 2H(i-1)-\frac {2(i-1)} i,\ H(i) = ln(i)+0.577215665\),這個c值因為\(\psi\)是常數可以直接計算出。
現在我們得到了所有s值,s值越接近1的就可以認為越異常。這裡放上CSDN的一張圖:
// 注意生成随機數不要寫錯
inline d64 rand(d64 l, d64 r) {
return rand() / 32768.0 * (r - l + 1) + l;
}
inline int rand(int l, int r) {
return rand() * rand() % (r - l + 1) + l;
}
unordered_map <double, int> HASH;
int sc[MAX], uc[MAX];
double b[MAX], c[MAX];
void DFS(d64 l, d64 r, int dep)
{
int xl = lower_bound(b, b+256, l) - b;
int xr = upper_bound(b, b+256, r) - b;
//找到本區間最靠左, 最靠右的點編号
if (xr - xl <= 1 || dep == 8) // 如果隻剩一個或者到達深度上限
{
for (int i = xl; i < xr; i++)
{
// 統計這個數被抽到的次數和深度總和,計算平均數
sc[HASH[b[i]]] += dep;
uc[HASH[b[i]]] ++;
}return;
}
double mid = rand(b[xl], b[xr-1]);// 随機斷開,遞歸分治
DFS(l, mid, dep+1);
DFS(mid, r, dep+1);
}
void Iforest(int l, int r)
{
for (int i = 1; i <= N; i++)
HASH[b[i]] = i, ord[i] = i;
for (int T = 1; T <= 900; T++) // 算法運作次數
{
for (int i = 1; i <= N; i++)
swap(b[rand(1, N)], b[rand(1, N)]);
// 随機打亂序列
sort(b + 1, b + 257); // 選取前256個數跑算法
DFS(b[1], b[256], 0);
}
for (int i = 1; i <= N; i++)
{
d64 avee = (d64)sc[i] / uc[i];
c[i] = pow(2, -avee / 12.5237); // 這裡的c就是剛才的s值
}
sort(ord + 1, ord + N + 1, [](int x, int y) {return c[x] > c[y];});
for (int i = 1; i <= N; i++)
rank[ord[i]] = i;
// 記錄每個點的排名,這裡選取了c值最大的2%作為異常資料
for (int i = l; i <= r; i++)
puts(rank[i] * 50 <= N ? "Abnormal" : "Normal");
}