天天看點

MeanShift聚類-02python案例

Intro

  Meanshift的使用案例~

資料引入

from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline      
from sklearn.datasets import load_iris
import pandas as pd
pd.set_option('display.max_rows', 500) # 列印最大行數
pd.set_option('display.max_columns', 500) # 列印最大列數      
# 檢查是否是array格式,如果不是,轉換成array
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import      
iris_df = pd.DataFrame(
    load_iris()["data"],
    columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
iris_df["target"] = load_iris()["target"]
iris_df.head()      
sepal_length sepal_width petal_length petal_width target
5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
iris_df.groupby(by="target").describe()      
sepal_length sepal_width petal_length petal_width
count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max
target
50.0 5.006 0.352490 4.3 4.800 5.0 5.2 5.8 50.0 3.428 0.379064 2.3 3.200 3.4 3.675 4.4 50.0 1.462 0.173664 1.0 1.4 1.50 1.575 1.9 50.0 0.246 0.105386 0.1 0.2 0.2 0.3 0.6
1 50.0 5.936 0.516171 4.9 5.600 5.9 6.3 7.0 50.0 2.770 0.313798 2.0 2.525 2.8 3.000 3.4 50.0 4.260 0.469911 3.0 4.0 4.35 4.600 5.1 50.0 1.326 0.197753 1.0 1.2 1.3 1.5 1.8
2 50.0 6.588 0.635880 4.9 6.225 6.5 6.9 7.9 50.0 2.974 0.322497 2.2 2.800 3.0 3.175 3.8 50.0 5.552 0.551895 4.5 5.1 5.55 5.875 6.9 50.0 2.026 0.274650 1.4 1.8 2.0 2.3 2.5

從資料上看,三個種類之間,petal_length和petal_width的差異比較大,用它來畫圖。

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors =["red","yellow","blue"]
marker = ["o","*","+"]
for k, col,mark in zip(range(3), colors,marker):
    sub_data = iris_df.query("target==%s"%k)
    plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col,
             markeredgecolor='k', markersize=5)
plt.show()      
MeanShift聚類-02python案例

可以看到紅色點和其餘點相差很多,藍色和黃色有部分點交錯在一起

預設參數進行聚類

# ms = MeanShift( bin_seeding=True,cluster_all=False)
bandwidth = 0.726
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(iris_df[["petal_length", "petal_width"]])
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors = ["yellow", "red", "blue"]
marker = ["o", "*", "+"]
for k, col, mark in zip(range(n_clusters_), colors, marker):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(iris_df[my_members].petal_length,
             iris_df[my_members].petal_width,
             ".",
             markerfacecolor=col,
             markeredgecolor='k',
             markersize=6)
    plt.plot(cluster_center[0],
             cluster_center[1],
             'o',
             markerfacecolor=col,
             markeredgecolor='k',
             markersize=14)
    circle = plt.Circle((cluster_center[0], cluster_center[1]),
                        bandwidth,
                        color='black',
                        fill=False)
    plt.gcf().gca().add_artist(circle)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()      
number of estimated clusters : 3      
MeanShift聚類-02python案例

從圖上看,紅色部分自成一派,聚類效果就好,藍黃兩類互有交叉,以最靠近的類别中心來打label.

estimate_bandwidth方法

根據聚類的原始資料,生成建議的bandwidth,基礎邏輯:

  • 先抽樣,擷取部分樣本
  • 計算這樣樣本和所有點的最大距離
  • 對距離求平均

從邏輯上看,更像是找一個較大的距離,使得能涵蓋更多的點

estimate_bandwidth(iris_df[["petal_length", "petal_width"]])      
0.7266371274126329      

計算距離,check下

from sklearn.neighbors import      
nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)
nbrs.fit(iris_df.iloc[:,[2,3]])      
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
                 metric_params=None, n_jobs=-1, n_neighbors=150, p=2,
                 radius=1.0)      
d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)      
from functools import reduce #python 3
total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())      
from scipy import      
stats.describe(total_distance)      
DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)      
pd.DataFrame({"total_distance":total_distance}).describe()      
total_distance
count 22350.000000
mean 2.185682
std 1.617862
min 0.000000
25% 0.640312
50% 1.941649
75% 3.544009
max 6.262587

從資料上看,有點接近25%分位數。