天天看点

MeanShift聚类-02python案例

Intro

  Meanshift的使用案例~

数据引入

from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline      
from sklearn.datasets import load_iris
import pandas as pd
pd.set_option('display.max_rows', 500) # 打印最大行数
pd.set_option('display.max_columns', 500) # 打印最大列数      
# 检查是否是array格式,如果不是,转换成array
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import      
iris_df = pd.DataFrame(
    load_iris()["data"],
    columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
iris_df["target"] = load_iris()["target"]
iris_df.head()      
sepal_length sepal_width petal_length petal_width target
5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
iris_df.groupby(by="target").describe()      
sepal_length sepal_width petal_length petal_width
count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max count mean std min 25% 50% 75% max
target
50.0 5.006 0.352490 4.3 4.800 5.0 5.2 5.8 50.0 3.428 0.379064 2.3 3.200 3.4 3.675 4.4 50.0 1.462 0.173664 1.0 1.4 1.50 1.575 1.9 50.0 0.246 0.105386 0.1 0.2 0.2 0.3 0.6
1 50.0 5.936 0.516171 4.9 5.600 5.9 6.3 7.0 50.0 2.770 0.313798 2.0 2.525 2.8 3.000 3.4 50.0 4.260 0.469911 3.0 4.0 4.35 4.600 5.1 50.0 1.326 0.197753 1.0 1.2 1.3 1.5 1.8
2 50.0 6.588 0.635880 4.9 6.225 6.5 6.9 7.9 50.0 2.974 0.322497 2.2 2.800 3.0 3.175 3.8 50.0 5.552 0.551895 4.5 5.1 5.55 5.875 6.9 50.0 2.026 0.274650 1.4 1.8 2.0 2.3 2.5

从数据上看,三个种类之间,petal_length和petal_width的差异比较大,用它来画图。

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors =["red","yellow","blue"]
marker = ["o","*","+"]
for k, col,mark in zip(range(3), colors,marker):
    sub_data = iris_df.query("target==%s"%k)
    plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col,
             markeredgecolor='k', markersize=5)
plt.show()      
MeanShift聚类-02python案例

可以看到红色点和其余点相差很多,蓝色和黄色有部分点交错在一起

默认参数进行聚类

# ms = MeanShift( bin_seeding=True,cluster_all=False)
bandwidth = 0.726
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(iris_df[["petal_length", "petal_width"]])
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors = ["yellow", "red", "blue"]
marker = ["o", "*", "+"]
for k, col, mark in zip(range(n_clusters_), colors, marker):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(iris_df[my_members].petal_length,
             iris_df[my_members].petal_width,
             ".",
             markerfacecolor=col,
             markeredgecolor='k',
             markersize=6)
    plt.plot(cluster_center[0],
             cluster_center[1],
             'o',
             markerfacecolor=col,
             markeredgecolor='k',
             markersize=14)
    circle = plt.Circle((cluster_center[0], cluster_center[1]),
                        bandwidth,
                        color='black',
                        fill=False)
    plt.gcf().gca().add_artist(circle)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()      
number of estimated clusters : 3      
MeanShift聚类-02python案例

从图上看,红色部分自成一派,聚类效果就好,蓝黄两类互有交叉,以最靠近的类别中心来打label.

estimate_bandwidth方法

根据聚类的原始数据,生成建议的bandwidth,基础逻辑:

  • 先抽样,获取部分样本
  • 计算这样样本和所有点的最大距离
  • 对距离求平均

从逻辑上看,更像是找一个较大的距离,使得能涵盖更多的点

estimate_bandwidth(iris_df[["petal_length", "petal_width"]])      
0.7266371274126329      

计算距离,check下

from sklearn.neighbors import      
nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)
nbrs.fit(iris_df.iloc[:,[2,3]])      
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
                 metric_params=None, n_jobs=-1, n_neighbors=150, p=2,
                 radius=1.0)      
d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)      
from functools import reduce #python 3
total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())      
from scipy import      
stats.describe(total_distance)      
DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)      
pd.DataFrame({"total_distance":total_distance}).describe()      
total_distance
count 22350.000000
mean 2.185682
std 1.617862
min 0.000000
25% 0.640312
50% 1.941649
75% 3.544009
max 6.262587

从数据上看,有点接近25%分位数。