Hi,

I have a web service that provides rest api to train random forest algo.
I train random forest on a 5 nodes spark cluster with enough memory -
everything is cached (~22 GB).
On a small datasets up to 100k samples everything is fine, but with the
biggest one (400k samples and ~70k features) I'm stuck with
StackOverflowError.

Additional options for my web service
    spark.executor.extraJavaOptions="-XX:ThreadStackSize=8192"
    spark.default.parallelism = 200.

On a 400k samples dataset
- (with default thread stack size) it took 4 hours of training to get the
error.
- with increased stack size it took 60 hours to hit it.
I can increase it, but it's hard to say what amount of memory it needs and
it's applied to all of the treads and might waste a lot of memory.

I'm looking at different stages at event timeline now and see that task
deserialization time gradually increases. And at the end task
deserialization time is roughly same as executor computing time.

Code I use to train model:

int MAX_BINS = 16;
int NUM_CLASSES = 0;
double MIN_INFO_GAIN = 0.0;
int MAX_MEMORY_IN_MB = 256;
double SUBSAMPLING_RATE = 1.0;
boolean USE_NODEID_CACHE = true;
int CHECKPOINT_INTERVAL = 10;
int RANDOM_SEED = 12345;

int NODE_SIZE = 5;
int maxDepth = 30;
int numTrees = 50;
Strategy strategy = new Strategy(Algo.Regression(),
Variance.instance(), maxDepth, NUM_CLASSES, MAX_BINS,
        QuantileStrategy.Sort(), new
scala.collection.immutable.HashMap<>(), nodeSize, MIN_INFO_GAIN,
        MAX_MEMORY_IN_MB, SUBSAMPLING_RATE, USE_NODEID_CACHE,
CHECKPOINT_INTERVAL);
RandomForestModel model =
RandomForest.trainRegressor(labeledPoints.rdd(), strategy, numTrees,
"auto", RANDOM_SEED);


Any advice would be highly appreciated.

The exception (~3000 lines long):
 java.lang.StackOverflowError
        at
java.io.ObjectInputStream$PeekInputStream.read(ObjectInputStream.java:2320)
        at
java.io.ObjectInputStream$PeekInputStream.readFully(ObjectInputStream.java:2333)
        at
java.io.ObjectInputStream$BlockDataInputStream.readInt(ObjectInputStream.java:2828)
        at java.io.ObjectInputStream.readHandle(ObjectInputStream.java:1453)
        at
java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1512)
        at
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1774)
        at
java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1351)
        at
java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2000)
        at
java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1924)
        at
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1801)
        at
java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1351)
        at java.io.ObjectInputStream.readObject(ObjectInputStream.java:371)
        at
scala.collection.immutable.$colon$colon.readObject(List.scala:366)
        at sun.reflect.GeneratedMethodAccessor3.invoke(Unknown Source)
        at
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:497)
        at
java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058)
        at
java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1900)
        at
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1801)
        at
java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1351)
        at
java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2000)
        at
java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1924)
        at
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1801)
        at
java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1351)
        at
java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2000)
        at
java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1924)
        at
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1801)
        at
java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1351)
        at java.io.ObjectInputStream.readObject(ObjectInputStream.java:371)
        at
scala.collection.immutable.$colon$colon.readObject(List.scala:362)
        at sun.reflect.GeneratedMethodAccessor3.invoke(Unknown Source)
        at
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:497)
        at
java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058)
        at
java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1900)
        at
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1801)
        at
java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1351)
        at
java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2000)
        at
java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1924)

--
Be well!
Jean Morozov

Reply via email to