天天看点

推荐系统笔记-04-用户基于ALS模型召回集

1、相关环境

hadoop-2.10.0

hive-3.1.2

hbase-2.2.2

spark-2.4.4

2、相关表结构

HBase

create 'multiple_recall', {NAME=>'als', TTL=>2592000, VERSIONS=>9999}


create 'history_recall', {NAME=>'channel', TTL=>2592000, VERSIONS=>9999}
           

3、相关Python实现

# -*- coding:utf-8 -*-

import os
import sys
import numpy as np
from datetime import datetime
BASE_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(BASE_PATH))
print sys.path
from offline import BaseSparkSession

default_encoding = 'utf-8'
if sys.getdefaultencoding() != default_encoding:
    reload(sys)
    sys.setdefaultencoding(default_encoding)

os.environ['PYSPARK_PYTHON'] = 'F:\develop\python\Python27\python.exe'
os.environ['HADOOP_HOME'] = 'F:\develop\hadoop\hadoop-2.10.0'
os.environ['HADOOP_CONF_DIR'] = 'F:\develop\hadoop\hadoop-2.10.0-conf'
os.environ['SPARK_HOME'] = 'F:\develop\spark\spark-2.4.4-bin-hadoop2.7'


class UserRecallBasedModel(BaseSparkSession):

    def __init__(self):
        self.SPARK_APP_NAME = 'user_recall_based_model'
        self.SPARK_MASTER_URL = 'yarn'
        self.SPARK_YARN_QUEUE = 'queue3'
        self.ENABLE_HIVE_SUPPORT = True
        self.spark_session = self.create_spark_session()

    # 生成用户召回结果
    def gen_user_recall(self):
        self.spark_session.sql("use portal")
        # 用户文章点击行为
        user_click_behaviour_df = self.spark_session.sql("select user_id, article_id, click_flag from t_user_behaviour")
        user_click_behaviour_df.show()

        # 生成用户ID索引 文章ID索引
        from pyspark.ml.feature import StringIndexer
        from pyspark.ml import Pipeline
        user_id_indexer = StringIndexer(inputCol="user_id", outputCol="indexer_user_id")
        article_id_indexer = StringIndexer(inputCol="article_id", outputCol="indexer_article_id")
        pipline = Pipeline(stages=[user_id_indexer, article_id_indexer])
        pipline_model = pipline.fit(user_click_behaviour_df)
        user_article_click_df = pipline_model.transform(user_click_behaviour_df)
        user_article_click_df.show()

        # 用户ID索引映射 文章ID索引映射
        user_id_mapping_df = user_article_click_df.groupBy(["user_id"]) \
            .max("indexer_user_id").withColumnRenamed("max(indexer_user_id)", "indexer_user_id")
        user_id_mapping_df.show()
        article_id_mapping_df = user_article_click_df.groupBy(["article_id"]) \
            .max("indexer_article_id").withColumnRenamed("max(indexer_article_id)", "indexer_article_id")
        article_id_mapping_df.show()

        # ALS模型推荐文章
        from pyspark.ml.recommendation import ALS
        als = ALS(userCol="indexer_user_id", itemCol="indexer_article_id", ratingCol="click_flag")
        als_model = als.fit(user_article_click_df)
        als_result = als_model.recommendForAllUsers(20)
        als_result.show()

        # 获取真实用户ID
        als_recall_df = als_result.join(user_id_mapping_df, on=["indexer_user_id"], how="left") \
            .select(["user_id", "indexer_user_id", "recommendations"])
        als_recall_df.show()

        # 获取真实文章ID
        import pyspark.sql.functions as F
        als_recall_df = als_recall_df.withColumn('indexer_article', F.explode("recommendations")) \
            .drop("recommendations").select(["user_id", "indexer_article"])
        als_recall_df = als_recall_df.rdd.map(lambda row: (row.user_id, row.indexer_article[0])) \
            .toDF(["user_id", "indexer_article_id"])
        als_recall_df.show()

        als_recall_df = als_recall_df.join(article_id_mapping_df, on=["indexer_article_id"], how="left") \
            .select(["user_id", "article_id"])
        als_recall_df.show()

        # 获取文章频道ID
        article_channel_df = self.spark_session.sql("select id as article_id, channel_id from t_article")
        als_recall_df = als_recall_df.join(article_channel_df, on=["article_id"], how="left")
        als_recall_df = als_recall_df.groupBy(["user_id", "channel_id"]).agg(F.collect_list("article_id")) \
            .withColumnRenamed("collect_list(article_id)", "articles").dropna()
        als_recall_df.show()

        # 用户召回数据存储HBase
        def insert_user_recall(partition):

            sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
            import json
            from recoutils.hbase_utils import HBaseUtils
            hbase_utils = HBaseUtils(host="192.168.0.1", port=9090, size=5)

            for row in partition:
                # 获取历史召回数据
                history_recall_list = hbase_utils.read_cells("history_recall",
                                                             "u:{}".format(row.user_id).encode(),
                                                             "channel:{}".format(row.channel_id).encode())
                history_recommend_list = []
                for history_recall in history_recall_list:
                    history_recommend_list.extend(eval(history_recall))

                # 过滤历史召回数据
                recommend_list = list(set(row.articles) - set(history_recommend_list))

                # 存储召回数据和历史召回数据
                if recommend_list:
                    hbase_utils.insert("multiple_recall",
                                       "u:{}".format(row.user_id).encode(),
                                       {"als:{}".format(row.channel_id).encode(): json.dumps(recommend_list).encode()})
                    hbase_utils.insert("history_recall",
                                       "u:{}".format(row.user_id).encode(),
                                       {"channel:{}".format(row.channel_id).encode(): json.dumps(recommend_list).encode()})

        als_recall_df.foreachPartition(insert_user_recall)


if __name__ == '__main__':
    user_recall_based_model = UserRecallBasedModel()
    user_recall_based_model.gen_user_recall()