Hey guys, I'm running into some pretty bad performance issues when it comes to using a CrossValidator, because of caching behavior of DataFrames.
The root of the problem is that while I have cached my DataFrame representing the features and labels, it is caching at the DataFrame level, while CrossValidator/LogisticRegression both drop down to the dataset.rdd level, which ignores the caching that I have previously done. This is worsened by the fact that for each combination of a fold and a param set from the grid, it recomputes my entire input dataset because the caching was lost. My current solution is to force the input DataFrame to be based off of a cached RDD, which I did with this horrible hack (had to drop down to java from the pyspark because of something to do with vectors not be inferred correctly): def checkpoint_dataframe_caching(df): return DataFrame(sqlContext._ssql_ctx.createDataFrame(df._jdf.rdd().cache(), train_data._jdf.schema()), sqlContext) before I pass it into the CrossValidator.fit(). If I do this, I still have to cache the underlying rdd once more than necessary (in addition to DataFrame#cache()), but at least in cross validation, it doesn't recompute the RDD graph anymore. Note, that input_df.rdd.cache() doesn't work because the python CrossValidator implementation applies some more dataframe transformations like filter, which then causes filtered_df.rdd to return a completely different rdd that recomputes the entire graph. Is it the intention of Spark SQL that calling DataFrame#rdd removes any caching that was done for the query? Is the fix as simple as getting the DataFrame#rdd to reference the cached query, or is there something more subtle going on. Best, Justin