天天看点

spark mllib ALS电影推荐

package zqr.com;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;

import java.util.*;

public class MovieRecommended {

    public static void main(String args[]){
        // 输入用户id
        System.out.println("输入用户id:");
        Scanner scan=new Scanner(System.in);
        String number=scan.nextLine();


        SparkConf conf = new SparkConf().setAppName("movie recommended system").setMaster("local");

        JavaSparkContext sc = new JavaSparkContext(conf); // 其底层实际上就是Scala的SparkContext

        String path = "/usr/local/spark/testdata/ml-latest-small/ratings.csv";

        JavaRDD<String> rating_data = sc.textFile(path);

        // rating_data.collect().forEach(System.out::println);



        JavaRDD<Rating> mapdata = rating_data.map(s -> {
            String[] sarray = s.split(",");
            Rating rt=null;
            if(sarray[0].equals("userId")) {

             rt=null;

            }else if(Double.parseDouble(sarray[2])<5&&Double.parseDouble(sarray[2])>=0&&!sarray[2].isEmpty()){
                rt=new Rating(Integer.parseInt(sarray[0]),
                        Integer.parseInt(sarray[1]),
                        Double.parseDouble(sarray[2]));
            }else{
                rt=null;
            }

            return rt;
        }).filter(k->k!=null);


        //=====================================================================================

        //mapdata.collect().forEach(System.out::println);

        // 隐性因子个数
        int rank = 10;
        //迭代次数
        int numIterations = 10;
        //lambda是ALS的正则化参数;
        MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(mapdata), rank, numIterations, 0.01);

        //System.out.println("model:"+model);

        // 评估评级数据模型
        JavaRDD<Tuple2<Object, Object>> userProducts =
                mapdata.map(r -> new Tuple2<>(r.user(), r.product()));

        JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD(
                model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
                        .map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))
        );

        //System.out.println("打印predictions的值");
        //predictions.collect().forEach(System.out::println);
        //Map<String,Map<Double,String>> data=new HashMap<String,Map<Double,String>>();
        Map<String,Double> tow=new TreeMap<String,Double>();
        List list=predictions.collect();
        for(Object x : list){
            String string=x.toString();
            String []arr=string.split("[,()]");
            String uid=arr[2];
            String mid=arr[3];
            double pfen=Double.parseDouble(arr[5]);

            if(uid.equals(number)){

            }
            tow.put(mid,pfen);
            //data.put(uid,tow);
            //System.out.println(uid+"--->"+mid+"---->"+pfen);

        }
           //System.out.println(list);

//           for (Map.Entry<String, Double> entry : tow.entrySet()) {
//           System.out.println(entry.getKey() + ":" + entry.getValue());
//           }


        //转换成list进行排序
        List<Map.Entry<String, Double>> li = new ArrayList<Map.Entry<String,Double>>(tow.entrySet());
       // 排序
        Collections.sort(li, new Comparator<Map.Entry<String, Double>>() {
            //根据value排序
            public int compare(Map.Entry<String, Double> o1,
                               Map.Entry<String, Double> o2) {
                double result = o2.getValue() - o1.getValue();
                if(result > 0)
                    return 1;
                else if(result == 0)
                    return 0;
                else
                    return -1;
            }
        });
        List<String> l=new ArrayList<String>();
        int num=0;
        for (Map.Entry<String, Double> entry : li) {
            System.out.println(entry.getKey() + "      " + entry.getValue());
            l.add(entry.getKey().toString());
            num++;
            if(num==10){
                break;
            }
        }




        String path1 = "/usr/local/spark/testdata/ml-latest-small/movies.csv";

        JavaRDD<String> mviddata = sc.textFile(path1);

        // rating_data.collect().forEach(System.out::println);


            JavaRDD<Map<String,String>> move_pair = mviddata.map(s -> {
                String[] sarray = s.split(",");
                String x=sarray[0].toString();
                String y=sarray[1].toString();
                Map mp=new HashMap<String,String>();
                mp.put(x,y);
                return  mp;
            });


        //move_pair.collect().forEach(System.out::println);
        for(String idx:l) {
            move_pair.foreach(x -> {
             if(x.containsKey(idx)){
                 String value=x.get(idx).toString();
                 System.out.println(value);
             }

            });
        }








//        for (Map.Entry<String, Double> entry : tow.entrySet()) {
//            System.out.println(entry.getKey() + ":" + entry.getValue());
//        }

       // model.save(sc.sc(), "target/movierd/myCollaborativeFilter");
    }



}
      

继续阅读