import org.apache.spark.sql.{SparkSession}
//action:userid~ docid ~behaivor(label)~time~ip
//160520092238579653~160704235940001~0~20160705000040909~1.49.185.165
//160520092238579653~160704235859003~0~20160705000040909~1.49.185.165
//define case class for action data
case class Action(docid: String, label:Int)
//document:docid ~ channelname ~ source ~ keyword:score
//160705131650005~科技~偏执电商~支付宝:0.17621 医疗:0.14105 复星:0.07106 动作:0.05235 邮局:0.04428
//160705024106002~体育~平大爷的刺~阿杜:0.23158 杜兰特:0.09447 巨头:0.08470 拯救者:0.06638 勇士:0.05453
//define case class for document data
case class Dccument(docid: String, channal: String, source: String, tags: String)
object GenTrainingData {
def main(args: Array[String]): Unit = {
//2rd_data/ch09/action.txt 2rd_data/ch09/document.txt output/ch11 local[2]
val Array(actionPath, documentPath, output, mode) = args
// 创建Spark实例
val spark = SparkSession.builder
.master(mode)
.appName(this.getClass.getName)
.getOrCreate()
import spark.implicits._
val ActionDF = spark.sparkContext.textFile(actionPath).map(_.split("~"))
.map(x => Action(x(1).trim.toString, x(2).trim.toInt))
.toDF()
// Register the DataFrame as a temporary view
//ActionDF.createOrReplaceTempView("actiondf")
val documentDF = spark.sparkContext.textFile(documentPath).map(_.split("~")).filter(_.length > 3)
.map { case x =>
val xtags = x(3).split(" ").filter(_.length > 0).map { b => b.substring(0, b.indexOf(":")) }.mkString("|")
Dccument(x(0).trim.toString, x(1).trim.toString, x(2).trim.toString, xtags.toString)
}
.toDF()
// Register the DataFrame as a temporary view
//documentDF.createOrReplaceTempView("documentdf")
// 将查询结果放到tempDF中,完成dataframe转化
//val tempDF = spark.sql("select actiondf.docid,actiondf.label,documentdf.channal,documentdf.source,documentdf.tags from actiondf,documentdf where actiondf.docid = documentdf.docid")
val tempDF = documentDF.join(ActionDF, documentDF("docid").equalTo(ActionDF("docid")))
//tempDF.select($"tags").show(100)
// 编码格式转换
val minDF = tempDF.select($"tags").rdd
.flatMap{ x => x.toString.replace("[","").replace("]","").split('|') }.distinct
//minDF.coalesce(1).saveAsTextFile(output+"/tags")
val indexes = minDF.collect().zipWithIndex.toMap
println(indexes.toList.length) //23937
//
val libsvmDF = tempDF.select($"label", $"tags").map {
x =>
val label = x(0)
val terms = x(1).toString.replace("[","").replace("]","")
.split('|') //使用单引号
.map(v => (indexes.get(v).getOrElse(-1)+1, 1)) //索引从0开始
.sortBy(_._1) //libsvm 需要升序
.map(x => x._1 + ":" + x._2)
.mkString(" ")
(label.toString + " " + terms)
}
libsvmDF.show(100)
//保存模型时存在:Exception while deleting local spark dir,不影响结果生成,作为已知问题暂时搁置。
//libsvmDF.coalesce(1).write.format("text").save(output+"/model")
//libsvmDF.rdd.coalesce(1).saveAsTextFile(output+"/model")
val Array(trainingdata, testdata) = libsvmDF.randomSplit(Array(0.7, 0.3))
trainingdata.rdd.coalesce(1).saveAsTextFile(output+"/training")
testdata.rdd.coalesce(1).saveAsTextFile(output+"/test")
//
//spark.stop()
}
}
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
object LRTrainAndTest {
def main(args: Array[String]) {
if (args.length < 8) {
System.err.println("Usage:LRTrainAndTest <trainingPath> <testPath> <output> <numFeatures> <partitions> <RegParam> <NumIterations> <NumCorrections>")
System.exit(1)
}
//2rd_data/ch11/test/part-00000 2rd_data/ch11/training/part-00000 output/ch11/label 23937 50 0.01 100 10
val conf = new SparkConf()
.setMaster("local")
.setAppName("ADTest with logistic regression")
val sc = new SparkContext(conf)
val numFeatures = args(3).toInt //特征数23937
val partitions = args(4).toInt //一般50-1000
//label channal source tags
//依次为:类别(是否点击,点击为1,没有点击为0)、频道、来源、关键词
//样例:1 娱乐 腾讯娱乐 曲妖精|棉袄|王子文|老大爷|黑色
// 导入训练样本和测试样本
val training = MLUtils.loadLibSVMFile(sc,args(0),numFeatures,partitions)
val test = MLUtils.loadLibSVMFile(sc,args(1),numFeatures,partitions)
val lr = new LogisticRegressionWithLBFGS()
//训练参数设置
lr.optimizer.setRegParam(args(5).toDouble) //0.01
.setNumIterations(args(6).toInt) //100
.setNumCorrections(args(7).toInt) //10
//训练
val lrModel = lr.setNumClasses(2).run(training)//2分类
lrModel.clearThreshold()
//预测打分
val predictionAndLabel = test.map(p=>(lrModel.predict(p.features),p.label))
predictionAndLabel.map(x=>x._1+"\t"+x._2).repartition(1)
.saveAsTextFile(args(2))
val metrics = new BinaryClassificationMetrics(predictionAndLabel)
//计算AUC
val str = s"the value of auc is ${metrics.areaUnderROC()}"
println(str)
}
}