天天看点

深度剖析Apriori算法

Apriori算法是第一个关联规则挖掘算法,也是最经典的算法。它利用逐层搜索的迭代方法找出数据库中项集的关系,以形成规则,其过程由连接(类矩阵运算)与剪枝(去掉那些没必要的中间结果)组成。

Apriori算法是关联规则中常用的一种算法。该算法主要包含两个步骤:首先找出数据集中所有的频繁项集,这些项集出现的频繁性要大于或等于最小支持度,然后根据频繁项集产生强关联规则,这些规则必须满足最小支持度和最小置信度。

上面提到了最小支持度和最小置信度,事实上,在关联规则中用于度量规则质量的两个主要指标即为支持度和置信度。那么,什么是支持度和置信度呢?接下来进行讲解。

给定关联规则X=>Y,即根据X推出Y。形式化定义为:

深度剖析Apriori算法

算法步骤:

  1. 找出出现频率最大的一个项L1。
  2. 根据L1找频繁“2项集”的集合C2.
  3. 并剪掉不满足支持度阈值的项,得到L2。
  4. 根据L2找频繁“3项集”的集合C3。
  5. 根据性质和支持度阈值进行剪枝,得到L3。
  6. 循环上述过程,直到得到空集C,即直到不能发现更大的频集L。
  7. 计算最大频集L的非空子集,两两计算置信度,得到大于置信度阈值的强关联规则。

    Apriori性质:一个频繁项集的任一子集也应该是频繁项集。也就是,生成一个k-itemset的候选项时,如果这个候选项有子集不在(k-1)-itemset(已经确定是frequent的)中时,那么这个候选项就不用拿去和支持度判断了,直接删除。

举个栗子:

1. 首先看一下我们的数据,假设给定如下电子商务网站的用户交易数据集。

![](http://cookdata.cn/media/bbs/images/2020_09_11_16_52_IMG_6697_[原始大小]_1599825711294_5d14.jpg =700x*)

代码如下(这里我用函数进行封装,便于之后的操作)

# 返回为dict类型,代表不同用户购买了哪些商品,key为str类型,代表不同用户,value为frozenset类型,其中的元素代表不同的商品。
def load_example():
    data = {
        'user1': ['I1', 'I2', 'I5'],
        'user2': ['I2', 'I4'],
        'user3': ['I2', 'I3'],
        'user4': ['I1', 'I2', 'I4'],
        'user5': ['I1', 'I3'],
        'user6': ['I2', 'I3'],
        'user7': ['I1', 'I3'],
        'user8': ['I1', 'I2', 'I3', 'I5'],
        'user9': ['I1', 'I2', 'I3']
    }
    return {i: frozenset(data[i]) for i in data}      
2. 计算频繁1项集。扫描交易数据集,统计每种商品出现的次数,选取大于或等于最小支持度的商品,得到了候选项集。
深度剖析Apriori算法

代码如下

# 函数返回为dict类型,key为frozenset类型,表示每个商品名称,value为int类型,表示该商品出现的次数。
# data: 商品清单,类型为dict,key为str类型,代表用户名称,value为frozenset类型,代表该用户所购买的商品名称。
def freq_one(data):
    freq_1 = {}
    for item in data:
        for record in data[item]:
            if frozenset([record]) in freq_1:
                freq_1[frozenset([record])] += 1
            else:
                freq_1[frozenset([record])] = 1
    return {v: freq_1[v] for v in freq_1 if freq_1[v] >= 2}      
3. 计算频繁k项集(k>=2)。根据频繁1项集,计算频繁2项集。首先将频繁1项集和频繁1项集进行连接运算,得到2项集。

![](http://cookdata.cn/media/bbs/images/2020_09_11_16_52_IMG_6699_[原始大小]_1599825756996_5d14.jpg =350x*)

扫描用户交易数据集,计算包含每个候选2项集的记录数。

![](http://cookdata.cn/media/bbs/images/2020_09_11_16_52_IMG_6700_[原始大小]_1599825802444_5d14.jpg =600x*)

这里呢我定义最小支持度为2/9,即支持度计数为2,根据最小支持度,得到频繁2项集,如图所示。

![](http://cookdata.cn/media/bbs/images/2020_09_11_16_52_IMG_6701_[原始大小]_1599825818128_5d14.jpg =600x*)

根据频繁2项集,再计算频繁3项集。首先将频繁2项集进行连接,得到{{I1,I2,I3},{I1,I2,I5},{I1,I3,I5},{I2,I3,I4},{I2,I3,I5},{I2,I4,I5}},然后根据频繁项集性质进行剪枝(第一种剪枝),即频繁项集的非空子集必须是频繁的。

{I1,I2,I3}的2项子集为{I1,I2},{I1,I3},{I2,I3},都在频繁2项集中则保留;

{I1,I2,I5}的2项子集为{I1,I2},{I1,I5},{I2,I5},都在频繁2项集中则保留;

{I1,I3,I5}的2项子集为{I1,I3},{I1,I5},{I3,I5},由于{I3,I5}不是频繁2项集,移除该候选集;

{I2,I3,I4}的2项子集为{I2,I3},{I2,I4},{I3,I4},由于{I3,I4}不是频繁2项集,移除该候选集;

{I2,I3,I5}的2项子集为{I2,I3},{I2,I5},{I3,I5},由于{I3,I5}不是频繁2项集,移除该候选集;

{I2,I4,I5}的2项子集为{I2,I4},{I2,I5},{I4,I5},由于{I4,I5}不是频繁2项集,移除该候选集。

通过剪枝,得到候选集{{I1,I2,I3},{I1,I2,I5}},扫描交易数据库,计算包含候选3项集的记录数(第二种阈值剪枝)

根据频繁3项集,计算频繁4项集。重复上述的思路,得到{I1,I2,I3,I5},根据频繁项集定理,它的子集{I2,I3,I5}为非频繁项集,所以移除该候选集。从而,频繁4项集为空,至此,计算频繁项集的步骤结束。

代码如下(这里我先定义两个函数,这两个函数将会在计算频繁k项集中用到。第一个函数是用来得到所有k项子集,第二个是用来得到所有非空子集并剔除自身)

# 函数返回为list类型,每一个元素都是k项子集。
# item是一个可迭代对象。
# k是想要得到的第几项集(k要小于等于item的长度),类型为int。
def get_subset(item, k):
    import itertools as its
    return [frozenset(item) for item in its.combinations(item, k)]      
# 函数返回为list类型,每一个元素也为一个list。
# item是一个可迭代对象。
def get_all_subsets_not_self(item):
    subsets = []
    for i in range(1, len(item)):
        subsets.append(get_subset(item, i))
    return subsets      

然后定义计算频繁k项集的函数,如果我们想要得到频繁2项集,则函数第二个参数​

​frequent_k_minus_one​

​​要传入频繁1项集,参数k要传入2;同理如果我们想要得到频繁3项集,则函数第二个参数​

​frequent_k_minus_one​

​要传入频繁2项集,参数k要传入3。

# 函数返回为dict类型,key为frozenset类型,表示每几个的商品名称,value为int类型,表示这几个商品出现的次数。
# data是商品清单,类型为dict,key为str类型,代表用户名称,value为frozenset类型,代表该用户所购买的商品名称。
# frequent_k_minus_one是频繁k-1项集,类型为dict。
# k是生成的第k项集(k>=2),类型为int。
# min_support是最小支持度,默认为2,类型为int。
def freq_k(data, frequent_k_minus_one, k, min_support=2):
    # 连接步,生成k项候选集
    items = frequent_k_minus_one.keys()
    candidate_items = [m.union(n) for m in items for n in items if m != n and len(m.union(n)) == k]
    # 剪枝步,剔除不能成为频繁k项集的候选集
    final_candidate = set()
    for candidate in candidate_items:
        if set(items) > set(get_subset(candidate, (k - 1))):
            final_candidate.add(candidate)
    # 遍历数据集data,对final_candidate中的元素进行统计
    current_k = dict()
    for record in data.items():
        for item in final_candidate:
            if item.issubset(record[1]):
                if item in current_k:
                    current_k[item] += 1
                else:
                    current_k[item] = 1
    # 返回支持度大于最小阈值的频繁项
    return {v: current_k[v] for v in current_k if current_k[v] >= min_support}      

同时,我还定义了两个函数,第一个可以直接生成最终频繁项集,第二个可以生成所有频繁项集,代码如下所示。

# 函数返回为tuple类型,其中,第一个元素为int类型,代表最终生成的频繁第几项集,第二个元素为set类型,代表最终生成的频繁项集元素。
# data是商品清单,类型为dict,key为str类型,代表用户名称,value为frozenset类型,代表该用户所购买的商品名称。
# freq_one是频繁1项集,类型为dict。
# min_support是最小支持度,默认为2,类型为int。
def frequent_final(data, freq_one, min_support=2):
    frequent_k = freq_one
    k = 2
    while True:
        frequent_k = freq_k(data=data, frequent_k_minus_notallow=frequent_k, k=k, min_support=min_support)
        k += 1
        if not (freq_k(data=data, frequent_k_minus_notallow=frequent_k, k=k, min_support=min_support)):
            break
    return (k - 1, frequent_k)      
# 函数返回为list类型,每一个元素为dict类型。
# data是商品清单,类型为dict,key为str类型,代表用户名称,value为frozenset类型,代表该用户所购买的商品名称。
# freq_one是频繁1项集,类型为dict。
# min_support是最小支持度,默认为2,类型为int。
def frequent_all(data, freq_one, min_support=2):
    frequent_k = freq_one
    frequent_list = [frequent_k]
    k = 2

    while True:
        frequent_k = freq_k(data=data, frequent_k_minus_notallow=frequent_k, k=k, min_support=min_support)
        frequent_list.append(frequent_k)
        k += 1
        if not (freq_k(data=data, frequent_k_minus_notallow=frequent_k, k=k, min_support=min_support)):
            break
    return frequent_list      
4. 我设置最小置信度为60%,即0.6,根据频繁项集,计算关联规则。这里以频繁3项集{I1,I2,I5}为例,计算关联规则。{I1,I2,I5}的非空子集为{I1,I2}、{I1,I5}、{I2,I5}、{I1}、{I2}和{I5}。规则1,{I1,I2}=>{I5},置信度为{I1,I2,I5}的支持度除以{I1,I2}的支持度,即2/4=50%,因其小于最小置信度,所以删除该规则。同理,最后可以得到{I1,I5}=>{I2},{I2,I5}=>{I1}和{I5}=>{I1,I2}为3条强关联规则。

代码如下

# 函数返回为list类型,每一个元素为一个tuple,tuple中第一个元素为str类型,代表由A->B的强关联商品,第二个元素为float类型,代表由商品A->B的置信度。
# data是商品清单,类型为dict,key为str类型,代表用户名称,value为frozenset类型,代表该用户所购买的商品名称。
# freq_one是频繁1项集,类型为dict。
# min_support是最小支持度,默认为2,类型为int。
# min_conf是最小置信度,默认为0.6,类型为float。
def rule(data, freq_one, min_support=2, min_cnotallow=0.6): 
    strong_rule = []
    final_itemsets = frequent_all(data=data, freq_notallow=freq_one, min_support=min_support)
    for item_set in final_itemsets[1:]:
        for AB in item_set:
            for A in [j for i in get_all_subsets_not_self(AB) for j in i]:
                B = frozenset([i for i in AB if i not in A])
                if len(B) > 0:
                    x = float(final_itemsets[len(AB) - 1][AB])
                    y = float(final_itemsets[len(A) - 1][A])
                    confidence = x / y
                    if confidence >= min_confidence:
                        strong_rules.append((str(A) + ' -> ' + str(B), confidence))
    return strong_rules      

至此我们就得到了商品之间的关联关系。最后我就可以使用以上定义好的函数直接调用得到关联规则,代码只需一行,如下所示:

print(rule(data=load_example(), freq_notallow=freq_one(data=load_example()),min_support=2, min_cnotallow=0.6))      

运行结果如下:

深度剖析Apriori算法

继续阅读