1from pyspark.sql import SparkSession
2from pyspark.ml import Pipeline, PipelineModel
3from pyspark.ml.feature import OneHotEncoder, VectorAssembler
4from pyspark.sql import DataFrame, SparkSession
5from pyspark.sql.functions import coalesce, col, lit
6from pyspark.sql.utils import AnalysisException
7from pyspark.ml.evaluation import BinaryClassificationEvaluator, BinaryClassificationMetrics, MulticlassClassificationEvaluator
8from xgboost.spark import SparkXGBClassifier
9
10
11if __name__ == "__main__":
12 parser = argparse.ArgumentParser(description="Baseline model fit and eval step")
13 parser.add_argument(
14 "--maxdepth",
15 type=int,
16 help="Decide which featureframe to run model fit on based on --maxdepth"
17 + " used in the automatic feature engineering step",
18 )
19 args = parser.parse_args()
20 maxdepth = args.maxdepth
21
22 # Start or fetch active Spark session
23 spark = SparkSession.builder.getOrCreate()
24
25 xgb_classifier = SparkXGBClassifier(
26 max_depth=10,
27 missing=0.0,
28 n_trees=10,
29 weight_col="label",
30 validation_indicator_col="is_validation",
31 early_stopping_rounds=1,
32 eval_metric="aucpr",
33 num_workers=36,
34 label_col="label",
35 features_col="features"
36 )
37
38 model = booster.fit(train_df)
39 model.transform(train_df).select("prediction", "probability").show(truncate=False)
40
41 binaryEval = BinaryClassificationEvaluator(labelCol="label")
42 binaryMetrics = BinaryClassificationMetrics(predictions.select("prediction", "label").rdd)
43 multiEval = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
44
45 aucpr = binaryMetrics.areaUnderPR
46 fbeta2 = multiEval.evaluate(predictions, {multiEval.metricName: "fMeasureByLabel", multiEval.beta: 2.0})
47 fbeta1 = multiEval.evaluate(predictions, {multiEval.metricName: "fMeasureByLabel", multiEval.beta: 1.0})
48 fbeta05 = multiEval.evaluate(predictions, {multiEval.metricName: "fMeasureByLabel", multiEval.beta: 0.5})
49
50 print("AUC-PR: %f" % aucpr)
51 print("F-beta(2): %f" % fbeta2)
52 print("F-beta(1): %f" % fbeta1)
53 print("F-beta(0.5): %f" % fbeta05)