天天看點

推薦系統lambda架構學習筆記之電商推薦案例(六)

一、個性化電商廣告推薦系統介紹

1.1資料集介紹

  • Ali_Display_Ad_Click是阿裡巴巴提供的一個淘寶展示廣告點選率預估資料集

    資料集來源:天池競賽

  • 原始樣本骨架raw_sample

    淘寶網站中随機抽樣了114萬使用者8天内的廣告展示/點選日志(2600萬條記錄),構成原始的樣本骨架。 字段說明如下:

    1. user_id:脫敏過的使用者ID;
    2. adgroup_id:脫敏過的廣告單元ID;
    3. time_stamp:時間戳;
    4. pid:資源位;
    5. noclk:為1代表沒有點選;為0代表點選;
    6. clk:為0代表沒有點選;為1代表點選;
    用前面7天的做訓練樣本(20170506-20170512),用第8天的做測試樣本(20170513)
  • 廣告基本資訊表ad_feature

    本資料集涵蓋了raw_sample中全部廣告的基本資訊(約80萬條目)。字段說明如下:

    1. adgroup_id:脫敏過的廣告ID;
    2. cate_id:脫敏過的商品類目ID;
    3. campaign_id:脫敏過的廣告計劃ID;
    4. customer_id: 脫敏過的廣告主ID;
    5. brand_id:脫敏過的品牌ID;
    6. price: 寶貝的價格
    其中一個廣告ID對應一個商品(寶貝),一個寶貝屬于一個類目,一個寶貝屬于一個品牌。
  • 使用者基本資訊表user_profile

    本資料集涵蓋了raw_sample中全部使用者的基本資訊(約100多萬使用者)。字段說明如下:

    1. userid:脫敏過的使用者ID;
    2. cms_segid:微群ID;
    3. cms_group_id:cms_group_id;
    4. final_gender_code:性别 1:男,2:女;
    5. age_level:年齡層次; 1234
    6. pvalue_level:消費檔次,1:低檔,2:中檔,3:高檔;
    7. shopping_level:購物深度,1:淺層使用者,2:中度使用者,3:深度使用者
    8. occupation:是否大學生 ,1:是,0:否
    9. new_user_class_level:城市層級
  • 使用者的行為日志behavior_log

    本資料集涵蓋了raw_sample中全部使用者22天内的購物行為(共七億條記錄)。字段說明如下:

    user:脫敏過的使用者ID;

    time_stamp:時間戳;

    btag:行為類型, 包括以下四種:

    ​ 類型 | 說明

    ​ pv | 浏覽

    ​ cart | 加入購物車

    ​ fav | 喜歡

    ​ buy | 購買

    cate_id:脫敏過的商品類目id;

    brand_id: 脫敏過的品牌id;

    這裡以user + time_stamp為key,會有很多重複的記錄;這是因為我們的不同的類型的行為資料是不同部門記錄的,在打包到一起的時候,實際上會有小的偏差(即兩個一樣的time_stamp實際上是差異比較小的兩個時間)

1.2 項目效果展示

推薦系統lambda架構學習筆記之電商推薦案例(六)

1.3 項目實作分析

  • 主要包括
    • 一份廣告點選的樣本資料raw_sample.csv:展現的是使用者對不同位置廣告點選、沒點選的情況
    • 一份廣告基本資訊資料ad_feature.csv:展現的是每個廣告的類目(id)、品牌(id)、價格特征
    • 一份使用者基本資訊資料user_profile.csv:展現的是使用者群組、性别、年齡、消費購物檔次、所在城市級别等特征
    • 一份使用者行為日志資料behavior_log.csv:展現使用者對商品類目(id)、品牌(id)的浏覽、加購物車、收藏、購買等資訊
    我們是在對非搜尋類型的廣告進行點選率預測和推薦(沒有搜尋詞、沒有廣告的内容特征資訊)
    1. 推薦業務處理主要流程: 召回 ===> 排序 ===> 過濾
      • 離線處理業務流
        • raw_sample.csv ==> 曆史樣本資料
        • ad_feature.csv ==> 廣告特征資料
        • user_profile.csv ==> 使用者特征資料
        • raw_sample.csv + ad_feature.csv + user_profile.csv ==> CTR點選率預測模型
        • behavior_log.csv ==> 評分資料 ==> user-cate/brand評分資料 ==> 協同過濾 ==> top-N cate/brand ==> 關聯廣告
        • 協同過濾召回 ==> top-N cate/brand ==> 關聯對應的廣告完成召回
      • 線上處理業務流
        • 資料處理部分:
          • 實時行為日志 ==> 實時特征 ==> 緩存
          • 實時行為日志 ==> 實時商品類别/品牌 ==> 實時廣告召回集 ==> 緩存
        • 推薦任務部分:
          • CTR點選率預測模型 + 廣告/使用者特征(緩存) + 對應的召回集(緩存) ==> 點選率排序 ==> top-N 廣告推薦結果
    2. 涉及技術:Flume、Kafka、Spark-streming\HDFS、Spark SQL、Spark ML、Redis
      • Flume:日志資料收集
      • Kafka:實時日志資料處理隊列
      • HDFS:存儲資料
      • Spark SQL:離線處理
      • Spark ML:模型訓練
      • Redis:緩存

1.4 點選率預測(CTR–Click-Through-Rate)概念

  • 電商廣告推薦通常使用廣告點選率(CTR–Click-Through-Rate)預測來實作

    點選率預測 VS 推薦算法

    點選率預測需要給出精準的點選機率,比如廣告A點選率0.5%、廣告B的點選率0.12%等;而推薦算法很多時候隻需要得出一個最優的次序A>B>C即可。

    點選率預測使用的算法通常是如邏輯回歸(Logic Regression)這樣的機器學習算法,而推薦算法則是一些基于協同過濾推薦、基于内容的推薦等思想實作的算法

    點選率 VS 轉化率

    點選率預測是對每次廣告的點選情況做出預測,可以判定這次為點選或不點選,也可以給出點選或不點選的機率

    轉化率指的是從狀态A進入到狀态B的機率,電商的轉化率通常是指到達網站後,進而有成交記錄的使用者比率,如使用者成交量/使用者通路量

    搜尋和非搜尋廣告點選率預測的差別

    搜尋中有很強的搜尋信号-“查詢詞(Query)”,查詢詞和廣告内容的比對程度很大程度影響了點選機率,搜尋廣告的點選率普遍較高

    非搜尋廣告(例如展示廣告,資訊流廣告)的點選率的計算很多就來源于使用者的興趣和廣告自身的特征,以及上下文環境。通常好位置能達到百分之幾的點選率。對于很多底部的廣告,點選率非常低,常常是千分之幾,甚至更低

二、根據使用者行為資料建立ALS模型并召回商品

2.0 使用者行為資料拆分

  • Pandas讀取大檔案-分批處理
    • pandas的資料分批讀取 chunk 厚厚的一塊 相當大的數量或部分
    import pandas as pd
    reader = pd.read_csv('behavior_log.csv',chunksize=100,iterator=True)
    count = 0;
    for chunk in reader:
        count += 1
        if count ==1:
            chunk.to_csv('test4.csv',index = False)
        elif count>1 and count<1000:
            chunk.to_csv('test4.csv',index = False, mode = 'a',header = False)
        else:
            break
    pd.read_csv('test4.csv')
               

2.1 預處理behavior_log資料集

  • 建立spark session
import os
# 配置spark driver和pyspark運作時,所使用的python解釋器路徑
PYSPARK_PYTHON = "/miniconda2/envs/py365/bin/python"
JAVA_HOME='/root/bigdata/jdk'
SPARK_HOME = "/root/bigdata/spark"
# 當存在多個版本時,不指定很可能會導緻出錯
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
os.environ['JAVA_HOME']=JAVA_HOME
os.environ["SPARK_HOME"] = SPARK_HOME
# spark配置資訊
from pyspark import SparkConf
from pyspark.sql import SparkSession

SPARK_APP_NAME = "preprocessingBehaviorLog"
SPARK_URL = "spark://192.168.19.137:7077"

conf = SparkConf()    # 建立spark config對象
config = (
	("spark.app.name", SPARK_APP_NAME),    # 設定啟動的spark的app名稱,沒有提供,将随機産生一個名稱
	("spark.executor.memory", "6g"),    # 設定該app啟動時占用的記憶體用量,預設1g
	("spark.master", SPARK_URL),    # spark master的位址
    ("spark.executor.cores", "4"),    # 設定spark executor使用的CPU核心數
    # 以下三項配置,可以控制執行器數量
#     ("spark.dynamicAllocation.enabled", True),
#     ("spark.dynamicAllocation.initialExecutors", 1),    # 1個執行器
#     ("spark.shuffle.service.enabled", True)
# 	('spark.sql.pivotMaxValues', '99999'),  # 當需要pivot DF,且值很多時,需要修改,預設是10000
)
# 檢視更詳細配置及說明:https://spark.apache.org/docs/latest/configuration.html

conf.setAll(config)

# 利用config對象,建立spark session
spark = SparkSession.builder.config(conf=conf).getOrCreate()
           
  • 從hdfs中加載csv檔案為DataFrame
# 從hdfs加載CSV檔案為DataFrame
df = spark.read.csv("hdfs://localhost:9000/data/behavior_log.csv", header=True)
df.show()    # 檢視dataframe,預設顯示前20條
# 大緻檢視一下資料類型
df.printSchema()    # 列印目前dataframe的結構
           

顯示結果:

+------+----------+----+-----+------+
|  user|time_stamp|btag| cate| brand|
+------+----------+----+-----+------+
|558157|1493741625|  pv| 6250| 91286|
|558157|1493741626|  pv| 6250| 91286|
|558157|1493741627|  pv| 6250| 91286|
|728690|1493776998|  pv|11800| 62353|
|332634|1493809895|  pv| 1101|365477|
|857237|1493816945|  pv| 1043|110616|
|619381|1493774638|  pv|  385|428950|
|467042|1493772641|  pv| 8237|301299|
|467042|1493772644|  pv| 8237|301299|
|991528|1493780710|  pv| 7270|274795|
|991528|1493780712|  pv| 7270|274795|
|991528|1493780712|  pv| 7270|274795|
|991528|1493780712|  pv| 7270|274795|
|991528|1493780714|  pv| 7270|274795|
|991528|1493780765|  pv| 7270|274795|
|991528|1493780714|  pv| 7270|274795|
|991528|1493780765|  pv| 7270|274795|
|991528|1493780764|  pv| 7270|274795|
|991528|1493780633|  pv| 7270|274795|
|991528|1493780764|  pv| 7270|274795|
+------+----------+----+-----+------+
only showing top 20 rows

root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- btag: string (nullable = true)
 |-- cate: string (nullable = true)
 |-- brand: string (nullable = true)
           
  • 從hdfs加載資料為dataframe,并設定結構
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType
# 建構結構對象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("timestamp", LongType()),
    StructField("btag", StringType()),
    StructField("cateId", IntegerType()),
    StructField("brandId", IntegerType())
])
# 從hdfs加載資料為dataframe,并設定結構
behavior_log_df = spark.read.csv("hdfs://localhost:9000/data/behavior_log.csv", header=True, schema=schema)
behavior_log_df.show()
behavior_log_df.count() 
           

顯示結果:

+------+----------+----+------+-------+
|userId| timestamp|btag|cateId|brandId|
+------+----------+----+------+-------+
|558157|1493741625|  pv|  6250|  91286|
|558157|1493741626|  pv|  6250|  91286|
|558157|1493741627|  pv|  6250|  91286|
|728690|1493776998|  pv| 11800|  62353|
|332634|1493809895|  pv|  1101| 365477|
|857237|1493816945|  pv|  1043| 110616|
|619381|1493774638|  pv|   385| 428950|
|467042|1493772641|  pv|  8237| 301299|
|467042|1493772644|  pv|  8237| 301299|
|991528|1493780710|  pv|  7270| 274795|
|991528|1493780712|  pv|  7270| 274795|
|991528|1493780712|  pv|  7270| 274795|
|991528|1493780712|  pv|  7270| 274795|
|991528|1493780714|  pv|  7270| 274795|
|991528|1493780765|  pv|  7270| 274795|
|991528|1493780714|  pv|  7270| 274795|
|991528|1493780765|  pv|  7270| 274795|
|991528|1493780764|  pv|  7270| 274795|
|991528|1493780633|  pv|  7270| 274795|
|991528|1493780764|  pv|  7270| 274795|
+------+----------+----+------+-------+
only showing top 20 rows

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- btag: string (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
           
  • 分析資料集字段的類型和格式
    • 檢視是否有空值
    • 檢視每列資料的類型
    • 檢視每列資料的類别情況
print("檢視userId的資料情況:", behavior_log_df.groupBy("userId").count().count())
# 約113w使用者
# 注意:behavior_log_df.groupBy("userId").count()  傳回的是一個dataframe,這裡的count計算的是每一個分組的個數,但目前還沒有進行計算
# 當調用df.count()時才開始進行計算,這裡的count計算的是dataframe的條目數,也就是共有多少個分組
           
檢視user的資料情況: 1136340
           
print("檢視btag的資料情況:", behavior_log_df.groupBy("btag").count().collect())    # collect會把計算結果全部加載到記憶體,謹慎使用
# 隻有四種類型資料:pv、fav、cart、buy
# 這裡由于類型隻有四個,是以直接使用collect,把資料全部加載出來
           
print("檢視cateId的資料情況:", behavior_log_df.groupBy("cateId").count().count())
# 約12968類别id
           
檢視cateId的資料情況: 12968
           
print("檢視brandId的資料情況:", behavior_log_df.groupBy("brandId").count().count())
# 約460561品牌id
           
檢視brandId的資料情況: 460561
           
print("判斷資料是否有空值:", behavior_log_df.count(), behavior_log_df.dropna().count())
# 約7億條目723268134 723268134
# 本資料集無空值條目,可放心處理
           
判斷資料是否有空值: 723268134 723268134
           
  • pivot透視操作,把某列裡的字段值轉換成行并進行聚合運算(pyspark.sql.GroupedData.pivot)
    • 如果透視的字段中的不同屬性值超過10000個,則需要設定spark.sql.pivotMaxValues,否則計算過程中會出現錯誤。文檔介紹。
# 統計每個使用者對各類商品的pv、fav、cart、buy數量
cate_count_df = behavior_log_df.groupBy(behavior_log_df.userId, behavior_log_df.cateId).pivot("btag",["pv","fav","cart","buy"]).count()
cate_count_df.printSchema()    # 此時還沒有開始計算
           

顯示效果:

root
 |-- userId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- pv: long (nullable = true)
 |-- fav: long (nullable = true)
 |-- cart: long (nullable = true)
 |-- buy: long (nullable = true)
           
  • 統計每個使用者對各個品牌的pv、fav、cart、buy數量并儲存結果
# 統計每個使用者對各個品牌的pv、fav、cart、buy數量
brand_count_df = behavior_log_df.groupBy(behavior_log_df.userId, behavior_log_df.brandId).pivot("btag",["pv","fav","cart","buy"]).count()
# brand_count_df.show()    # 同上
# 113w * 46w
# 由于運算時間比較長,是以這裡先将結果存儲起來,供後續其他操作使用
# 寫入資料時才開始計算
cate_count_df.write.csv("hdfs://localhost:9000/preprocessing_dataset/cate_count.csv", header=True)
brand_count_df.write.csv("hdfs://localhost:9000/preprocessing_dataset/brand_count.csv", header=True)
           

2.2 根據使用者對類目偏好打分訓練ALS模型

  • 根據您統計的次數 + 打分規則 ==> 偏好打分資料集 ==> ALS模型
# spark ml的模型訓練是基于記憶體的,如果資料過大,記憶體空間小,疊代次數過多的化,可能會造成記憶體溢出,報錯
# 設定Checkpoint的話,會把所有資料落盤,這樣如果異常退出,下次重新開機後,可以接着上次的訓練節點繼續運作
# 但該方法其實名額不治本,因為無法防止記憶體溢出,是以還是會報錯
# 如果資料量大,應考慮的是增加記憶體、或限制疊代次數和訓練資料量級等
spark.sparkContext.setCheckpointDir("hdfs://localhost:9000/checkPoint/")
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 建構結構對象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cateId", IntegerType()),
    StructField("pv", IntegerType()),
    StructField("fav", IntegerType()),
    StructField("cart", IntegerType()),
    StructField("buy", IntegerType())
])

# 從hdfs加載CSV檔案
cate_count_df = spark.read.csv("hdfs://localhost:9000/preprocessing_dataset/cate_count.csv", header=True, schema=schema)
cate_count_df.printSchema()
cate_count_df.first()    # 第一行資料
           

顯示結果:

root
 |-- userId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- pv: integer (nullable = true)
 |-- fav: integer (nullable = true)
 |-- cart: integer (nullable = true)
 |-- buy: integer (nullable = true)

Row(userId=1061650, cateId=4520, pv=2326, fav=None, cart=53, buy=None)
           
  • 處理每一行資料:r表示row對象
def process_row(r):
    # 處理每一行資料:r表示row對象
    
    # 偏好評分規則:
	#     m: 使用者對應的行為次數
    #     該偏好權重比例,次數上限僅供參考,具體數值應根據産品業務場景權衡
	#     pv: if m<=20: score=0.2*m; else score=4
	#     fav: if m<=20: score=0.4*m; else score=8
	#     cart: if m<=20: score=0.6*m; else score=12
	#     buy: if m<=20: score=1*m; else score=20
    
    # 注意這裡要全部設為浮點數,spark運算時對類型比較敏感,要保持資料類型都一緻
	pv_count = r.pv if r.pv else 0.0
	fav_count = r.fav if r.fav else 0.0
	cart_count = r.cart if r.cart else 0.0
	buy_count = r.buy if r.buy else 0.0

	pv_score = 0.2*pv_count if pv_count<=20 else 4.0
	fav_score = 0.4*fav_count if fav_count<=20 else 8.0
	cart_score = 0.6*cart_count if cart_count<=20 else 12.0
	buy_score = 1.0*buy_count if buy_count<=20 else 20.0

	rating = pv_score + fav_score + cart_score + buy_score
	# 傳回使用者ID、分類ID、使用者對分類的偏好打分
	return r.userId, r.cateId, rating
           
  • 傳回一個PythonRDD類型
# 傳回一個PythonRDD類型,此時還沒開始計算
cate_count_df.rdd.map(process_row).toDF(["userId", "cateId", "rating"])
           

顯示結果:

  • 使用者對商品類别的打分資料
# 使用者對商品類别的打分資料
# map傳回的結果是rdd類型,需要調用toDF方法轉換為Dataframe
cate_rating_df = cate_count_df.rdd.map(process_row).toDF(["userId", "cateId", "rating"])
# 注意:toDF不是每個rdd都有的方法,僅局限于此處的rdd
# 可通過該方法獲得 user-cate-matrix
# 但由于cateId字段過多,這裡運算量比很大,機器記憶體要求很高才能執行,否則無法完成任務
# 請謹慎使用

# 但好在我們訓練ALS模型時,不需要轉換為user-cate-matrix,是以這裡可以不用運作
# cate_rating_df.groupBy("userId").povit("cateId").min("rating")
# 使用者對類别的偏好打分資料
cate_rating_df
           

顯示結果:

DataFrame[userId: bigint, cateId: bigint, rating: double]
           
  • 通常如果USER-ITEM打分資料應該是通過一下方式進行處理轉換為USER-ITEM-MATRIX
推薦系統lambda架構學習筆記之電商推薦案例(六)

但這裡我們将使用的Spark的ALS模型進行CF推薦,是以注意這裡資料輸入不需要提前轉換為矩陣,直接是 USER-ITEM-RATE的資料

  • 基于Spark的ALS隐因子模型進行CF評分預測
    • ALS的意思是交替最小二乘法(Alternating Least Squares),是Spark2.*中加入的進行基于模型的協同過濾(model-based CF)的推薦系統算法。

      同SVD,它也是一種矩陣分解技術,對資料進行降維處理。

    • 詳細使用方法:pyspark.ml.recommendation.ALS
    • 注意:由于資料量巨大,是以這裡也不考慮基于記憶體的CF算法

      參考:為什麼Spark中隻有ALS

# 使用pyspark中的ALS矩陣分解方法實作CF評分預測
# 文檔位址:https://spark.apache.org/docs/2.2.2/api/python/pyspark.ml.html?highlight=vectors#module-pyspark.ml.recommendation
from pyspark.ml.recommendation import ALS   # ml:dataframe, mllib:rdd

# 利用打分資料,訓練ALS模型
# checkpointInterval 間隔多少次生成一次checkpoint
als = ALS(userCol='userId', itemCol='cateId', ratingCol='rating', checkpointInterval=5)

# 此處訓練時間較長
model = als.fit(cate_rating_df)
           
  • 模型訓練好後,調用方法進行使用,具體API檢視
# model.recommendForAllUsers(N) 給所有使用者推薦TOP-N個物品
ret = model.recommendForAllUsers(3)
# 由于是給所有使用者進行推薦,此處運算時間也較長
ret.show()
# 推薦結果存放在recommendations列中,
ret.select("recommendations").show()
           

顯示結果:

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   148|[[3347, 12.547271...|
|   463|[[1610, 9.250818]...|
|   471|[[1610, 10.246621...|
|   496|[[1610, 5.162216]...|
|   833|[[5607, 9.065482]...|
|  1088|[[104, 6.886987],...|
|  1238|[[5631, 14.51981]...|
|  1342|[[5720, 10.89842]...|
|  1580|[[5731, 8.466453]...|
|  1591|[[1610, 12.835257...|
|  1645|[[1610, 11.968531...|
|  1829|[[1610, 17.576496...|
|  1959|[[1610, 8.353473]...|
|  2122|[[1610, 12.652732...|
|  2142|[[1610, 12.48068]...|
|  2366|[[1610, 11.904813...|
|  2659|[[5607, 11.699315...|
|  2866|[[1610, 7.752719]...|
|  3175|[[3347, 2.3429515...|
|  3749|[[1610, 3.641833]...|
+------+--------------------+
only showing top 20 rows

+--------------------+
|     recommendations|
+--------------------+
|[[3347, 12.547271...|
|[[1610, 9.250818]...|
|[[1610, 10.246621...|
|[[1610, 5.162216]...|
|[[5607, 9.065482]...|
|[[104, 6.886987],...|
|[[5631, 14.51981]...|
|[[5720, 10.89842]...|
|[[5731, 8.466453]...|
|[[1610, 12.835257...|
|[[1610, 11.968531...|
|[[1610, 17.576496...|
|[[1610, 8.353473]...|
|[[1610, 12.652732...|
|[[1610, 12.48068]...|
|[[1610, 11.904813...|
|[[5607, 11.699315...|
|[[1610, 7.752719]...|
|[[3347, 2.3429515...|
|[[1610, 3.641833]...|
+--------------------+
only showing top 20 rows
           
  • model.recommendForUserSubset 給部分使用者推薦TOP-N個物品
# 注意:recommendForUserSubset API,2.2.2版本中無法使用
dataset = spark.createDataFrame([[1],[2],[3]])
dataset = dataset.withColumnRenamed("_1", "userId")
ret = model.recommendForUserSubset(dataset, 3)

# 隻給部分用推薦,運算時間短
ret.show()
ret.collect()    # 注意: collect會将所有資料加載到記憶體,慎用
           

顯示結果:

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|     1|[[1610, 25.4989],...|
|     3|[[5607, 13.665942...|
|     2|[[5579, 5.9051886...|
+------+--------------------+

[Row(userId=1, recommendations=[Row(cateId=1610, rating=25.498899459838867), Row(cateId=5737, rating=24.901548385620117), Row(cateId=3347, rating=20.736785888671875)]),
 Row(userId=3, recommendations=[Row(cateId=5607, rating=13.665942192077637), Row(cateId=1610, rating=11.770171165466309), Row(cateId=3347, rating=10.35690689086914)]),
 Row(userId=2, recommendations=[Row(cateId=5579, rating=5.90518856048584), Row(cateId=2447, rating=5.624575138092041), Row(cateId=5690, rating=5.2555742263793945)])]
           
  • transform中提供userId和cateId可以對打分進行預測,利用打分結果排序後
# transform中提供userId和cateId可以對打分進行預測,利用打分結果排序後,同樣可以實作TOP-N的推薦
model.transform
# 将模型進行存儲
model.save("hdfs://localhost:9000/models/userCateRatingALSModel.obj")
# 測試存儲的模型
from pyspark.ml.recommendation import ALSModel
# 從hdfs加載之前存儲的模型
als_model = ALSModel.load("hdfs://localhost:9000/models/userCateRatingALSModel.obj")
# model.recommendForAllUsers(N) 給使用者推薦TOP-N個物品
result = als_model.recommendForAllUsers(3)
result.show()
           

顯示結果:

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   148|[[3347, 12.547271...|
|   463|[[1610, 9.250818]...|
|   471|[[1610, 10.246621...|
|   496|[[1610, 5.162216]...|
|   833|[[5607, 9.065482]...|
|  1088|[[104, 6.886987],...|
|  1238|[[5631, 14.51981]...|
|  1342|[[5720, 10.89842]...|
|  1580|[[5731, 8.466453]...|
|  1591|[[1610, 12.835257...|
|  1645|[[1610, 11.968531...|
|  1829|[[1610, 17.576496...|
|  1959|[[1610, 8.353473]...|
|  2122|[[1610, 12.652732...|
|  2142|[[1610, 12.48068]...|
|  2366|[[1610, 11.904813...|
|  2659|[[5607, 11.699315...|
|  2866|[[1610, 7.752719]...|
|  3175|[[3347, 2.3429515...|
|  3749|[[1610, 3.641833]...|
+------+--------------------+
only showing top 20 rows
           
  • 召回到redis
import redis
host = "192.168.19.137"
port = 6379    
# 召回到redis
def recall_cate_by_cf(partition):
    # 建立redis 連接配接池
    pool = redis.ConnectionPool(host=host, port=port)
    # 建立redis用戶端
    client = redis.Redis(connection_pool=pool)
    for row in partition:
        client.hset("recall_cate", row.userId, [i.cateId for i in row.recommendations])
# 對每個分片的資料進行處理 #mapPartitions Transformation   map
# foreachPartition Action操作             foreachRDD
result.foreachPartition(recall_cate_by_cf)

# 注意:這裡這是召回的是使用者最感興趣的n個類别
# 總的條目數,檢視redis中總的條目數是否一緻
result.count()
           

顯示結果:

1136340
           

2.3 根據使用者對品牌偏好打分訓練ALS模型

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("userId", IntegerType()),
    StructField("brandId", IntegerType()),
    StructField("pv", IntegerType()),
    StructField("fav", IntegerType()),
    StructField("cart", IntegerType()),
    StructField("buy", IntegerType())
])
# 從hdfs加載預處理好的品牌的統計資料
brand_count_df = spark.read.csv("hdfs://localhost:9000/preprocessing_dataset/brand_count.csv", header=True, schema=schema)
# brand_count_df.show()
def process_row(r):
    # 處理每一行資料:r表示row對象
    
    # 偏好評分規則:
	#     m: 使用者對應的行為次數
    #     該偏好權重比例,次數上限僅供參考,具體數值應根據産品業務場景權衡
	#     pv: if m<=20: score=0.2*m; else score=4
	#     fav: if m<=20: score=0.4*m; else score=8
	#     cart: if m<=20: score=0.6*m; else score=12
	#     buy: if m<=20: score=1*m; else score=20
    
    # 注意這裡要全部設為浮點數,spark運算時對類型比較敏感,要保持資料類型都一緻
	pv_count = r.pv if r.pv else 0.0
	fav_count = r.fav if r.fav else 0.0
	cart_count = r.cart if r.cart else 0.0
	buy_count = r.buy if r.buy else 0.0

	pv_score = 0.2*pv_count if pv_count<=20 else 4.0
	fav_score = 0.4*fav_count if fav_count<=20 else 8.0
	cart_score = 0.6*cart_count if cart_count<=20 else 12.0
	buy_score = 1.0*buy_count if buy_count<=20 else 20.0

	rating = pv_score + fav_score + cart_score + buy_score
	# 傳回使用者ID、品牌ID、使用者對品牌的偏好打分
	return r.userId, r.brandId, rating
# 使用者對品牌的打分資料
brand_rating_df = brand_count_df.rdd.map(process_row).toDF(["userId", "brandId", "rating"])
# brand_rating_df.show()
           
  • 基于Spark的ALS隐因子模型進行CF評分預測
    • ALS的意思是交替最小二乘法(Alternating Least Squares),是Spark中進行基于模型的協同過濾(model-based CF)的推薦系統算法,也是目前Spark内唯一一個推薦算法。

      同SVD,它也是一種矩陣分解技術,但理論上,ALS在海量資料的處理上要優于SVD。

      更多了解:pyspark.ml.recommendation.ALS

      注意:由于資料量巨大,是以這裡不考慮基于記憶體的CF算法

      參考:為什麼Spark中隻有ALS

  • 使用pyspark中的ALS矩陣分解方法實作CF評分預測
# 使用pyspark中的ALS矩陣分解方法實作CF評分預測
# 文檔位址:https://spark.apache.org/docs/latest/api/python/pyspark.ml.html?highlight=vectors#module-pyspark.ml.recommendation
from pyspark.ml.recommendation import ALS

als = ALS(userCol='userId', itemCol='brandId', ratingCol='rating', checkpointInterval=2)
# 利用打分資料,訓練ALS模型
# 此處訓練時間較長
model = als.fit(brand_rating_df)
# model.recommendForAllUsers(N) 給使用者推薦TOP-N個物品
model.recommendForAllUsers(3).show()
# 将模型進行存儲
model.save("hdfs://localhost:9000/models/userBrandRatingModel.obj")
# 測試存儲的模型
from pyspark.ml.recommendation import ALSModel
# 從hdfs加載模型
my_model = ALSModel.load("hdfs://localhost:9000/models/userBrandRatingModel.obj")
my_model
# model.recommendForAllUsers(N) 給使用者推薦TOP-N個物品
my_model.recommendForAllUsers(3).first()
           

三、 CTR預估資料準備

3.1 分析并預處理raw_sample資料集

# 從HDFS中加載樣本資料資訊
df = spark.read.csv("hdfs://localhost:9000/data/raw_sample.csv", header=True)
df.show()    # 展示資料,預設前20條
df.printSchema()
           

顯示結果:

+------+----------+----------+-----------+------+---+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|
+------+----------+----------+-----------+------+---+
|581738|1494137644|         1|430548_1007|     1|  0|
|449818|1494638778|         3|430548_1007|     1|  0|
|914836|1494650879|         4|430548_1007|     1|  0|
|914836|1494651029|         5|430548_1007|     1|  0|
|399907|1494302958|         8|430548_1007|     1|  0|
|628137|1494524935|         9|430548_1007|     1|  0|
|298139|1494462593|         9|430539_1007|     1|  0|
|775475|1494561036|         9|430548_1007|     1|  0|
|555266|1494307136|        11|430539_1007|     1|  0|
|117840|1494036743|        11|430548_1007|     1|  0|
|739815|1494115387|        11|430539_1007|     1|  0|
|623911|1494625301|        11|430548_1007|     1|  0|
|623911|1494451608|        11|430548_1007|     1|  0|
|421590|1494034144|        11|430548_1007|     1|  0|
|976358|1494156949|        13|430548_1007|     1|  0|
|286630|1494218579|        13|430539_1007|     1|  0|
|286630|1494289247|        13|430539_1007|     1|  0|
|771431|1494153867|        13|430548_1007|     1|  0|
|707120|1494220810|        13|430548_1007|     1|  0|
|530454|1494293746|        13|430548_1007|     1|  0|
+------+----------+----------+-----------+------+---+
only showing top 20 rows

root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- adgroup_id: string (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: string (nullable = true)
 |-- clk: string (nullable = true)
           
  • 分析資料集字段的類型和格式
    • 檢視是否有空值
    • 檢視每列資料的類型
    • 檢視每列資料的類别情況
print("樣本資料集總條目數:", df.count())
# 約2600w
print("使用者user總數:", df.groupBy("user").count().count())
# 約 114w,略多餘日志資料中使用者數
print("廣告id adgroup_id總數:", df.groupBy("adgroup_id").count().count())
# 約85w
print("廣告展示位pid情況:", df.groupBy("pid").count().collect())
# 隻有兩種廣告展示位,占比約為六比四
print("廣告點選資料情況clk:", df.groupBy("clk").count().collect())
# 點和不點比率約: 1:20
           

顯示結果:

樣本資料集總條目數: 26557961
使用者user總數: 1141729
廣告id adgroup_id總數: 846811
廣告展示位pid情況: [Row(pid='430548_1007', count=16472898), Row(pid='430539_1007', count=10085063)]
廣告點選資料情況clk: [Row(clk='0', count=25191905), Row(clk='1', count=1366056)]
           
  • 使用dataframe.withColumn更改df列資料結構;使用dataframe.withColumnRenamed更改列名稱
# 更改表結構,轉換為對應的資料類型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType

# 列印df結構資訊
df.printSchema()   
# 更改df表結構:更改列類型和列名稱
raw_sample_df = df.\
    withColumn("user", df.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
    withColumn("time_stamp", df.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("pid", df.pid.cast(StringType())).\
    withColumn("nonclk", df.nonclk.cast(IntegerType())).\
    withColumn("clk", df.clk.cast(IntegerType()))
raw_sample_df.printSchema()
raw_sample_df.show()
           

顯示結果:

root
 |-- user: string (nullable = true)
 |-- time_stamp: string (nullable = true)
 |-- adgroup_id: string (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: string (nullable = true)
 |-- clk: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId|        pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644|        1|430548_1007|     1|  0|
|449818|1494638778|        3|430548_1007|     1|  0|
|914836|1494650879|        4|430548_1007|     1|  0|
|914836|1494651029|        5|430548_1007|     1|  0|
|399907|1494302958|        8|430548_1007|     1|  0|
|628137|1494524935|        9|430548_1007|     1|  0|
|298139|1494462593|        9|430539_1007|     1|  0|
|775475|1494561036|        9|430548_1007|     1|  0|
|555266|1494307136|       11|430539_1007|     1|  0|
|117840|1494036743|       11|430548_1007|     1|  0|
|739815|1494115387|       11|430539_1007|     1|  0|
|623911|1494625301|       11|430548_1007|     1|  0|
|623911|1494451608|       11|430548_1007|     1|  0|
|421590|1494034144|       11|430548_1007|     1|  0|
|976358|1494156949|       13|430548_1007|     1|  0|
|286630|1494218579|       13|430539_1007|     1|  0|
|286630|1494289247|       13|430539_1007|     1|  0|
|771431|1494153867|       13|430548_1007|     1|  0|
|707120|1494220810|       13|430548_1007|     1|  0|
|530454|1494293746|       13|430548_1007|     1|  0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows
           
  • 特征選取(Feature Selection)
    • 特征選擇就是選擇那些靠譜的Feature,去掉備援的Feature,對于搜尋廣告,Query關鍵詞和廣告的比對程度很重要;但對于展示廣告,廣告本身的曆史表現,往往是最重要的Feature。

      根據經驗,該資料集中,隻有廣告展示位pid對比較重要,且資料不同資料之間的占比約為6:4,是以pid可以作為一個關鍵特征

      nonclk和clk在這裡是作為目标值,不做為特征

  • 熱獨編碼 OneHotEncode
    • 熱獨編碼是一種經典編碼,是使用N位狀态寄存器(如0和1)來對N個狀态進行編碼,每個狀态都由他獨立的寄存器位,并且在任意時候,其中隻有一位有效。

      假設有三組特征,分别表示年齡,城市,裝置;

      [“男”, “女”][0,1]

      [“北京”, “上海”, “廣州”][0,1,2]

      [“蘋果”, “小米”, “華為”, “微軟”][0,1,2,3]

      傳統變化: 對每一組特征,使用枚舉類型,從0開始;

      ["男“,”上海“,”小米“]=[ 0,1,1]

      ["女“,”北京“,”蘋果“] =[1,0,0]

      傳統變化後的資料不是連續的,而是随機配置設定的,不容易應用在分類器中

      而經過熱獨編碼,資料會變成稀疏的,友善分類器處理:

      ["男“,”上海“,”小米“]=[ 1,0,0,1,0,0,1,0,0]

      ["女“,”北京“,”蘋果“] =[0,1,1,0,0,1,0,0,0]

      這樣做保留了特征的多樣性,但是也要注意如果資料過于稀疏(樣本較少、次元過高),其效果反而會變差

  • Spark中使用熱獨編碼
    • 注意:熱編碼隻能對字元串類型的列資料進行處理

      StringIndexer:對指定字元串列資料進行特征處理,如将性别資料“男”、“女”轉化為0和1

      OneHotEncoder:對特征列資料,進行熱編碼,通常需結合StringIndexer一起使用

      Pipeline:讓資料按順序依次被處理,将前一次的處理結果作為下一次的輸入

  • 特征處理
'''特征處理'''
'''
pid 資源位。該特征屬于分類特征,隻有兩類取值,是以考慮進行熱編碼處理即可,分為是否在資源位1、是否在資源位2 兩個特征
'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# StringIndexer對指定字元串列進行特征處理 [A,B,C]->[0,1,2]
stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')

# 對處理出來的特征處理列進行,熱獨編碼
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
# 利用管道對每一個資料進行熱獨編碼處理
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_model = pipeline.fit(raw_sample_df)
new_df = pipeline_model.transform(raw_sample_df)
new_df.show()
           

顯示結果:

+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644|        1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|        3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|        4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|        5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|        8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|        9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows

           
  • 傳回字段pid_value是一個稀疏向量類型資料 pyspark.ml.linalg.SparseVector
from pyspark.ml.linalg import SparseVector
# 參數:次元、索引清單、值清單
# (4, [1, 3], [3.0, 4.0])) ->[0, 3.0, 0, 4.0]
print(SparseVector(4, [1, 3], [3.0, 4.0]))
print(SparseVector(4, [1, 3], [3.0, 4.0]).toArray())
print("*********")
print(new_df.select("pid_value").first())
print(new_df.select("pid_value").first().pid_value.toArray())
           

顯示結果:

(4,[1,3],[3.0,4.0])
[0. 3. 0. 4.]
*********
Row(pid_value=SparseVector(2, {0: 1.0}))
[1. 0.]

           
  • 檢視最大時間
+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|177002|1494691186|   593001|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|243671|1494691186|   600195|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|488527|1494691184|   494312|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|488527|1494691184|   431082|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
| 17054|1494691184|   742741|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
| 17054|1494691184|   756665|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|488527|1494691184|   687854|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|839493|1494691183|   561681|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|704223|1494691183|   624504|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|839493|1494691183|   582235|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|704223|1494691183|   675674|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|628998|1494691180|   618965|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|674444|1494691179|   427579|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|627200|1494691179|   782038|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|627200|1494691179|   420769|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|674444|1494691179|   588664|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|738335|1494691179|   451004|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|627200|1494691179|   817569|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|322244|1494691179|   820018|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|322244|1494691179|   735220|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows
           
# 本樣本資料集共計8天資料
# 前七天為訓練資料、最後一天為測試資料

from datetime import datetime
datetime.fromtimestamp(1494691186)
print("該時間之前的資料為訓練樣本,該時間以後的資料為測試樣本:", datetime.fromtimestamp(1494691186-24*60*60))
           

顯示結果:

該時間之前的資料為訓練樣本,該時間以後的資料為測試樣本: 2017-05-12 23:59:46

           
  • 訓練樣本
# 訓練樣本:
train_sample = raw_sample_df.filter(raw_sample_df.timestamp<=(1494691186-24*60*60))
print("訓練樣本個數:")
print(train_sample.count())
# 測試樣本
test_sample = raw_sample_df.filter(raw_sample_df.timestamp>(1494691186-24*60*60))
print("測試樣本個數:")
print(test_sample.count())

# 注意:還需要加入廣告基本特征和使用者基本特征才能做程一份完整的樣本資料集
           

顯示結果:

訓練樣本個數:
23249291
測試樣本個數:
3308670

           

3.2 分析并預處理ad_feature資料集

# 從HDFS中加載廣告基本資訊資料,傳回spark dafaframe對象
df = spark.read.csv("hdfs://localhost:9000/data/ad_feature.csv", header=True)
df.show()    # 展示資料,預設前20條
           

顯示結果:

+----------+-------+-----------+--------+------+-----+
|adgroup_id|cate_id|campaign_id|customer| brand|price|
+----------+-------+-----------+--------+------+-----+
|     63133|   6406|      83237|       1| 95471|170.0|
|    313401|   6406|      83237|       1| 87331|199.0|
|    248909|    392|      83237|       1| 32233| 38.0|
|    208458|    392|      83237|       1|174374|139.0|
|    110847|   7211|     135256|       2|145952|32.99|
|    607788|   6261|     387991|       6|207800|199.0|
|    375706|   4520|     387991|       6|  NULL| 99.0|
|     11115|   7213|     139747|       9|186847| 33.0|
|     24484|   7207|     139744|       9|186847| 19.0|
|     28589|   5953|     395195|      13|  NULL|428.0|
|     23236|   5953|     395195|      13|  NULL|368.0|
|    300556|   5953|     395195|      13|  NULL|639.0|
|     92560|   5953|     395195|      13|  NULL|368.0|
|    590965|   4284|      28145|      14|454237|249.0|
|    529913|   4284|      70206|      14|  NULL|249.0|
|    546930|   4284|      28145|      14|  NULL|249.0|
|    639794|   6261|      70206|      14| 37004| 89.9|
|    335413|   4284|      28145|      14|  NULL|249.0|
|    794890|   4284|      70206|      14|454237|249.0|
|    684020|   6261|      70206|      14| 37004| 99.0|
+----------+-------+-----------+--------+------+-----+
only showing top 20 rows
           
# 注意:由于本資料集中存在NULL字樣的資料,無法直接設定schema,隻能先将NULL類型的資料處理掉,然後進行類型轉換

from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替換掉NULL字元串,替換掉
df = df.replace("NULL", "-1")

# 列印df結構資訊
df.printSchema()   
# 更改df表結構:更改列類型和列名稱
ad_feature_df = df.\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()
           

顯示結果:

root
 |-- adgroup_id: string (nullable = true)
 |-- cate_id: string (nullable = true)
 |-- campaign_id: string (nullable = true)
 |-- customer: string (nullable = true)
 |-- brand: string (nullable = true)
 |-- price: string (nullable = true)

root
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)

+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|    63133|  6406|     83237|         1|  95471|170.0|
|   313401|  6406|     83237|         1|  87331|199.0|
|   248909|   392|     83237|         1|  32233| 38.0|
|   208458|   392|     83237|         1| 174374|139.0|
|   110847|  7211|    135256|         2| 145952|32.99|
|   607788|  6261|    387991|         6| 207800|199.0|
|   375706|  4520|    387991|         6|     -1| 99.0|
|    11115|  7213|    139747|         9| 186847| 33.0|
|    24484|  7207|    139744|         9| 186847| 19.0|
|    28589|  5953|    395195|        13|     -1|428.0|
|    23236|  5953|    395195|        13|     -1|368.0|
|   300556|  5953|    395195|        13|     -1|639.0|
|    92560|  5953|    395195|        13|     -1|368.0|
|   590965|  4284|     28145|        14| 454237|249.0|
|   529913|  4284|     70206|        14|     -1|249.0|
|   546930|  4284|     28145|        14|     -1|249.0|
|   639794|  6261|     70206|        14|  37004| 89.9|
|   335413|  4284|     28145|        14|     -1|249.0|
|   794890|  4284|     70206|        14| 454237|249.0|
|   684020|  6261|     70206|        14|  37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
           
  • 檢視各項資料的特征
print("總廣告條數:",df.count())   # 資料條數
_1 = ad_feature_df.groupBy("cateId").count().count()
print("cateId數值個數:", _1)
_2 = ad_feature_df.groupBy("campaignId").count().count()
print("campaignId數值個數:", _2)
_3 = ad_feature_df.groupBy("customerId").count().count()
print("customerId數值個數:", _3)
_4 = ad_feature_df.groupBy("brandId").count().count()
print("brandId數值個數:", _4)
ad_feature_df.sort("price").show()
ad_feature_df.sort("price", ascending=False).show()
print("價格高于1w的條目個數:", ad_feature_df.select("price").filter("price>10000").count())
print("價格低于1的條目個數", ad_feature_df.select("price").filter("price<1").count())
           

顯示結果:

總廣告條數: 846811
cateId數值個數: 6769
campaignId數值個數: 423436
customerId數值個數: 255875
brandId數值個數: 99815
+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|   485749|  9970|    352666|    140520|     -1| 0.01|
|    88975|  9996|    198424|    182415|     -1| 0.01|
|   109704| 10539|     59774|     90351| 202710| 0.01|
|    49911|  7032|    129079|    172334|     -1| 0.01|
|   339334|  9994|    310408|    211292| 383023| 0.01|
|     6636|  6703|    392038|     46239| 406713| 0.01|
|    92241|  6130|     72781|    149714|     -1| 0.01|
|    20397| 10539|    410958|     65726|  79971| 0.01|
|   345870|  9995|    179595|    191036|  79971| 0.01|
|    77797|  9086|    218276|     31183|     -1| 0.01|
|    14435|  1136|    135610|     17788|     -1| 0.01|
|    42055|  9994|     43866|    113068| 123242| 0.01|
|    41925|  7032|     85373|    114532|     -1| 0.01|
|    67558|  9995|     90141|     83948|     -1| 0.01|
|   149570|  7043|    126746|    176076|     -1| 0.01|
|   518883|  7185|    403318|     58013|     -1| 0.01|
|     2246|  9996|    413653|     60214| 182966| 0.01|
|   290675|  4824|    315371|    240984|     -1| 0.01|
|   552638| 10305|    403318|     58013|     -1| 0.01|
|    89831| 10539|     90141|     83948| 211816| 0.01|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows

+---------+------+----------+----------+-------+-----------+
|adgroupId|cateId|campaignId|customerId|brandId|      price|
+---------+------+----------+----------+-------+-----------+
|   658722|  1093|    218101|    207754|     -1|      1.0E8|
|   468220|  1093|    270719|    207754|     -1|      1.0E8|
|   179746|  1093|    270027|    102509| 405447|      1.0E8|
|   443295|  1093|     44251|    102509| 300681|      1.0E8|
|    31899|   685|    218918|     31239| 278301|      1.0E8|
|   243384|   685|    218918|     31239| 278301|      1.0E8|
|   554311|  1093|    266086|    207754|     -1|      1.0E8|
|   513942|   745|      8401|     86243|     -1|8.8888888E7|
|   201060|   745|      8401|     86243|     -1|5.5555556E7|
|   289563|   685|     37665|    120847| 278301|      1.5E7|
|    35156|   527|    417722|     72273| 278301|      1.0E7|
|    33756|   527|    416333|     70894|     -1|  9900000.0|
|   335495|   739|    170121|    148946| 326126|  9600000.0|
|   218306|   206|    162394|      4339| 221720|  8888888.0|
|   213567|  7213|    239302|    205612| 406125|  5888888.0|
|   375920|   527|    217512|    148946| 326126|  4760000.0|
|   262215|   527|    132721|     11947| 417898|  3980000.0|
|   154623|   739|    170121|    148946| 326126|  3900000.0|
|   152414|   739|    170121|    148946| 326126|  3900000.0|
|   448651|   527|    422260|     41289| 209959|  3800000.0|
+---------+------+----------+----------+-------+-----------+
only showing top 20 rows

價格高于1w的條目個數: 6527
價格低于1的條目個數 5762

           
  • 特征選擇
    • cateId:脫敏過的商品類目ID;
    • campaignId:脫敏過的廣告計劃ID;
    • customerId:脫敏過的廣告主ID;
    • brandId:脫敏過的品牌ID;

    以上四個特征均屬于分類特征,但由于分類值個數均過于龐大,如果去做熱獨編碼處理,會導緻資料過于稀疏 且目前我們缺少對這些特征更加具體的資訊,(如商品類目具體資訊、品牌具體資訊等),進而無法對這些特征的資料做聚類、降維處理 是以這裡不選取它們作為特征

    而隻選取price作為特征資料,因為價格本身是一個統計類型連續數值型資料,且能很好的展現廣告的價值屬性特征,通常也不需要做其他處理(離散化、歸一化、标準化等),是以這裡直接将當做特征資料來使用

3.3 分析并預處理user_profile資料集

# 從HDFS加載使用者基本資訊資料
df = spark.read.csv("hdfs://localhost:9000/data/user_profile.csv", header=True)
# 發現pvalue_level和new_user_class_level存在空值:(注意此處的null表示空值,而如果是NULL,則往往表示是一個字元串)
# 是以直接利用schema就可以加載進該資料,無需替換null值
df.show()
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                    3|
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                 null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                    4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                    4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
           
# 注意:這裡的null會直接被pyspark識别為None資料,也就是na資料,是以這裡可以直接利用schema導入資料

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 建構表結構schema對象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema從hdfs加載
user_profile_df = spark.read.csv("hdfs://localhost:9000/data/user_profile.csv", header=True, schema=schema)
user_profile_df.printSchema()
user_profile_df.show()
           

顯示結果:

root
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: integer (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: integer (nullable = true)

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows
           
  • 顯示特征情況
print("分類特征值個數情況: ")
print("cms_segid: ", user_profile_df.groupBy("cms_segid").count().count())
print("cms_group_id: ", user_profile_df.groupBy("cms_group_id").count().count())
print("final_gender_code: ", user_profile_df.groupBy("final_gender_code").count().count())
print("age_level: ", user_profile_df.groupBy("age_level").count().count())
print("shopping_level: ", user_profile_df.groupBy("shopping_level").count().count())
print("occupation: ", user_profile_df.groupBy("occupation").count().count())

print("含缺失值的特征情況: ")
user_profile_df.groupBy("pvalue_level").count().show()
user_profile_df.groupBy("new_user_class_level").count().show()

t_count = user_profile_df.count()
pl_na_count = t_count - user_profile_df.dropna(subset=["pvalue_level"]).count()
print("pvalue_level的空值情況:", pl_na_count, "空值占比:%0.2f%%"%(pl_na_count/t_count*100))
nul_na_count = t_count - user_profile_df.dropna(subset=["new_user_class_level"]).count()
print("new_user_class_level的空值情況:", nul_na_count, "空值占比:%0.2f%%"%(nul_na_count/t_count*100))
           

顯示内容:

分類特征值個數情況: 
cms_segid:  97
cms_group_id:  13
final_gender_code:  2
age_level:  7
shopping_level:  3
occupation:  2
含缺失值的特征情況: 
+------------+------+
|pvalue_level| count|
+------------+------+
|        null|575917|
|           1|154436|
|           3| 37759|
|           2|293656|
+------------+------+

+--------------------+------+
|new_user_class_level| count|
+--------------------+------+
|                null|344920|
|                   1| 80548|
|                   3|173047|
|                   4|138833|
|                   2|324420|
+--------------------+------+

pvalue_level的空值情況: 575917 空值占比:54.24%
new_user_class_level的空值情況: 344920 空值占比:32.49%
           
  • 缺失值處理
    • 注意,一般情況下:
      • 缺失率低于10%:可直接進行相應的填充,如預設值、均值、算法拟合等等;
      • 高于10%:往往會考慮舍棄該特征
      • 特征處理,如1維轉多元
      但根據我們的經驗,我們的廣告推薦其實和使用者的消費水準、使用者所在城市等級都有比較大的關聯,是以在這裡pvalue_level、new_user_class_level都是比較重要的特征,我們不考慮舍棄
  • 缺失值處理方案:
    • 填充方案:結合使用者的其他特征值,利用随機森林算法進行預測;但産生了大量人為建構的資料,一定程度上增加了資料的噪音
    • 把變量映射到高維空間:如pvalue_level的1維資料,轉換成是否1、是否2、是否3、是否缺失的4維資料;這樣保證了所有原始資料不變,同時能提高精确度,但這樣會導緻資料變得比較稀疏,如果樣本量很小,反而會導緻樣本效果較差,是以也不能濫用
  • 填充方案
    • 利用随機森林對pvalue_level的缺失值進行預測
from pyspark.mllib.regression import LabeledPoint

# 剔除掉缺失值資料,将餘下的資料作為訓練資料
# user_profile_df.dropna(subset=["pvalue_level"]): 将pvalue_level中的空值所在行資料剔除後的資料,作為訓練樣本
train_data = user_profile_df.dropna(subset=["pvalue_level"]).rdd.map(
    lambda r:LabeledPoint(r.pvalue_level-1, [r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation])
)

# 注意随機森林輸入資料時,由于label的分類數是從0開始的,但pvalue_level的目前隻分别是1,2,3,是以需要對應分别-1來作為目标值
# 自然那麼最終得出預測值後,需要對應+1才能還原回來

# 我們使用cms_segid, cms_group_id, final_gender_code, age_level, shopping_level, occupation作為特征值,pvalue_level作為目标值
           
  • Labeled point

A labeled point is a local vector, either dense or sparse, associated with a label/response. In MLlib, labeled points are used in supervised learning algorithms. We use a double to store a label, so we can use labeled points in both regression and classification. For binary classification, a label should be either 0 (negative) or 1 (positive). For multiclass classification, labels should be class indices starting from zero: 0, 1, 2, ….

标記點是與标簽/響應相關聯的密集或稀疏的局部矢量。在MLlib中,标記點用于監督學習算法。我們使用double來存儲标簽,是以我們可以在回歸和分類中使用标記點。對于二分類情況,目标值應為0(負)或1(正)。對于多分類,标簽應該是從零開始的類索引:0, 1, 2, …。

Python

A labeled point is represented by LabeledPoint.

标記點表示為 LabeledPoint。

Refer to the LabeledPoint Python docs for more details on the API.

有關API的更多詳細資訊,請參閱LabeledPointPython文檔。

from pyspark.mllib.linalg import SparseVector
from pyspark.mllib.regression import LabeledPoint

# Create a labeled point with a positive label and a dense feature vector.
pos = LabeledPoint(1.0, [1.0, 0.0, 3.0])

# Create a labeled point with a negative label and a sparse feature vector.
neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0]))
           
  • 随機森林:pyspark.mllib.tree.RandomForest
from pyspark.mllib.tree import RandomForest
# 訓練分類模型
# 參數1 訓練的資料
#參數2 目标值的分類個數 0,1,2
#參數3 特征中是否包含分類的特征 {2:2,3:7} {2:2} 表示 在特征中 第二個特征是分類的: 有兩個分類
#參數4 随機森林中 樹的棵數
model = RandomForest.trainClassifier(train_data, 3, {}, 5)
           
  • 随機森林模型:pyspark.mllib.tree.RandomForestModel
# 預測單個資料
# 注意用法:https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html?highlight=tree%20random#pyspark.mllib.tree.RandomForestModel.predict
model.predict([0.0, 4.0 ,2.0 , 4.0, 1.0, 0.0])
           

顯示結果:

1.0
           
  • 篩選出缺失值條目
pl_na_df = user_profile_df.na.fill(-1).where("pvalue_level=-1")
pl_na_df.show(10)

def row(r):
    return r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation

# 轉換為普通的rdd類型
rdd = pl_na_df.rdd.map(row)
# 預測全部的pvalue_level值:
predicts = model.predict(rdd)
# 檢視前20條
print(predicts.take(20))
print("預測值總數", predicts.count())

# 這裡注意predict參數,如果是預測多個,那麼參數必須是直接有清單構成的rdd參數,而不能是dataframe.rdd類型
# 是以這裡經過map函數處理,将每一行資料轉換為普通的清單資料
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11602|        0|           5|                2|        5|          -1|             3|         0|                   2|
| 11727|        0|           3|                2|        3|          -1|             3|         0|                   1|
| 12195|        0|          10|                1|        4|          -1|             3|         0|                   2|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 10 rows

[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0]
預測值總數 575917
           
  • 轉換為pandas dataframe
# 這裡資料量比較小,直接轉換為pandas dataframe來處理,因為友善,但注意如果資料量較大不推薦,因為這樣會把全部資料加載到記憶體中
temp = predicts.map(lambda x:int(x)).collect()
pdf = pl_na_df.toPandas()
import numpy as np
 # 在pandas df的基礎上直接替換掉列資料
pdf["pvalue_level"] = np.array(temp) + 1  # 注意+1 還原預測值
pdf
           
  • 與非缺失資料進行拼接,完成pvalue_level的缺失值預測
new_user_profile_df = user_profile_df.dropna(subset=["pvalue_level"]).unionAll(spark.createDataFrame(pdf, schema=schema))
new_user_profile_df.show()

# 注意:unionAll的使用,兩個df的表結構必須完全一樣

           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                null|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
| 11739|       20|           3|                2|        3|           2|             3|         0|                   4|
| 12549|       33|           4|                2|        4|           2|             3|         0|                   2|
| 15155|       36|           5|                2|        5|           2|             1|         0|                null|
| 15347|       20|           3|                2|        3|           2|             3|         0|                   3|
| 15455|        8|           2|                2|        2|           2|             3|         0|                   3|
| 15783|        0|           4|                2|        4|           2|             3|         0|                null|
| 16749|        5|           2|                2|        2|           1|             3|         1|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows
           
  • 利用随機森林對new_user_class_level的缺失值進行預測
from pyspark.mllib.regression import LabeledPoint

# 選出new_user_class_level全部的
train_data2 = user_profile_df.dropna(subset=["new_user_class_level"]).rdd.map(
    lambda r:LabeledPoint(r.new_user_class_level - 1, [r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation])
)
from pyspark.mllib.tree import RandomForest
model2 = RandomForest.trainClassifier(train_data2, 4, {}, 5)
model2.predict([0.0, 4.0 ,2.0 , 4.0, 1.0, 0.0])
# 預測值實際應該為2
           

顯示結果:

1.0
           
nul_na_df = user_profile_df.na.fill(-1).where("new_user_class_level=-1")
nul_na_df.show(10)

def row(r):
    return r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation

rdd2 = nul_na_df.rdd.map(row)
predicts2 = model.predict(rdd2)
predicts2.take(20)
           
  • 顯示結果:
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 12620|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 14437|        0|           5|                2|        5|          -1|             3|         0|                  -1|
| 14574|        0|           1|                2|        1|          -1|             2|         0|                  -1|
| 14985|        0|          11|                1|        5|          -1|             2|         0|                  -1|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 10 rows

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0]

           
  • 總結:可以發現由于這兩個字段的缺失過多,是以預測出來的值已經大大失真,但如果缺失率在10%以下,這種方法是比較有效的一種
user_profile_df = user_profile_df.na.fill(-1)
user_profile_df.show()
# new_df = new_df.withColumn("pvalue_level", new_df.pvalue_level.cast(StringType()))\
#     .withColumn("new_user_class_level", new_df.new_user_class_level.cast(StringType()))
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows
           
  • 低維轉高維方式
    • 我們接下來采用将變量映射到高維空間的方法來處理資料,即将缺失項也當做一個單獨的特征來對待,保證資料的原始性

      由于該思想正好和熱獨編碼實作方法一樣,是以這裡直接使用熱獨編碼方式處理資料

from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# 使用熱獨編碼轉換pvalue_level的一維資料為多元,其中缺失值單獨作為一個特征值

# 需要先将缺失值全部替換為數值,與原有特征一起處理
from pyspark.sql.types import StringType
user_profile_df = user_profile_df.na.fill(-1)
user_profile_df.show()

# 熱獨編碼時,必須先将待處理字段轉為字元串類型才可處理
user_profile_df = user_profile_df.withColumn("pvalue_level", user_profile_df.pvalue_level.cast(StringType()))\
    .withColumn("new_user_class_level", user_profile_df.new_user_class_level.cast(StringType()))
user_profile_df.printSchema()

# 對pvalue_level進行熱獨編碼,求值
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(user_profile_df)
user_profile_df2 = pipeline_fit.transform(user_profile_df)
# pl_onehot_value列的值為稀疏向量,存儲熱獨編碼的結果
user_profile_df2.printSchema()
user_profile_df2.show()
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rows

root
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- pl_onehot_feature: double (nullable = false)
 |-- pl_onehot_value: vector (nullable = true)

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+
only showing top 20 rows

           
  • 使用熱編碼轉換new_user_class_level的一維資料為多元
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(user_profile_df2)
user_profile_df3 = pipeline_fit.transform(user_profile_df2)
user_profile_df3.show()
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows
           
  • 使用者特征合并
from pyspark.ml.feature import VectorAssembler
feature_df = VectorAssembler().setInputCols(["age_level", "pl_onehot_value", "nucl_onehot_value"]).setOutputCol("features").transform(user_profile_df3)
feature_df.show()
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+--------------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|            features|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|(10,[0,1,7],[5.0,...|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,3,6],[2.0,...|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,2,5],[2.0,...|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[6.0,...|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[5.0,...|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,1,6],[3.0,...|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,3,8],[1.0,...|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|(10,[0,2,9],[5.0,...|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[2.0,...|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,1,8],[5.0,...|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[2.0,...|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|(10,[0,2,6],[4.0,...|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,2,5],[4.0,...|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,1,5],[4.0,...|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|(10,[0,2,5],[4.0,...|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,1,8],[5.0,...|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|(10,[0,3,7],[2.0,...|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|(10,[0,3,8],[4.0,...|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+--------------------+
only showing top 20 rows
           

顯示結果:

+--------------------+
|            features|
+--------------------+
|(10,[0,1,7],[5.0,...|
|(10,[0,3,6],[2.0,...|
|(10,[0,2,5],[2.0,...|
|(10,[0,1,5],[4.0,...|
|(10,[0,1,5],[4.0,...|
|(10,[0,2,6],[6.0,...|
|(10,[0,2,6],[5.0,...|
|(10,[0,1,6],[3.0,...|
|(10,[0,3,8],[1.0,...|
|(10,[0,2,9],[5.0,...|
|(10,[0,2,6],[2.0,...|
|(10,[0,1,8],[5.0,...|
|(10,[0,2,6],[2.0,...|
|(10,[0,2,6],[4.0,...|
|(10,[0,2,5],[4.0,...|
|(10,[0,1,5],[4.0,...|
|(10,[0,2,5],[4.0,...|
|(10,[0,1,8],[5.0,...|
|(10,[0,3,7],[2.0,...|
|(10,[0,3,8],[4.0,...|
+--------------------+
only showing top 20 rows
           
  • 特征選取

除了前面處理的pvalue_level和new_user_class_level需要作為特征以外,(能展現出使用者的購買力特征),還有:

前面分析的以下幾個分類特征值個數情況:

- cms_segid:  97
- cms_group_id:  13
- final_gender_code:  2
- age_level:  7
- shopping_level:  3
- occupation:  2
-pvalue_level
-new_user_class_level
-price
           

根據經驗,以上幾個分類特征都一定程度能展現使用者在購物方面的特征,且類别都較少,都可以用來作為使用者特征

四、 LR實作CTR預估

4.1 Spark邏輯回歸(LR)訓練點選率預測模型

  • 本小節主要根據廣告點選樣本資料集(raw_sample)、廣告基本特征資料集(ad_feature)、使用者基本資訊資料集(user_profile)建構出了一個完整的樣本資料集,并按日期劃分為了訓練集(前七天)和測試集(最後一天),利用邏輯回歸進行訓練。

    訓練模型時,通過對類别特征資料進行處理,一定程度達到提高了模型的效果

'''從HDFS中加載樣本資料資訊'''
_raw_sample_df1 = spark.read.csv("hdfs://localhost:9000/data/raw_sample.csv", header=True)
# _raw_sample_df1.show()    # 展示資料,預設前20條
# 更改表結構,轉換為對應的資料類型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, LongType, StringType
  
# 更改df表結構:更改列類型和列名稱
_raw_sample_df2 = _raw_sample_df1.\
    withColumn("user", _raw_sample_df1.user.cast(IntegerType())).withColumnRenamed("user", "userId").\
    withColumn("time_stamp", _raw_sample_df1.time_stamp.cast(LongType())).withColumnRenamed("time_stamp", "timestamp").\
    withColumn("adgroup_id", _raw_sample_df1.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("pid", _raw_sample_df1.pid.cast(StringType())).\
    withColumn("nonclk", _raw_sample_df1.nonclk.cast(IntegerType())).\
    withColumn("clk", _raw_sample_df1.clk.cast(IntegerType()))
_raw_sample_df2.printSchema()
_raw_sample_df2.show()

# 樣本資料pid特征處理
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_raw_sample_df2)
raw_sample_df = pipeline_fit.transform(_raw_sample_df2)
raw_sample_df.show()

'''pid和特征的對應關系
430548_1007:0
430549_1007:1
'''
           

顯示結果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)

+------+----------+---------+-----------+------+---+
|userId| timestamp|adgroupId|        pid|nonclk|clk|
+------+----------+---------+-----------+------+---+
|581738|1494137644|        1|430548_1007|     1|  0|
|449818|1494638778|        3|430548_1007|     1|  0|
|914836|1494650879|        4|430548_1007|     1|  0|
|914836|1494651029|        5|430548_1007|     1|  0|
|399907|1494302958|        8|430548_1007|     1|  0|
|628137|1494524935|        9|430548_1007|     1|  0|
|298139|1494462593|        9|430539_1007|     1|  0|
|775475|1494561036|        9|430548_1007|     1|  0|
|555266|1494307136|       11|430539_1007|     1|  0|
|117840|1494036743|       11|430548_1007|     1|  0|
|739815|1494115387|       11|430539_1007|     1|  0|
|623911|1494625301|       11|430548_1007|     1|  0|
|623911|1494451608|       11|430548_1007|     1|  0|
|421590|1494034144|       11|430548_1007|     1|  0|
|976358|1494156949|       13|430548_1007|     1|  0|
|286630|1494218579|       13|430539_1007|     1|  0|
|286630|1494289247|       13|430539_1007|     1|  0|
|771431|1494153867|       13|430548_1007|     1|  0|
|707120|1494220810|       13|430548_1007|     1|  0|
|530454|1494293746|       13|430548_1007|     1|  0|
+------+----------+---------+-----------+------+---+
only showing top 20 rows

+------+----------+---------+-----------+------+---+-----------+-------------+
|userId| timestamp|adgroupId|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+---------+-----------+------+---+-----------+-------------+
|581738|1494137644|        1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|        3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|        4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|        5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|        8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|        9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|        9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|       11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|       11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|       13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|       13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+---------+-----------+------+---+-----------+-------------+
only showing top 20 rows

'pid和特征的對應關系\n430548_1007:0\n430549_1007:1\n'
           
  • 從HDFS中加載廣告基本資訊資料
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/datasets/ad_feature.csv", header=True)

# 更改表結構,轉換為對應的資料類型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替換掉NULL字元串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
 
# 更改df表結構:更改列類型和列名稱
ad_feature_df = _ad_feature_df.\
    withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", _ad_feature_df.price.cast(FloatType()))
ad_feature_df.printSchema()
ad_feature_df.show()
           

顯示結果:

root
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)

+---------+------+----------+----------+-------+-----+
|adgroupId|cateId|campaignId|customerId|brandId|price|
+---------+------+----------+----------+-------+-----+
|    63133|  6406|     83237|         1|  95471|170.0|
|   313401|  6406|     83237|         1|  87331|199.0|
|   248909|   392|     83237|         1|  32233| 38.0|
|   208458|   392|     83237|         1| 174374|139.0|
|   110847|  7211|    135256|         2| 145952|32.99|
|   607788|  6261|    387991|         6| 207800|199.0|
|   375706|  4520|    387991|         6|     -1| 99.0|
|    11115|  7213|    139747|         9| 186847| 33.0|
|    24484|  7207|    139744|         9| 186847| 19.0|
|    28589|  5953|    395195|        13|     -1|428.0|
|    23236|  5953|    395195|        13|     -1|368.0|
|   300556|  5953|    395195|        13|     -1|639.0|
|    92560|  5953|    395195|        13|     -1|368.0|
|   590965|  4284|     28145|        14| 454237|249.0|
|   529913|  4284|     70206|        14|     -1|249.0|
|   546930|  4284|     28145|        14|     -1|249.0|
|   639794|  6261|     70206|        14|  37004| 89.9|
|   335413|  4284|     28145|        14|     -1|249.0|
|   794890|  4284|     70206|        14| 454237|249.0|
|   684020|  6261|     70206|        14|  37004| 99.0|
+---------+------+----------+----------+-------+-----+
only showing top 20 rows
           
  • 從HDFS加載使用者基本資訊資料
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 建構表結構schema對象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema從hdfs加載
_user_profile_df1 = spark.read.csv("hdfs://localhost:9000/datasets/user_profile.csv", header=True, schema=schema)
# user_profile_df.printSchema()
# user_profile_df.show()

'''對缺失資料進行特征熱編碼'''
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# 使用熱編碼轉換pvalue_level的一維資料為多元,增加n-1個虛拟變量,n為pvalue_level的取值範圍

# 需要先将缺失值全部替換為數值,便于處理,否則會抛出異常
from pyspark.sql.types import StringType
_user_profile_df2 = _user_profile_df1.na.fill(-1)
# _user_profile_df2.show()

# 熱編碼時,必須先将待處理字段轉為字元串類型才可處理
_user_profile_df3 = _user_profile_df2.withColumn("pvalue_level", _user_profile_df2.pvalue_level.cast(StringType()))\
    .withColumn("new_user_class_level", _user_profile_df2.new_user_class_level.cast(StringType()))
# _user_profile_df3.printSchema()

# 對pvalue_level進行熱編碼,求值
# 運作過程是先将pvalue_level轉換為一列新的特征資料,然後對該特征資料求出的熱編碼值,存在了新的一列資料中,類型為一個稀疏矩陣
stringindexer = StringIndexer(inputCol='pvalue_level', outputCol='pl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pl_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df3)
_user_profile_df4 = pipeline_fit.transform(_user_profile_df3)
# pl_onehot_value列的值為稀疏矩陣,存儲熱編碼的結果
# _user_profile_df4.printSchema()
# _user_profile_df4.show()

# 使用熱編碼轉換new_user_class_level的一維資料為多元
stringindexer = StringIndexer(inputCol='new_user_class_level', outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(_user_profile_df4)
user_profile_df = pipeline_fit.transform(_user_profile_df4)
user_profile_df.show()
           

顯示結果:

+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|userId|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|pl_onehot_feature|pl_onehot_value|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|              0.0|  (4,[0],[1.0])|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|              2.0|  (4,[2],[1.0])|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|              0.0|  (4,[0],[1.0])|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|              1.0|  (4,[1],[1.0])|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|              1.0|  (4,[1],[1.0])|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|              0.0|  (4,[0],[1.0])|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|              1.0|  (4,[1],[1.0])|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|              0.0|  (4,[0],[1.0])|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|              2.0|  (4,[2],[1.0])|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|              2.0|  (4,[2],[1.0])|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-----------------+---------------+-------------------+-----------------+
only showing top 20 rows

           
  • 熱編碼中:"pvalue_level"特征對應關系:
+------------+----------------------+
|pvalue_level|pl_onehot_feature     |
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+
           
  • “new_user_class_level”的特征對應關系
+--------------------+------------------------+
|new_user_class_level|nucl_onehot_feature     |
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+
           
user_profile_df.groupBy("pvalue_level").min("pl_onehot_feature").show()
user_profile_df.groupBy("new_user_class_level").min("nucl_onehot_feature").show()
           

顯示結果:

+------------+----------------------+
|pvalue_level|min(pl_onehot_feature)|
+------------+----------------------+
|          -1|                   0.0|
|           3|                   3.0|
|           1|                   2.0|
|           2|                   1.0|
+------------+----------------------+

+--------------------+------------------------+
|new_user_class_level|min(nucl_onehot_feature)|
+--------------------+------------------------+
|                  -1|                     0.0|
|                   3|                     2.0|
|                   1|                     4.0|
|                   4|                     3.0|
|                   2|                     1.0|
+--------------------+------------------------+

           
  • Dataframe資料合并:pyspark.sql.DataFrame.join
# raw_sample_df和ad_feature_df合并條件
condition = [raw_sample_df.adgroupId==ad_feature_df.adgroupId]
_ = raw_sample_df.join(ad_feature_df, condition, 'outer')

# _和user_profile_df合并條件
condition2 = [_.userId==user_profile_df.userId]
datasets = _.join(user_profile_df, condition2, "outer")
# 檢視datasets的結構
datasets.printSchema()
# 檢視datasets條目數
print(datasets.count())
           

顯示結果:

root
 |-- userId: integer (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- pid: string (nullable = true)
 |-- nonclk: integer (nullable = true)
 |-- clk: integer (nullable = true)
 |-- pid_feature: double (nullable = true)
 |-- pid_value: vector (nullable = true)
 |-- adgroupId: integer (nullable = true)
 |-- cateId: integer (nullable = true)
 |-- campaignId: integer (nullable = true)
 |-- customerId: integer (nullable = true)
 |-- brandId: integer (nullable = true)
 |-- price: float (nullable = true)
 |-- userId: integer (nullable = true)
 |-- cms_segid: integer (nullable = true)
 |-- cms_group_id: integer (nullable = true)
 |-- final_gender_code: integer (nullable = true)
 |-- age_level: integer (nullable = true)
 |-- pvalue_level: string (nullable = true)
 |-- shopping_level: integer (nullable = true)
 |-- occupation: integer (nullable = true)
 |-- new_user_class_level: string (nullable = true)
 |-- pl_onehot_feature: double (nullable = true)
 |-- pl_onehot_value: vector (nullable = true)
 |-- nucl_onehot_feature: double (nullable = true)
 |-- nucl_onehot_value: vector (nullable = true)

26557961
           
  • 訓練CTRModel_Normal:直接将對應的特征的特征值組合成對應的特征向量進行訓練
# 剔除備援、不需要的字段
useful_cols = [
    # 
    # 時間字段,劃分訓練集和測試集
    "timestamp",
    # label目标值字段
    "clk",  
    # 特征值字段
    "pid_value",       # 資源位的特征向量
    "price",    # 廣告價格
    "cms_segid",    # 使用者微群ID
    "cms_group_id",    # 使用者組ID
    "final_gender_code",    # 使用者性别特征,[1,2]
    "age_level",    # 年齡等級,1-
    "shopping_level",
    "occupation",
    "pl_onehot_value",
    "nucl_onehot_value"
]
# 篩選指定字段資料,建構新的資料集
datasets_1 = datasets.select(*useful_cols)
# 由于前面使用的是outer方式合并的資料,産生了部分空值資料,這裡必須先剔除掉
datasets_1 = datasets_1.dropna()
print("剔除空值資料後,還剩:", datasets_1.count())
           

顯示結果:

剔除空值資料後,還剩: 25029435

           
  • 根據特征字段計算出特征向量,并劃分出訓練資料集和測試資料集
from pyspark.ml.feature import VectorAssembler
# 根據特征字段計算特征向量
datasets_1 = VectorAssembler().setInputCols(useful_cols[2:]).setOutputCol("features").transform(datasets_1)
# 訓練資料集: 約7天的資料
train_datasets_1 = datasets_1.filter(datasets_1.timestamp<=(1494691186-24*60*60))
# 測試資料集:約1天的資料量
test_datasets_1 = datasets_1.filter(datasets_1.timestamp>(1494691186-24*60*60))
# 所有的特征的特征向量已經彙總到在features字段中
train_datasets_1.show(5)
test_datasets_1.show(5)
           

顯示結果:

+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value| price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494261938|  0|(2,[1],[1.0])| 108.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494261938|  0|(2,[1],[1.0])|1880.0|        0|          11|                1|        5|             3|         0|  (4,[0],[1.0])|    (5,[1],[1.0])|(18,[1,2,4,5,6,7,...|
|1494553913|  0|(2,[1],[1.0])|2360.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494553913|  0|(2,[1],[1.0])|2200.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494436784|  0|(2,[1],[1.0])|5649.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+------+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
| timestamp|clk|    pid_value|price|cms_segid|cms_group_id|final_gender_code|age_level|shopping_level|occupation|pl_onehot_value|nucl_onehot_value|            features|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
|1494677292|  0|(2,[1],[1.0])|176.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|698.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494677292|  0|(2,[1],[1.0])|697.0|       19|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[1],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|247.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
|1494684007|  0|(2,[1],[1.0])|109.0|       18|           3|                2|        3|             3|         0|  (4,[1],[1.0])|    (5,[4],[1.0])|(18,[1,2,3,4,5,6,...|
+----------+---+-------------+-----+---------+------------+-----------------+---------+--------------+----------+---------------+-----------------+--------------------+
only showing top 5 rows

           
  • 建立邏輯回歸訓練器,并訓練模型:LogisticRegression、 LogisticRegressionModel
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression()
# 設定目标字段、特征值字段并訓練
model = lr.setLabelCol("clk").setFeaturesCol("features").fit(train_datasets_1)
# 對模型進行存儲
model.save("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 載入訓練好的模型
from pyspark.ml.classification import LogisticRegressionModel
model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_Normal.obj")
# 根據測試資料進行預測
result_1 = model.transform(test_datasets_1)
# 按probability升序排列資料,probability表示預測結果的機率
# 如果預測值是0,其機率是0.9248,那麼反之可推出1的可能性就是1-0.9248=0.0752,即點選機率約為7.52%
# 因為前面提到廣告的點選率一般都比較低,是以預測值通常都是0,是以通常需要反減得出點選的機率
result_1.select("clk", "price", "probability", "prediction").sort("probability").show(100)
           

顯示結果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  0|      1.0E8|[0.86822033939259...|       0.0|
|  0|      1.0E8|[0.88410457194969...|       0.0|
|  0|      1.0E8|[0.89175497837562...|       0.0|
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  0|      1.5E7|[0.93741450446939...|       0.0|
|  0|      1.5E7|[0.93757135079959...|       0.0|
|  0|      1.5E7|[0.93834723093801...|       0.0|
|  0|     1099.0|[0.93972095713786...|       0.0|
|  0|      338.0|[0.93972134993018...|       0.0|
|  0|      311.0|[0.93972136386626...|       0.0|
|  0|      300.0|[0.93972136954393...|       0.0|
|  0|      278.0|[0.93972138089925...|       0.0|
|  0|      188.0|[0.93972142735283...|       0.0|
|  0|      176.0|[0.93972143354663...|       0.0|
|  0|      168.0|[0.93972143767584...|       0.0|
|  0|      158.0|[0.93972144283734...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  0|      125.0|[0.93972145987031...|       0.0|
|  0|      119.0|[0.93972146296721...|       0.0|
|  0|       78.0|[0.93972148412937...|       0.0|
|  0|      59.98|[0.93972149343040...|       0.0|
|  0|       58.0|[0.93972149445238...|       0.0|
|  0|       56.0|[0.93972149548468...|       0.0|
|  0|       38.0|[0.93972150477538...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  0|       33.0|[0.93972150735613...|       0.0|
|  0|       30.0|[0.93972150890458...|       0.0|
|  0|       27.6|[0.93972151014334...|       0.0|
|  0|       18.0|[0.93972151509838...|       0.0|
|  0|       30.0|[0.93980311191464...|       0.0|
|  0|       28.0|[0.93980311294563...|       0.0|
|  0|       25.0|[0.93980311449212...|       0.0|
|  0|      688.0|[0.93999362023323...|       0.0|
|  0|      339.0|[0.93999379960808...|       0.0|
|  0|      335.0|[0.93999380166395...|       0.0|
|  0|      220.0|[0.93999386077017...|       0.0|
|  0|      176.0|[0.93999388338470...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  0|      158.0|[0.93999389263610...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  0|      122.5|[0.93999391088191...|       0.0|
|  0|       99.0|[0.93999392296012...|       0.0|
|  0|       88.0|[0.93999392861375...|       0.0|
|  0|       79.0|[0.93999393323945...|       0.0|
|  0|       75.0|[0.93999393529532...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       68.0|[0.93999393889308...|       0.0|
|  0|       59.9|[0.93999394305620...|       0.0|
|  0|      44.98|[0.93999395072458...|       0.0|
|  0|       35.5|[0.93999395559698...|       0.0|
|  0|       33.0|[0.93999395688189...|       0.0|
|  0|       32.8|[0.93999395698469...|       0.0|
|  0|       30.0|[0.93999395842379...|       0.0|
|  0|       28.0|[0.93999395945172...|       0.0|
|  0|       19.9|[0.93999396361485...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       19.8|[0.93999396366625...|       0.0|
|  0|       12.0|[0.93999396767518...|       0.0|
|  0|        6.7|[0.93999397039920...|       0.0|
|  0|      568.0|[0.94000369247841...|       0.0|
|  0|      398.0|[0.94000377983931...|       0.0|
|  0|      158.0|[0.94000390317214...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  0|     5718.0|[0.94001886593718...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  0|     4120.0|[0.94001968693052...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|     1027.5|[0.94002127571285...|       0.0|
|  0|      989.0|[0.94002129549211...|       0.0|
|  0|      672.0|[0.94002145834965...|       0.0|
|  0|      660.0|[0.94002146451460...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      598.0|[0.94002149636681...|       0.0|
|  0|      563.0|[0.94002151434789...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      509.0|[0.94002154209012...|       0.0|
|  0|      500.0|[0.94002154671382...|       0.0|
|  0|      498.0|[0.94002154774131...|       0.0|
|  0|      440.0|[0.94002157753851...|       0.0|
|  0|      430.0|[0.94002158267595...|       0.0|
|  0|      388.0|[0.94002160425322...|       0.0|
|  0|      369.0|[0.94002161401436...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      368.0|[0.94002161452811...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      366.0|[0.94002161555560...|       0.0|
|  0|      348.0|[0.94002162480299...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      299.0|[0.94002164997645...|       0.0|
|  0|      298.0|[0.94002165049020...|       0.0|
|  0|      297.0|[0.94002165100394...|       0.0|
|  0|      278.0|[0.94002166076508...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  0|      275.0|[0.94002166230631...|       0.0|
|  0|      273.0|[0.94002166333380...|       0.0|
|  0|      258.0|[0.94002167103995...|       0.0|
|  0|      256.0|[0.94002167206744...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows
           
  • 檢視樣本中點選的被實際點選的條目的預測情況

顯示結果:

+---+-----------+--------------------+----------+
|clk|      price|         probability|prediction|
+---+-----------+--------------------+----------+
|  1|5.5555556E7|[0.92481456486873...|       0.0|
|  1|      138.0|[0.93972145316035...|       0.0|
|  1|       35.0|[0.93972150632383...|       0.0|
|  1|      149.0|[0.93999389726180...|       0.0|
|  1|     5608.0|[0.94001892245145...|       0.0|
|  1|      275.0|[0.94002166230631...|       0.0|
|  1|       35.0|[0.94002178560473...|       0.0|
|  1|       49.0|[0.94004219516957...|       0.0|
|  1|      915.0|[0.94021082858784...|       0.0|
|  1|      598.0|[0.94021099096349...|       0.0|
|  1|      568.0|[0.94021100633025...|       0.0|
|  1|      398.0|[0.94021109340848...|       0.0|
|  1|      368.0|[0.94021110877521...|       0.0|
|  1|      299.0|[0.94021114411869...|       0.0|
|  1|      278.0|[0.94021115487539...|       0.0|
|  1|      259.0|[0.94021116460765...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      258.0|[0.94021116511987...|       0.0|
|  1|      195.0|[0.94021119738998...|       0.0|
|  1|      188.0|[0.94021120097554...|       0.0|
|  1|      178.0|[0.94021120609778...|       0.0|
|  1|      159.0|[0.94021121583003...|       0.0|
|  1|      149.0|[0.94021122095226...|       0.0|
|  1|      138.0|[0.94021122658672...|       0.0|
|  1|       58.0|[0.94021126756458...|       0.0|
|  1|       49.0|[0.94021127217459...|       0.0|
|  1|       35.0|[0.94021127934572...|       0.0|
|  1|       25.0|[0.94021128446795...|       0.0|
|  1|     2890.0|[0.94028789742257...|       0.0|
|  1|      220.0|[0.94028926340218...|       0.0|
|  1|      188.0|[0.94031410659516...|       0.0|
|  1|       68.0|[0.94031416796289...|       0.0|
|  1|       58.0|[0.94031417307687...|       0.0|
|  1|      198.0|[0.94035413548387...|       0.0|
|  1|      208.0|[0.94039204931181...|       0.0|
|  1|     8888.0|[0.94045237642030...|       0.0|
|  1|      519.0|[0.94045664687995...|       0.0|
|  1|      478.0|[0.94045666780037...|       0.0|
|  1|      349.0|[0.94045673362308...|       0.0|
|  1|      348.0|[0.94045673413334...|       0.0|
|  1|      316.0|[0.94045675046144...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      298.0|[0.94045675964600...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      199.0|[0.94045681016104...|       0.0|
|  1|      198.0|[0.94045681067129...|       0.0|
|  1|      187.1|[0.94045681623305...|       0.0|
|  1|      176.0|[0.94045682189685...|       0.0|
|  1|      168.0|[0.94045682597887...|       0.0|
|  1|      160.0|[0.94045683006090...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      158.0|[0.94045683108140...|       0.0|
|  1|      135.0|[0.94045684281721...|       0.0|
|  1|      129.0|[0.94045684587872...|       0.0|
|  1|      127.0|[0.94045684689923...|       0.0|
|  1|      125.0|[0.94045684791973...|       0.0|
|  1|      124.0|[0.94045684842999...|       0.0|
|  1|      118.0|[0.94045685149150...|       0.0|
|  1|      109.0|[0.94045685608377...|       0.0|
|  1|      108.0|[0.94045685659402...|       0.0|
|  1|       99.0|[0.94045686118630...|       0.0|
|  1|       98.0|[0.94045686169655...|       0.0|
|  1|       79.8|[0.94045687098314...|       0.0|
|  1|       79.0|[0.94045687139134...|       0.0|
|  1|       77.0|[0.94045687241185...|       0.0|
|  1|       72.5|[0.94045687470798...|       0.0|
|  1|       69.0|[0.94045687649386...|       0.0|
|  1|       68.0|[0.94045687700412...|       0.0|
|  1|       60.0|[0.94045688108613...|       0.0|
|  1|      43.98|[0.94045688926037...|       0.0|
|  1|       40.0|[0.94045689129118...|       0.0|
|  1|       39.9|[0.94045689134220...|       0.0|
|  1|       39.6|[0.94045689149528...|       0.0|
|  1|       32.0|[0.94045689537319...|       0.0|
|  1|       31.0|[0.94045689588345...|       0.0|
|  1|      25.98|[0.94045689844491...|       0.0|
|  1|       23.0|[0.94045689996546...|       0.0|
|  1|       19.0|[0.94045690200647...|       0.0|
|  1|       16.9|[0.94045690307800...|       0.0|
|  1|       10.0|[0.94045690659874...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        3.5|[0.94045690991538...|       0.0|
|  1|        0.4|[0.94045691149716...|       0.0|
|  1|     3960.0|[0.94055740378069...|       0.0|
|  1|     3088.0|[0.94055784801535...|       0.0|
|  1|     1689.0|[0.94055856072019...|       0.0|
|  1|      998.0|[0.94055891273943...|       0.0|
|  1|      888.0|[0.94055896877705...|       0.0|
|  1|      788.0|[0.94055901972029...|       0.0|
|  1|      737.0|[0.94055904570133...|       0.0|
|  1|      629.0|[0.94055910071996...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      599.0|[0.94055911600291...|       0.0|
|  1|      499.0|[0.94055916694603...|       0.0|
|  1|      468.0|[0.94055918273839...|       0.0|
|  1|      459.0|[0.94055918732327...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
|  1|      399.0|[0.94055921788912...|       0.0|
+---+-----------+--------------------+----------+
only showing top 100 rows

           

五、離線推薦資料緩存

5.1離線資料緩存之離線召回集

  • 這裡主要是利用我們前面訓練的ALS模型進行協同過濾召回,但是注意,我們ALS模型召回的是使用者最感興趣的類别,而我們需要的是使用者可能感興趣的廣告的集合,是以我們還需要根據召回的類别比對出對應的廣告。

    是以這裡我們除了需要我們訓練的ALS模型以外,還需要有一個廣告和類别的對應關系。

# 從HDFS中加載廣告基本資訊資料,傳回spark dafaframe對象
df = spark.read.csv("hdfs://localhost:9000/data/ad_feature.csv", header=True)

# 注意:由于本資料集中存在NULL字樣的資料,無法直接設定schema,隻能先将NULL類型的資料處理掉,然後進行類型轉換

from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替換掉NULL字元串,替換掉
df = df.replace("NULL", "-1")

# 更改df表結構:更改列類型和列名稱
ad_feature_df = df.\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", df.price.cast(FloatType()))

# 這裡我們隻需要adgroupId、和cateId
_ = ad_feature_df.select("adgroupId", "cateId")
# 由于這裡資料集其實很少,是以我們再直接轉成Pandas dataframe來處理,把資料載入記憶體
pdf = _.toPandas()


# 手動釋放一些記憶體
del df
del ad_feature_df
del _
import gc
gc.collect()
           
  • 根據指定的類别找到對應的廣告
import numpy as np
pdf.where(pdf.cateId==11156).dropna().adgroupId

np.random.choice(pdf.where(pdf.cateId==11156).dropna().adgroupId.astype(np.int64), 200)
           

顯示結果:

313       138953.0
314       467512.0
1661      140008.0
1666      238772.0
1669      237471.0
1670      238761.0
			...   
843456    352273.0
846728    818681.0
846729    838953.0
846810    845337.0
Name: adgroupId, Length: 731, dtype: float64

           
  • 利用ALS模型進行類别的召回
# 加載als模型,注意必須先有spark上下文管理器,即sparkContext,但這裡sparkSession建立後,自動建立了sparkContext

from pyspark.ml.recommendation import ALSModel
# 從hdfs加載之前存儲的模型
als_model = ALSModel.load("hdfs://localhost:9000/models/userCateRatingALSModel.obj")
# 傳回模型中關于使用者的所有屬性   df:   id   features
als_model.userFactors
           

顯示結果:

DataFrame[id: int, features: array<float>]
           
import pandas as pd
cateId_df = pd.DataFrame(pdf.cateId.unique(),columns=["cateId"])
cateId_df
           

顯示結果:

cateId
0	1
1	2
2	3
3	4
4	5
5	6
6	7
...	...
6766	12948
6767	12955
6768	12960
6769 rows × 1 columns

           
cateId_df.insert(0, "userId", np.array([8 for i in range(6769)]))
cateId_df
           

顯示結果:

userId cateId
0	8	1
1	8	2
2	8	3
3	8	4
4	8	5
...	...	...
6766	8	12948
6767	8	12955
6768	8	12960
6769 rows × 2 columns

           
  • 傳入 userid、cataId的df,對應預測值進行排序

顯示結果:

+------+------+----------+
|userId|cateId|prediction|
+------+------+----------+
|     8|  7214|  9.917084|
|     8|   877|  7.479664|
|     8|  7266| 7.4762917|
|     8| 10856| 7.3395424|
|     8|  4766|  7.149538|
|     8|  7282| 6.6835284|
|     8|  7270| 6.2145095|
|     8|   201| 6.0623236|
|     8|  4267| 5.9155636|
|     8|  7267|  5.838009|
|     8|  5392| 5.6882005|
|     8|  6261| 5.6804466|
|     8|  6306| 5.2992325|
|     8| 11050|  5.245261|
|     8|  8655| 5.1701374|
|     8|  4610|  5.139578|
|     8|   932|   5.12694|
|     8| 12276| 5.0776596|
|     8|  8071|  4.979195|
|     8|  6580| 4.8523283|
+------+------+----------+
only showing top 20 rows

           
import numpy as np
import pandas as pd

import redis

# 存儲使用者召回,使用redis第9号資料庫,類型:sets類型
client = redis.StrictRedis(host="192.168.19.137", port=6379, db=9)
# 周遊als_model 中 所有使用者的id
for r in als_model.userFactors.select("id").collect():
    
    userId = r.id
    
    #準備 目前使用者 和 所有類别 一一對應的dataframe
    cateId_df = pd.DataFrame(pdf.cateId.unique(),columns=["cateId"])
    cateId_df.insert(0, "userId", np.array([userId for i in range(6769)]))
    ret = set()
    
    # 利用模型,傳入datasets(userId, cateId),這裡控制了userId一樣,是以相當于是在求某使用者對所有分類的興趣程度
    cateId_list = als_model.transform(spark.createDataFrame(cateId_df)).sort("prediction", ascending=False).na.drop()
    # 找到前 20個 最感興趣的類别 從前20個分類中選出500個進行召回
    for i in cateId_list.head(20):
        need = 500 - len(ret)    # 如果不足500個,那麼随機選出need個廣告
        ret = ret.union(np.random.choice(pdf.where(pdf.cateId==i.cateId).adgroupId.dropna().astype(np.int64), need))
        if len(ret) >= 500:    # 如果達到500個則退出
            break
    client.sadd(userId, *ret)
    
# 如果redis所在機器,記憶體不足,會抛出異常
           

5.2 離線資料緩存之離線特征

# "pid", 廣告資源位,屬于場景特征,也就是說,每一種廣告通常是可以防止在多種資源外下的
# 是以這裡對于pid,應該是由廣告系統發起推薦請求時,向推薦系統明确要推薦的使用者是誰,以及對應的資源位,或者說有哪些
# 這樣如果有多個資源位,那麼每個資源位都會對應相應的一個推薦清單

# 需要進行緩存的特征值
    
feature_cols_from_ad = [
    "price"    # 來自廣告基本資訊中
]

# 使用者特征
feature_cols_from_user = [
    "cms_group_id",
    "final_gender_code",
    "age_level",
    "shopping_level",
    "occupation",
    "pvalue_level",
    "new_user_class_level"
]
           
  • 從HDFS中加載廣告基本資訊資料
_ad_feature_df = spark.read.csv("hdfs://localhost:9000/data/ad_feature.csv", header=True)

# 更改表結構,轉換為對應的資料類型
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替換掉NULL字元串
_ad_feature_df = _ad_feature_df.replace("NULL", "-1")
 
# 更改df表結構:更改列類型和列名稱
ad_feature_df = _ad_feature_df.\
    withColumn("adgroup_id", _ad_feature_df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", _ad_feature_df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", _ad_feature_df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", _ad_feature_df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", _ad_feature_df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", _ad_feature_df.price.cast(FloatType()))
    
def foreachPartition(partition):
    
    import redis
    import json
    client = redis.StrictRedis(host="192.168.19.137", port=6379, db=10)
    
    for r in partition:
        data = {
            "price": r.price
        }
        # 轉成json字元串再儲存,能保證資料再次倒出來時,能有效的轉換成python類型
        client.hset("ad_features", r.adgroupId, json.dumps(data))
        
ad_feature_df.foreachPartition(foreachPartition)
           
  • 從HDFS加載使用者基本資訊資料
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType

# 建構表結構schema對象
schema = StructType([
    StructField("userId", IntegerType()),
    StructField("cms_segid", IntegerType()),
    StructField("cms_group_id", IntegerType()),
    StructField("final_gender_code", IntegerType()),
    StructField("age_level", IntegerType()),
    StructField("pvalue_level", IntegerType()),
    StructField("shopping_level", IntegerType()),
    StructField("occupation", IntegerType()),
    StructField("new_user_class_level", IntegerType())
])
# 利用schema從hdfs加載
user_profile_df = spark.read.csv("hdfs://localhost:8020/csv/user_profile.csv", header=True, schema=schema)
user_profile_df
           

顯示結果:

DataFrame[userId: int, cms_segid: int, cms_group_id: int, final_gender_code: int, age_level: int, pvalue_level: int, shopping_level: int, occupation: int, new_user_class_level: int]

           
def foreachPartition2(partition):
    
    import redis
    import json
    client = redis.StrictRedis(host="192.168.199.188", port=6379, db=10)
    
    for r in partition:
        data = {
            "cms_group_id": r.cms_group_id,
            "final_gender_code": r.final_gender_code,
            "age_level": r.age_level,
            "shopping_level": r.shopping_level,
            "occupation": r.occupation,
            "pvalue_level": r.pvalue_level,
            "new_user_class_level": r.new_user_class_level
        }
        # 轉成json字元串再儲存,能保證資料再次倒出來時,能有效的轉換成python類型
        client.hset("user_features1", r.userId, json.dumps(data))
        
user_profile_df.foreachPartition(foreachPartition2)
           

六、 實時産生推薦結果

6.1 推薦任務處理

  • CTR預測模型 + 特征 ==> 預測結果 ==> TOP-N清單
  • 特征擷取
import redis
import json
import pandas as pd
from pyspark.ml.linalg import DenseVector


def create_datasets(userId, pid):
    client_of_recall = redis.StrictRedis(host="192.168.19.137", port=6379, db=9)
    client_of_features = redis.StrictRedis(host="192.168.19.137", port=6379, db=10)
    # 擷取使用者特征
    user_feature = json.loads(client_of_features.hget("user_features", userId))
    
    # 擷取使用者召回集
    recall_sets = client_of_recall.smembers(userId)
    
    result = []
    
    # 周遊召回集
    for adgroupId in recall_sets:
        adgroupId = int(adgroupId)
        # 擷取該廣告的特征值
        ad_feature = json.loads(client_of_features.hget("ad_features", adgroupId))
        
        features = {}
        features.update(user_feature)
        features.update(ad_feature)

        for k,v in features.items():
            if v is None:
                features[k] = -1

        features_col = [
            # 特征值
            "price",
            "cms_group_id",
            "final_gender_code",
            "age_level",
            "shopping_level",
            "occupation",
            "pid", 
            "pvalue_level",
            "new_user_class_level"
        ]
        '''
        "cms_group_id", 類别型特征,約13個分類 ==> 13維
        "final_gender_code", 類别型特征,2個分類 ==> 2維
        "age_level", 類别型特征,7個分類 ==>7維
        "shopping_level", 類别型特征,3個分類 ==> 3維
        "occupation", 類别型特征,2個分類 ==> 2維
        '''

        price = float(features["price"])

        pid_value = [0 for i in range(2)]#[0,0]
        cms_group_id_value = [0 for i in range(13)]
        final_gender_code_value = [0 for i in range(2)]
        age_level_value = [0 for i in range(7)]
        shopping_level_value = [0 for i in range(3)]
        occupation_value = [0 for i in range(2)]
        pvalue_level_value = [0 for i in range(4)]
        new_user_class_level_value = [0 for i in range(5)]

        pid_value[pid_rela[pid]] = 1
        cms_group_id_value[cms_group_id_rela[int(features["cms_group_id"])]] = 1
        final_gender_code_value[final_gender_code_rela[int(features["final_gender_code"])]] = 1
        age_level_value[age_level_rela[int(features["age_level"])]] = 1
        shopping_level_value[shopping_level_rela[int(features["shopping_level"])]] = 1
        occupation_value[occupation_rela[int(features["occupation"])]] = 1
        pvalue_level_value[pvalue_level_rela[int(features["pvalue_level"])]] = 1
        new_user_class_level_value[new_user_class_level_rela[int(features["new_user_class_level"])]] = 1
 #         print(pid_value)
#         print(cms_group_id_value)
#         print(final_gender_code_value)
#         print(age_level_value)
#         print(shopping_level_value)
#         print(occupation_value)
#         print(pvalue_level_value)
#         print(new_user_class_level_value)
        
        vector = DenseVector([price] + pid_value + cms_group_id_value + final_gender_code_value\
        + age_level_value + shopping_level_value + occupation_value + pvalue_level_value + new_user_class_level_value)
        
        result.append((userId, adgroupId, vector))
        
    return result

# create_datasets(88, "430548_1007")
           
  • 載入訓練好的模型
from pyspark.ml.classification import LogisticRegressionModel
CTR_model = LogisticRegressionModel.load("hdfs://localhost:9000/models/CTRModel_AllOneHot.obj")
pdf = pd.DataFrame(create_datasets(8, "430548_1007"), columns=["userId", "adgroupId", "features"])
datasets = spark.createDataFrame(pdf)
datasets.show()
           

顯示結果:

+------+---------+--------------------+
|userId|adgroupId|            features|
+------+---------+--------------------+
|     8|   445914|[9.89999961853027...|
|     8|   258252|[7.59999990463256...|
|     8|   129682|[8.5,1.0,0.0,1.0,...|
|     8|   763027|[68.0,1.0,0.0,1.0...|
|     8|   292027|[16.0,1.0,0.0,1.0...|
|     8|   430023|[34.2000007629394...|
|     8|   133457|[169.0,1.0,0.0,1....|
|     8|   816999|[5.0,1.0,0.0,1.0,...|
|     8|   221714|[4.80000019073486...|
|     8|   186334|[106.0,1.0,0.0,1....|
|     8|   169717|[2.20000004768371...|
|     8|    31314|[15.8000001907348...|
|     8|   815312|[2.29999995231628...|
|     8|   199445|[5.0,1.0,0.0,1.0,...|
|     8|   746178|[16.7999992370605...|
|     8|   290950|[6.5,1.0,0.0,1.0,...|
|     8|   221585|[18.5,1.0,0.0,1.0...|
|     8|   692672|[47.0,1.0,0.0,1.0...|
|     8|   797982|[33.0,1.0,0.0,1.0...|
|     8|   815219|[2.40000009536743...|
+------+---------+--------------------+
only showing top 20 rows
           
prediction = CTR_model.transform(datasets).sort("probability")
prediction.show()
           
+------+---------+--------------------+--------------------+--------------------+----------+
|userId|adgroupId|            features|       rawPrediction|         probability|prediction|
+------+---------+--------------------+--------------------+--------------------+----------+
|     8|   631204|[19888.0,1.0,0.0,...|[2.69001234046578...|[0.93643471623189...|       0.0|
|     8|   583215|[3750.0,1.0,0.0,1...|[2.69016170680037...|[0.93644360664433...|       0.0|
|     8|   275819|[3280.0,1.0,0.0,1...|[2.69016605691669...|[0.93644386554961...|       0.0|
|     8|   401433|[1200.0,1.0,0.0,1...|[2.69018530849532...|[0.93644501133142...|       0.0|
|     8|    29466|[640.0,1.0,0.0,1....|[2.69019049161265...|[0.93644531980785...|       0.0|
|     8|   173327|[356.0,1.0,0.0,1....|[2.69019312019358...|[0.93644547624893...|       0.0|
|     8|   241402|[269.0,1.0,0.0,1....|[2.69019392542787...|[0.93644552417271...|       0.0|
|     8|   351366|[246.0,1.0,0.0,1....|[2.69019413830591...|[0.93644553684221...|       0.0|
|     8|   229827|[238.0,1.0,0.0,1....|[2.69019421235044...|[0.93644554124900...|       0.0|
|     8|   164807|[228.0,1.0,0.0,1....|[2.69019430490611...|[0.93644554675747...|       0.0|
|     8|   227731|[199.0,1.0,0.0,1....|[2.69019457331754...|[0.93644556273205...|       0.0|
|     8|   265403|[198.0,1.0,0.0,1....|[2.69019458257311...|[0.93644556328290...|       0.0|
|     8|   569939|[188.0,1.0,0.0,1....|[2.69019467512877...|[0.93644556879138...|       0.0|
|     8|   277335|[181.5,1.0,0.0,1....|[2.69019473528996...|[0.93644557237189...|       0.0|
|     8|   575633|[180.0,1.0,0.0,1....|[2.69019474917331...|[0.93644557319816...|       0.0|
|     8|   201867|[179.0,1.0,0.0,1....|[2.69019475842887...|[0.93644557374900...|       0.0|
|     8|    25542|[176.0,1.0,0.0,1....|[2.69019478619557...|[0.93644557540155...|       0.0|
|     8|   133457|[169.0,1.0,0.0,1....|[2.69019485098454...|[0.93644557925748...|       0.0|
|     8|   494224|[169.0,1.0,0.0,1....|[2.69019485098454...|[0.93644557925748...|       0.0|
|     8|   339382|[163.0,1.0,0.0,1....|[2.69019490651794...|[0.93644558256256...|       0.0|
+------+---------+--------------------+--------------------+--------------------+----------+
only showing top 20 rows
           
  • TOP-20
# TOP-20
prediction.select("adgroupId").head(20)
           

顯示結果:

[Row(adgroupId=631204),
 Row(adgroupId=583215),
 Row(adgroupId=275819),
 Row(adgroupId=401433),
 Row(adgroupId=29466),
 Row(adgroupId=173327),
 Row(adgroupId=241402),
 Row(adgroupId=351366),
 Row(adgroupId=229827),
 Row(adgroupId=164807),
 Row(adgroupId=227731),
 Row(adgroupId=265403),
 Row(adgroupId=569939),
 Row(adgroupId=277335),
 Row(adgroupId=575633),
 Row(adgroupId=201867),
 Row(adgroupId=25542),
 Row(adgroupId=133457),
 Row(adgroupId=494224),
 Row(adgroupId=339382)]
           
[i.adgroupId for i in prediction.select("adgroupId").head(20)]
           

顯示結果:

[631204,
 583215,
 275819,
 401433,
 29466,
 173327,
 241402,
 351366,
 229827,
 164807,
 227731,
 265403,
 569939,
 277335,
 575633,
 201867,
 25542,
 133457,
 494224,
 339382]