http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java deleted file mode 100644 index 0f3d974..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.trainers.distributed; - -import java.io.Serializable; -import java.util.List; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; -import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; - -/** Multilayer perceptron group update training loop data. */ -public class MLPGroupUpdateTrainingLoopData<P> implements Serializable { - /** */ - private final ParameterUpdateCalculator<MultilayerPerceptron, P> updateCalculator; - /** */ - private final int stepsCnt; - /** */ - private final IgniteFunction<List<P>, P> updateReducer; - /** */ - private final P previousUpdate; - /** */ - private final IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier; - /** */ - private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; - /** */ - private final double tolerance; - - /** */ - private final GroupTrainerCacheKey<Void> key; - /** */ - private final MultilayerPerceptron mlp; - - /** Create multilayer perceptron group update training loop data. */ - public MLPGroupUpdateTrainingLoopData(MultilayerPerceptron mlp, - ParameterUpdateCalculator<MultilayerPerceptron, P> updateCalculator, int stepsCnt, - IgniteFunction<List<P>, P> updateReducer, P previousUpdate, - GroupTrainerCacheKey<Void> key, IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier, - IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, - double tolerance) { - this.mlp = mlp; - this.updateCalculator = updateCalculator; - this.stepsCnt = stepsCnt; - this.updateReducer = updateReducer; - this.previousUpdate = previousUpdate; - this.key = key; - this.batchSupplier = batchSupplier; - this.loss = loss; - this.tolerance = tolerance; - } - - /** Get perceptron. */ - public MultilayerPerceptron mlp() { - return mlp; - } - - /** Get update calculator. */ - public ParameterUpdateCalculator<MultilayerPerceptron, P> updateCalculator() { - return updateCalculator; - } - - /** Get steps count. */ - public int stepsCnt() { - return stepsCnt; - } - - /** Get update reducer. */ - public IgniteFunction<List<P>, P> getUpdateReducer() { - return updateReducer; - } - - /** Get previous update. */ - public P previousUpdate() { - return previousUpdate; - } - - /** Get group trainer cache key. */ - public GroupTrainerCacheKey<Void> key() { - return key; - } - - /** Get batch supplier. */ - public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() { - return batchSupplier; - } - - /** Get loss function. */ - public IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss() { - return loss; - } - - /** Get tolerance. */ - public double tolerance() { - return tolerance; - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java deleted file mode 100644 index 249136b..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.trainers.distributed; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.trainers.group.Metaoptimizer; - -/** Meta-optimizer for multilayer perceptron. */ -public class MLPMetaoptimizer<P> implements Metaoptimizer<MLPGroupUpdateTrainerLocalContext, - MLPGroupUpdateTrainingLoopData<P>, P, P, P, ArrayList<P>> { - /** */ - private final IgniteFunction<List<P>, P> allUpdatesReducer; - - /** Construct metaoptimizer. */ - public MLPMetaoptimizer(IgniteFunction<List<P>, P> allUpdatesReducer) { - this.allUpdatesReducer = allUpdatesReducer; - } - - /** {@inheritDoc} */ - @Override public IgniteFunction<List<P>, P> initialReducer() { - return allUpdatesReducer; - } - - /** {@inheritDoc} */ - @Override public P locallyProcessInitData(P data, MLPGroupUpdateTrainerLocalContext locCtx) { - return data; - } - - /** {@inheritDoc} */ - @Override public IgniteFunction<P, ArrayList<P>> distributedPostprocessor() { - return p -> { - ArrayList<P> res = new ArrayList<>(); - res.add(p); - return res; - }; - } - - /** {@inheritDoc} */ - @Override public IgniteFunction<List<ArrayList<P>>, ArrayList<P>> postProcessReducer() { - // Flatten. - return lists -> new ArrayList<>(lists.stream() - .flatMap(List::stream) - .collect(Collectors.toList())); - } - - /** {@inheritDoc} */ - @Override public P localProcessor(ArrayList<P> input, MLPGroupUpdateTrainerLocalContext locCtx) { - locCtx.incrementCurrentStep(); - - return allUpdatesReducer.apply(input); - } - - /** {@inheritDoc} */ - @Override public boolean shouldContinue(P input, MLPGroupUpdateTrainerLocalContext locCtx) { - return input != null && locCtx.currentStep() < locCtx.globalStepsMaxCount(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java index 8579b82..64a1956 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java @@ -28,18 +28,19 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.math.util.MatrixUtil; import org.apache.ignite.ml.nn.LocalBatchTrainerInput; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; +import org.apache.ignite.ml.nn.updaters.ParameterUpdater; +import org.apache.ignite.ml.nn.updaters.UpdaterParams; /** * Batch trainer. This trainer is not distributed on the cluster, but input can theoretically read data from * Ignite cache. */ -public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> +public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P extends UpdaterParams<? super M>> implements Trainer<M, LocalBatchTrainerInput<M>> { /** * Supplier for updater function. */ - private final IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier; + private final IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier; /** * Error threshold. @@ -70,7 +71,7 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> * @param maxIterations Maximal iterations count. */ public LocalBatchTrainer(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, - IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier, double errorThreshold, int maxIterations) { + IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier, double errorThreshold, int maxIterations) { this.loss = loss; this.updaterSupplier = updaterSupplier; this.errorThreshold = errorThreshold; @@ -83,19 +84,19 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> M mdl = data.mdl(); double err; - ParameterUpdateCalculator<? super M, P> updater = updaterSupplier.get(); + ParameterUpdater<? super M, P> updater = updaterSupplier.get(); P updaterParams = updater.init(mdl, loss); while (i < maxIterations) { - IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get(); + IgniteBiTuple<Matrix, Matrix> batch = data.getBatch(); Matrix input = batch.get1(); Matrix truth = batch.get2(); - updaterParams = updater.calculateNewUpdate(mdl, updaterParams, i, input, truth); + updaterParams = updater.updateParams(mdl, updaterParams, i, input, truth); // Update mdl with updater parameters. - mdl = updater.update(mdl, updaterParams); + mdl = updaterParams.update(mdl); Matrix predicted = mdl.apply(input); @@ -131,7 +132,7 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> * @param updaterSupplier New updater supplier. * @return new trainer with the same parameters as this trainer, but with new updater supplier. */ - public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier) { + public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier) { return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations); } http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java index 0c92395..7065e2f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java @@ -23,16 +23,17 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.nn.LossFunctions; import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator; -import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate; -import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator; +import org.apache.ignite.ml.nn.updaters.ParameterUpdater; +import org.apache.ignite.ml.nn.updaters.RPropUpdater; +import org.apache.ignite.ml.nn.updaters.RPropUpdaterParams; +import org.apache.ignite.ml.nn.updaters.UpdaterParams; /** * Local batch trainer for MLP. * * @param <P> Parameter updater parameters. */ -public class MLPLocalBatchTrainer<P> +public class MLPLocalBatchTrainer<P extends UpdaterParams<? super MultilayerPerceptron>> extends LocalBatchTrainer<MultilayerPerceptron, P> { /** * Default loss function. @@ -50,6 +51,7 @@ public class MLPLocalBatchTrainer<P> */ private static final int DEFAULT_MAX_ITERATIONS = 100; + /** * Construct a trainer. * @@ -60,7 +62,7 @@ public class MLPLocalBatchTrainer<P> */ public MLPLocalBatchTrainer( IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, - IgniteSupplier<ParameterUpdateCalculator<MultilayerPerceptron, P>> updaterSupplier, + IgniteSupplier<ParameterUpdater<? super MultilayerPerceptron, P>> updaterSupplier, double errorThreshold, int maxIterations) { super(loss, updaterSupplier, errorThreshold, maxIterations); } @@ -70,8 +72,7 @@ public class MLPLocalBatchTrainer<P> * * @return MLPLocalBatchTrainer with default parameters. */ - public static MLPLocalBatchTrainer<RPropParameterUpdate> getDefault() { - return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, () -> new RPropUpdateCalculator<>(), DEFAULT_ERROR_THRESHOLD, - DEFAULT_MAX_ITERATIONS); + public static MLPLocalBatchTrainer<RPropUpdaterParams> getDefault() { + return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, RPropUpdater::new, DEFAULT_ERROR_THRESHOLD, DEFAULT_MAX_ITERATIONS); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java index 8e2f0df..b33c2c7 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java @@ -17,7 +17,6 @@ package org.apache.ignite.ml.nn.updaters; -import org.apache.ignite.ml.Model; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; @@ -26,7 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; /** * Interface for models which are smooth functions of their parameters. */ -interface BaseSmoothParametrized<M extends BaseSmoothParametrized<M> & Model<Matrix, Matrix>> { +interface BaseSmoothParametrized<M extends BaseSmoothParametrized<M>> { /** * Compose function in the following way: feed output of this model as input to second argument to loss function. * After that we have a function g of three arguments: input, ground truth, parameters. @@ -40,8 +39,7 @@ interface BaseSmoothParametrized<M extends BaseSmoothParametrized<M> & Model<Mat * @param truthBatch Batch of ground truths. * @return Gradient of h at current point in parameters space. */ - Vector differentiateByParameters(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, - Matrix inputsBatch, Matrix truthBatch); + Vector differentiateByParameters(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, Matrix inputsBatch, Matrix truthBatch); /** * Get parameters vector. http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java deleted file mode 100644 index 8671285..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovParameterUpdate.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import java.io.Serializable; -import java.util.List; -import java.util.Objects; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -/** - * Data needed for Nesterov parameters updater. - */ -public class NesterovParameterUpdate implements Serializable { - /** - * Previous step weights updates. - */ - protected Vector prevIterationUpdates; - - /** - * Construct NesterovParameterUpdate. - * - * @param paramsCnt Count of parameters on which updateCache happens. - */ - public NesterovParameterUpdate(int paramsCnt) { - prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt).assign(0); - } - - /** - * Construct NesterovParameterUpdate. - * - * @param prevIterationUpdates Previous iteration updates. - */ - public NesterovParameterUpdate(Vector prevIterationUpdates) { - this.prevIterationUpdates = prevIterationUpdates; - } - - /** - * Set previous step parameters updates. - * - * @param updates Parameters updates. - * @return This object with updated parameters updates. - */ - public NesterovParameterUpdate setPreviousUpdates(Vector updates) { - prevIterationUpdates = updates; - return this; - } - - /** - * Get previous step parameters updates. - * - * @return Previous step parameters updates. - */ - public Vector prevIterationUpdates() { - return prevIterationUpdates; - } - - /** - * Get sum of parameters updates. - * - * @param parameters Parameters to sum. - * @return Sum of parameters updates. - */ - public static NesterovParameterUpdate sum(List<NesterovParameterUpdate> parameters) { - return parameters.stream().filter(Objects::nonNull).map(NesterovParameterUpdate::prevIterationUpdates) - .reduce(Vector::plus).map(NesterovParameterUpdate::new).orElse(null); - } - - /** - * Get average of parameters updates. - * - * @param parameters Parameters to average. - * @return Average of parameters updates. - */ - public static NesterovParameterUpdate avg(List<NesterovParameterUpdate> parameters) { - NesterovParameterUpdate sum = sum(parameters); - return sum != null ? sum.setPreviousUpdates(sum.prevIterationUpdates().divide(parameters.size())) : null; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java deleted file mode 100644 index 5caddd4..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdateCalculator.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; - -/** - * Class encapsulating Nesterov algorithm for MLP parameters updateCache. - */ -public class NesterovUpdateCalculator<M extends SmoothParametrized> - implements ParameterUpdateCalculator<M, NesterovParameterUpdate> { - /** - * Learning rate. - */ - private final double learningRate; - - /** - * Loss function. - */ - private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; - - /** - * Momentum constant. - */ - protected double momentum; - - /** - * Construct NesterovUpdateCalculator. - * - * @param momentum Momentum constant. - */ - public NesterovUpdateCalculator(double learningRate, double momentum) { - this.learningRate = learningRate; - this.momentum = momentum; - } - - /** {@inheritDoc} */ - @Override public NesterovParameterUpdate calculateNewUpdate(SmoothParametrized mdl, - NesterovParameterUpdate updaterParameters, int iteration, Matrix inputs, Matrix groundTruth) { - // TODO:IGNITE-7350 create new updateCache object here instead of in-place change. - - if (iteration > 0) { - Vector curParams = mdl.parameters(); - mdl.setParameters(curParams.minus(updaterParameters.prevIterationUpdates().times(momentum))); - } - - Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); - updaterParameters.setPreviousUpdates(updaterParameters.prevIterationUpdates() - .plus(gradient.times(learningRate))); - - return updaterParameters; - } - - /** {@inheritDoc} */ - @Override public NesterovParameterUpdate init(M mdl, - IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { - this.loss = loss; - - return new NesterovParameterUpdate(mdl.parametersCount()); - } - - /** {@inheritDoc} */ - @Override public <M1 extends M> M1 update(M1 obj, NesterovParameterUpdate update) { - Vector parameters = obj.parameters(); - return (M1)obj.setParameters(parameters.minus(update.prevIterationUpdates())); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java new file mode 100644 index 0000000..7b6a0c7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Class encapsulating Nesterov algorithm for MLP parameters update. + */ +public class NesterovUpdater implements ParameterUpdater<SmoothParametrized, NesterovUpdaterParams> { + /** + * Learning rate. + */ + private final double learningRate; + + /** + * Loss function. + */ + private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Momentum constant. + */ + protected double momentum; + + /** + * Construct NesterovUpdater. + * + * @param momentum Momentum constant. + */ + public NesterovUpdater(double learningRate, double momentum) { + this.learningRate = learningRate; + this.momentum = momentum; + } + + /** {@inheritDoc} */ + @Override public NesterovUpdaterParams init(SmoothParametrized mdl, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + + return new NesterovUpdaterParams(mdl.parametersCount()); + } + + /** {@inheritDoc} */ + @Override public NesterovUpdaterParams updateParams(SmoothParametrized mdl, NesterovUpdaterParams updaterParameters, + int iteration, Matrix inputs, Matrix groundTruth) { + + if (iteration > 0) { + Vector curParams = mdl.parameters(); + mdl.setParameters(curParams.minus(updaterParameters.prevIterationUpdates().times(momentum))); + } + + Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); + updaterParameters.setPreviousUpdates(updaterParameters.prevIterationUpdates().plus(gradient.times(learningRate))); + + return updaterParameters; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java new file mode 100644 index 0000000..d403ea1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Data needed for Nesterov parameters updater. + */ +public class NesterovUpdaterParams implements UpdaterParams<SmoothParametrized> { + /** + * Previous step weights updates. + */ + protected Vector prevIterationUpdates; + + /** + * Construct NesterovUpdaterParams. + * + * @param paramsCnt Count of parameters on which update happens. + */ + public NesterovUpdaterParams(int paramsCnt) { + prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt).assign(0); + } + + /** + * Set previous step parameters updates. + * + * @param updates Parameters updates. + * @return This object with updated parameters updates. + */ + public NesterovUpdaterParams setPreviousUpdates(Vector updates) { + prevIterationUpdates = updates; + return this; + } + + /** + * Get previous step parameters updates. + * + * @return Previous step parameters updates. + */ + public Vector prevIterationUpdates() { + return prevIterationUpdates; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public <M extends SmoothParametrized> M update(M obj) { + Vector parameters = obj.parameters(); + return (M)obj.setParameters(parameters.minus(prevIterationUpdates)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java deleted file mode 100644 index 77e3763..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdateCalculator.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; - -/** - * Interface for classes encapsulating parameters updateCache logic. - * - * @param <M> Type of model to be updated. - * @param <P> Type of parameters needed for this updater. - */ -public interface ParameterUpdateCalculator<M, P> { - /** - * Initializes the updater. - * - * @param mdl Model to be trained. - * @param loss Loss function. - */ - P init(M mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss); - - /** - * Calculate new updateCache. - * - * @param mdl Model to be updated. - * @param updaterParameters Updater parameters to updateCache. - * @param iteration Current trainer iteration. - * @param inputs Inputs. - * @param groundTruth True values. - * @return Updated parameters. - */ - P calculateNewUpdate(M mdl, P updaterParameters, int iteration, Matrix inputs, Matrix groundTruth); - - /** - * Update given obj with this parameters. - * - * @param obj Object to be updated. - */ - <M1 extends M> M1 update(M1 obj, P update); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java new file mode 100644 index 0000000..e8e28fd --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Interface for classes encapsulating parameters update logic. + * + * @param <M> Type of model to be updated. + * @param <P> Type of parameters needed for this updater. + */ +public interface ParameterUpdater<M, P extends UpdaterParams> { + /** + * Initializes the updater. + * + * @param mdl Model to be trained. + * @param loss Loss function. + */ + P init(M mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss); + + /** + * Update updater parameters. + * + * @param mdl Model to be updated. + * @param updaterParameters Updater parameters to update. + * @param iteration Current trainer iteration. + * @param inputs Inputs. + * @param groundTruth True values. + * @return Updated parameters. + */ + P updateParams(M mdl, P updaterParameters, int iteration, Matrix inputs, Matrix groundTruth); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java deleted file mode 100644 index e2fa4d5..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropParameterUpdate.java +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import java.io.Serializable; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -/** - * Data needed for RProp updater. - * <p> - * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p> - */ -public class RPropParameterUpdate implements Serializable { - /** - * Previous iteration parameters updates. In original paper they are labeled with "delta w". - */ - protected Vector prevIterationUpdates; - - /** - * Previous iteration model partial derivatives by parameters. - */ - protected Vector prevIterationGradient; - /** - * Previous iteration parameters deltas. In original paper they are labeled with "delta". - */ - protected Vector deltas; - - /** - * Updates mask (values by which updateCache is multiplied). - */ - protected Vector updatesMask; - - /** - * Construct RPropParameterUpdate. - * - * @param paramsCnt Parameters count. - * @param initUpdate Initial updateCache (in original work labeled as "delta_0"). - */ - RPropParameterUpdate(int paramsCnt, double initUpdate) { - prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt); - prevIterationGradient = new DenseLocalOnHeapVector(paramsCnt); - deltas = new DenseLocalOnHeapVector(paramsCnt).assign(initUpdate); - updatesMask = new DenseLocalOnHeapVector(paramsCnt); - } - - /** - * Construct instance of this class by given parameters. - * - * @param prevIterationUpdates Previous iteration parameters updates. - * @param prevIterationGradient Previous iteration model partial derivatives by parameters. - * @param deltas Previous iteration parameters deltas. - * @param updatesMask Updates mask. - */ - public RPropParameterUpdate(Vector prevIterationUpdates, Vector prevIterationGradient, - Vector deltas, Vector updatesMask) { - this.prevIterationUpdates = prevIterationUpdates; - this.prevIterationGradient = prevIterationGradient; - this.deltas = deltas; - this.updatesMask = updatesMask; - } - - /** - * Get bias deltas. - * - * @return Bias deltas. - */ - Vector deltas() { - return deltas; - } - - /** - * Get previous iteration biases updates. In original paper they are labeled with "delta w". - * - * @return Biases updates. - */ - Vector prevIterationUpdates() { - return prevIterationUpdates; - } - - /** - * Set previous iteration parameters updates. In original paper they are labeled with "delta w". - * - * @param updates New parameters updates value. - * @return This object. - */ - private RPropParameterUpdate setPrevIterationUpdates(Vector updates) { - prevIterationUpdates = updates; - - return this; - } - - /** - * Get previous iteration loss function partial derivatives by parameters. - * - * @return Previous iteration loss function partial derivatives by parameters. - */ - Vector prevIterationGradient() { - return prevIterationGradient; - } - - /** - * Set previous iteration loss function partial derivatives by parameters. - * - * @return This object. - */ - private RPropParameterUpdate setPrevIterationGradient(Vector gradient) { - prevIterationGradient = gradient; - return this; - } - - /** - * Get updates mask (values by which updateCache is multiplied). - * - * @return Updates mask (values by which updateCache is multiplied). - */ - public Vector updatesMask() { - return updatesMask; - } - - /** - * Set updates mask (values by which updateCache is multiplied). - * - * @param updatesMask New updatesMask. - * @return This object. - */ - public RPropParameterUpdate setUpdatesMask(Vector updatesMask) { - this.updatesMask = updatesMask; - - return this; - } - - /** - * Set previous iteration deltas. - * - * @param deltas New deltas. - * @return This object. - */ - public RPropParameterUpdate setDeltas(Vector deltas) { - this.deltas = deltas; - - return this; - } - - /** - * Sums updates during one training. - * - * @param updates Updates. - * @return Sum of updates during one training. - */ - public static RPropParameterUpdate sumLocal(List<RPropParameterUpdate> updates) { - List<RPropParameterUpdate> nonNullUpdates = updates.stream().filter(Objects::nonNull) - .collect(Collectors.toList()); - - if (nonNullUpdates.isEmpty()) - return null; - - Vector newDeltas = nonNullUpdates.get(nonNullUpdates.size() - 1).deltas(); - Vector newGradient = nonNullUpdates.get(nonNullUpdates.size() - 1).prevIterationGradient(); - Vector totalUpdate = nonNullUpdates.stream().map(pu -> VectorUtils.elementWiseTimes(pu.updatesMask().copy(), - pu.prevIterationUpdates())).reduce(Vector::plus).orElse(null); - - return new RPropParameterUpdate(totalUpdate, newGradient, newDeltas, - new DenseLocalOnHeapVector(newDeltas.size()).assign(1.0)); - } - - /** - * Sums updates returned by different trainings. - * - * @param updates Updates. - * @return Sum of updates during returned by different trainings. - */ - public static RPropParameterUpdate sum(List<RPropParameterUpdate> updates) { - Vector totalUpdate = updates.stream().filter(Objects::nonNull) - .map(pu -> VectorUtils.elementWiseTimes(pu.updatesMask().copy(), pu.prevIterationUpdates())) - .reduce(Vector::plus).orElse(null); - Vector totalDelta = updates.stream().filter(Objects::nonNull) - .map(RPropParameterUpdate::deltas).reduce(Vector::plus).orElse(null); - Vector totalGradient = updates.stream().filter(Objects::nonNull) - .map(RPropParameterUpdate::prevIterationGradient).reduce(Vector::plus).orElse(null); - - if (totalUpdate != null) - return new RPropParameterUpdate(totalUpdate, totalGradient, totalDelta, - new DenseLocalOnHeapVector(Objects.requireNonNull(totalDelta).size()).assign(1.0)); - - return null; - } - - /** - * Averages updates returned by different trainings. - * - * @param updates Updates. - * @return Averages of updates during returned by different trainings. - */ - public static RPropParameterUpdate avg(List<RPropParameterUpdate> updates) { - List<RPropParameterUpdate> nonNullUpdates = updates.stream() - .filter(Objects::nonNull).collect(Collectors.toList()); - int size = nonNullUpdates.size(); - - RPropParameterUpdate sum = sum(updates); - if (sum != null) - return sum. - setPrevIterationGradient(sum.prevIterationGradient().divide(size)). - setPrevIterationUpdates(sum.prevIterationUpdates().divide(size)). - setDeltas(sum.deltas().divide(size)); - - return null; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java deleted file mode 100644 index 99f39c9..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdateCalculator.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.VectorUtils; -import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.util.MatrixUtil; - -/** - * Class encapsulating RProp algorithm. - * <p> - * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p> - */ -public class RPropUpdateCalculator<M extends SmoothParametrized> implements ParameterUpdateCalculator<M, RPropParameterUpdate> { - /** - * Default initial update. - */ - private static double DFLT_INIT_UPDATE = 0.1; - - /** - * Default acceleration rate. - */ - private static double DFLT_ACCELERATION_RATE = 1.2; - - /** - * Default deacceleration rate. - */ - private static double DFLT_DEACCELERATION_RATE = 0.5; - - /** - * Initial update. - */ - private final double initUpdate; - - /** - * Acceleration rate. - */ - private final double accelerationRate; - - /** - * Deacceleration rate. - */ - private final double deaccelerationRate; - - /** - * Maximal value for update. - */ - private final static double UPDATE_MAX = 50.0; - - /** - * Minimal value for update. - */ - private final static double UPDATE_MIN = 1E-6; - - /** - * Loss function. - */ - protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; - - /** - * Construct RPropUpdateCalculator. - * - * @param initUpdate Initial update. - * @param accelerationRate Acceleration rate. - * @param deaccelerationRate Deacceleration rate. - */ - public RPropUpdateCalculator(double initUpdate, double accelerationRate, double deaccelerationRate) { - this.initUpdate = initUpdate; - this.accelerationRate = accelerationRate; - this.deaccelerationRate = deaccelerationRate; - } - - /** - * Construct RPropUpdateCalculator with default parameters. - */ - public RPropUpdateCalculator() { - this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE); - } - - /** {@inheritDoc} */ - @Override public RPropParameterUpdate calculateNewUpdate(SmoothParametrized mdl, RPropParameterUpdate updaterParams, - int iteration, Matrix inputs, Matrix groundTruth) { - Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); - Vector prevGradient = updaterParams.prevIterationGradient(); - Vector derSigns; - - if (prevGradient != null) - derSigns = VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y)); - else - derSigns = gradient.like(gradient.size()).assign(1.0); - - Vector newDeltas = updaterParams.deltas().copy().map(derSigns, (prevDelta, sign) -> { - if (sign > 0) - return Math.min(prevDelta * accelerationRate, UPDATE_MAX); - else if (sign < 0) - return Math.max(prevDelta * deaccelerationRate, UPDATE_MIN); - else - return prevDelta; - }); - - Vector newPrevIterationUpdates = MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> { - if (derSigns.getX(i) >= 0) - return -Math.signum(der) * delta; - - return updaterParams.prevIterationUpdates().getX(i); - }); - - Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> { - if (sign < 0) - gradient.setX(i, 0.0); - - if (sign >= 0) - return 1.0; - else - return -1.0; - }); - - return new RPropParameterUpdate(newPrevIterationUpdates, gradient.copy(), newDeltas, updatesMask); - } - - /** {@inheritDoc} */ - @Override public RPropParameterUpdate init(M mdl, - IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { - this.loss = loss; - return new RPropParameterUpdate(mdl.parametersCount(), initUpdate); - } - - /** {@inheritDoc} */ - @Override public <M1 extends M> M1 update(M1 obj, RPropParameterUpdate update) { - Vector updatesToAdd = VectorUtils.elementWiseTimes(update.updatesMask().copy(), update.prevIterationUpdates()); - return (M1)obj.setParameters(obj.parameters().plus(updatesToAdd)); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java new file mode 100644 index 0000000..37963b4 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Class encapsulating RProp algorithm. + * <p> + * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p> + */ +public class RPropUpdater implements ParameterUpdater<SmoothParametrized, RPropUpdaterParams> { + /** + * Default initial update. + */ + private static double DFLT_INIT_UPDATE = 0.1; + + /** + * Default acceleration rate. + */ + private static double DFLT_ACCELERATION_RATE = 1.2; + + /** + * Default deacceleration rate. + */ + private static double DFLT_DEACCELERATION_RATE = 0.5; + + /** + * Initial update. + */ + private final double initUpdate; + + /** + * Acceleration rate. + */ + private final double accelerationRate; + + /** + * Deacceleration rate. + */ + private final double deaccelerationRate; + + /** + * Maximal value for update. + */ + private final static double UPDATE_MAX = 50.0; + + /** + * Minimal value for update. + */ + private final static double UPDATE_MIN = 1E-6; + + /** + * Loss function. + */ + protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Construct RPropUpdater. + * + * @param initUpdate Initial update. + * @param accelerationRate Acceleration rate. + * @param deaccelerationRate Deacceleration rate. + */ + public RPropUpdater(double initUpdate, double accelerationRate, double deaccelerationRate) { + this.initUpdate = initUpdate; + this.accelerationRate = accelerationRate; + this.deaccelerationRate = deaccelerationRate; + } + + /** + * Construct RPropUpdater with default parameters. + */ + public RPropUpdater() { + this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE); + } + + /** {@inheritDoc} */ + @Override public RPropUpdaterParams init(SmoothParametrized mdl, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + return new RPropUpdaterParams(mdl.parametersCount(), initUpdate); + } + + /** {@inheritDoc} */ + @Override public RPropUpdaterParams updateParams(SmoothParametrized mdl, RPropUpdaterParams updaterParams, + int iteration, Matrix inputs, Matrix groundTruth) { + Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); + Vector prevGradient = updaterParams.prevIterationGradient(); + Vector derSigns; + + if (prevGradient != null) + derSigns = VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y)); + else + derSigns = gradient.like(gradient.size()).assign(1.0); + + updaterParams.deltas().map(derSigns, (prevDelta, sign) -> { + if (sign > 0) + return Math.min(prevDelta * accelerationRate, UPDATE_MAX); + else if (sign < 0) + return Math.max(prevDelta * deaccelerationRate, UPDATE_MIN); + else + return prevDelta; + }); + + updaterParams.setPrevIterationBiasesUpdates(MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> { + if (derSigns.getX(i) >= 0) + return -Math.signum(der) * delta; + + return updaterParams.prevIterationUpdates().getX(i); + })); + + Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> { + if (sign < 0) + gradient.setX(i, 0.0); + + if (sign >= 0) + return 1.0; + else + return -1.0; + }); + + updaterParams.setUpdatesMask(updatesMask); + updaterParams.setPrevIterationWeightsDerivatives(gradient.copy()); + + return updaterParams; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java new file mode 100644 index 0000000..080e809 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Data needed for RProp updater. + * <p> + * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p> + */ +public class RPropUpdaterParams implements UpdaterParams<SmoothParametrized> { + /** + * Previous iteration weights updates. In original paper they are labeled with "delta w". + */ + protected Vector prevIterationUpdates; + + /** + * Previous iteration model partial derivatives by parameters. + */ + protected Vector prevIterationGradient; + /** + * Previous iteration parameters deltas. In original paper they are labeled with "delta". + */ + protected Vector deltas; + + /** + * Updates mask (values by which update is multiplied). + */ + protected Vector updatesMask; + + /** + * Construct RPropUpdaterParams. + * + * @param paramsCnt Parameters count. + * @param initUpdate Initial update (in original work labeled as "delta_0"). + */ + RPropUpdaterParams(int paramsCnt, double initUpdate) { + prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt); + prevIterationGradient = new DenseLocalOnHeapVector(paramsCnt); + deltas = new DenseLocalOnHeapVector(paramsCnt).assign(initUpdate); + updatesMask = new DenseLocalOnHeapVector(paramsCnt); + } + + /** + * Get bias deltas. + * + * @return Bias deltas. + */ + Vector deltas() { + return deltas; + } + + /** + * Get previous iteration biases updates. In original paper they are labeled with "delta w". + * + * @return Biases updates. + */ + Vector prevIterationUpdates() { + return prevIterationUpdates; + } + + /** + * Set previous iteration parameters updates. In original paper they are labeled with "delta w". + * + * @param updates New parameters updates value. + * @return This object. + */ + Vector setPrevIterationBiasesUpdates(Vector updates) { + return prevIterationUpdates = updates; + } + + /** + * Get previous iteration loss function partial derivatives by parameters. + * + * @return Previous iteration loss function partial derivatives by parameters. + */ + Vector prevIterationGradient() { + return prevIterationGradient; + } + + /** + * Set previous iteration loss function partial derivatives by parameters. + * + * @return This object. + */ + RPropUpdaterParams setPrevIterationWeightsDerivatives(Vector gradient) { + prevIterationGradient = gradient; + return this; + } + + /** + * Get updates mask (values by which update is multiplied). + * + * @return Updates mask (values by which update is multiplied). + */ + public Vector updatesMask() { + return updatesMask; + } + + /** + * Set updates mask (values by which update is multiplied). + * + * @param updatesMask New updatesMask. + */ + public RPropUpdaterParams setUpdatesMask(Vector updatesMask) { + this.updatesMask = updatesMask; + + return this; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public <M extends SmoothParametrized> M update(M obj) { + Vector updatesToAdd = VectorUtils.elementWiseTimes(updatesMask.copy(), prevIterationUpdates); + return (M)obj.setParameters(obj.parameters().plus(updatesToAdd)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java deleted file mode 100644 index 7159621..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParameter.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import java.io.Serializable; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -/** - * Parameters for {@link SimpleGDUpdateCalculator}. - */ -public class SimpleGDParameter implements Serializable { - /** - * Gradient. - */ - private Vector gradient; - - /** - * Learning rate. - */ - private double learningRate; - - /** - * Construct instance of this class. - * - * @param paramsCnt Count of parameters. - * @param learningRate Learning rate. - */ - public SimpleGDParameter(int paramsCnt, double learningRate) { - gradient = new DenseLocalOnHeapVector(paramsCnt); - this.learningRate = learningRate; - } - - /** - * Construct instance of this class. - * - * @param gradient Gradient. - * @param learningRate Learning rate. - */ - public SimpleGDParameter(Vector gradient, double learningRate) { - this.gradient = gradient; - this.learningRate = learningRate; - } - - /** - * Get gradient. - * - * @return Get gradient. - */ - public Vector gradient() { - return gradient; - } - - /** - * Get learning rate. - * - * @return learning rate. - */ - public double learningRate() { - return learningRate; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java new file mode 100644 index 0000000..50a120a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Parameters for {@link SimpleGDUpdater}. + */ +public class SimpleGDParams implements UpdaterParams<SmoothParametrized> { + /** + * Gradient. + */ + private Vector gradient; + + /** + * Learning rate. + */ + private double learningRate; + + /** + * Construct instance of this class. + * + * @param paramsCnt Count of parameters. + * @param learningRate Learning rate. + */ + public SimpleGDParams(int paramsCnt, double learningRate) { + gradient = new DenseLocalOnHeapVector(paramsCnt); + this.learningRate = learningRate; + } + + /** + * Construct instance of this class. + * + * @param gradient Gradient. + * @param learningRate Learning rate. + */ + public SimpleGDParams(Vector gradient, double learningRate) { + this.gradient = gradient; + this.learningRate = learningRate; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public <M extends SmoothParametrized> M update(M obj) { + Vector params = obj.parameters(); + return (M)obj.setParameters(params.minus(gradient.times(learningRate))); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java deleted file mode 100644 index d2197d9..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdateCalculator.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.nn.updaters; - -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; - -/** - * Simple gradient descent parameters updater. - */ -public class SimpleGDUpdateCalculator<M extends SmoothParametrized> implements ParameterUpdateCalculator<M, SimpleGDParameter> { - /** - * Learning rate. - */ - private double learningRate; - - /** - * Loss function. - */ - protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; - - /** - * Construct SimpleGDUpdateCalculator. - * - * @param learningRate Learning rate. - */ - public SimpleGDUpdateCalculator(double learningRate) { - this.learningRate = learningRate; - } - - /** {@inheritDoc} */ - @Override public SimpleGDParameter init(M mdl, - IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { - this.loss = loss; - return new SimpleGDParameter(mdl.parametersCount(), learningRate); - } - - /** {@inheritDoc} */ - @Override public SimpleGDParameter calculateNewUpdate(SmoothParametrized mlp, SimpleGDParameter updaterParameters, - int iteration, Matrix inputs, Matrix groundTruth) { - return new SimpleGDParameter(mlp.differentiateByParameters(loss, inputs, groundTruth), learningRate); - } - - /** {@inheritDoc} */ - @Override public <M1 extends M> M1 update(M1 obj, SimpleGDParameter update) { - Vector params = obj.parameters(); - return (M1)obj.setParameters(params.minus(update.gradient().times(learningRate))); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java new file mode 100644 index 0000000..5bf9c3f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Simple gradient descent parameters updater. + */ +public class SimpleGDUpdater implements ParameterUpdater<SmoothParametrized, SimpleGDParams> { + /** + * Learning rate. + */ + private double learningRate; + + /** + * Loss function. + */ + protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Construct SimpleGDUpdater. + * + * @param learningRate Learning rate. + */ + public SimpleGDUpdater(double learningRate) { + this.learningRate = learningRate; + } + + /** {@inheritDoc} */ + @Override public SimpleGDParams init(SmoothParametrized mlp, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + return new SimpleGDParams(mlp.parametersCount(), learningRate); + } + + /** {@inheritDoc} */ + @Override public SimpleGDParams updateParams(SmoothParametrized mlp, SimpleGDParams updaterParameters, + int iteration, Matrix inputs, Matrix groundTruth) { + return new SimpleGDParams(mlp.differentiateByParameters(loss, inputs, groundTruth), learningRate); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java index 1534a6d..5c4f59f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java @@ -17,11 +17,8 @@ package org.apache.ignite.ml.nn.updaters; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.Matrix; - /** * Interface for models which are smooth functions of their parameters. */ -public interface SmoothParametrized<M extends SmoothParametrized<M>> extends BaseSmoothParametrized<M>, Model<Matrix, Matrix> { +public interface SmoothParametrized<M extends SmoothParametrized<M>> extends BaseSmoothParametrized<M> { } http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java new file mode 100644 index 0000000..cd5bc32 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +/** + * A common interface for parameter updaters. + * + * @param <T> Type of object to be updated with this params. + */ +public interface UpdaterParams<T> { + /** + * Update given obj with this parameters. + * + * @param obj Object to be updated. + */ + <M extends T> M update(M obj); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java deleted file mode 100644 index 7540d6f..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers; - -import org.apache.ignite.ml.Model; - -/** Trainer interface. */ -public interface Trainer<M extends Model, T> { - /** Train the model based on provided data. - * - * @param data Data for training. - * @return Trained model. - */ - public M train(T data); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java deleted file mode 100644 index d70b3f1..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.io.Serializable; -import java.util.List; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteException; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.Affinity; -import org.apache.ignite.compute.ComputeJob; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; - -/** - * Base job for group training. - * It's purpose is to apply worker to each element (cache key or cache entry) of given cache specified - * by keySupplier. Worker produces {@link ResultAndUpdates} object which contains 'side effects' which are updates - * needed to apply to caches and computation result. - * After we get all {@link ResultAndUpdates} we merge all 'update' parts of them for each node - * and apply them on corresponding node, also we reduce all 'result' by some given reducer. - * - * @param <K> Type of keys of cache used for group trainer. - * @param <V> Type of values of cache used for group trainer. - * @param <T> Type of elements to which workers are applier. - * @param <R> Type of result of worker. - */ -public abstract class BaseLocalProcessorJob<K, V, T, R extends Serializable> implements ComputeJob { - /** - * UUID of group training. - */ - protected UUID trainingUUID; - - /** - * Worker. - */ - protected IgniteFunction<T, ResultAndUpdates<R>> worker; - - /** - * Supplier of keys determining elements to which worker should be applied. - */ - protected IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keySupplier; - - /** - * Operator used to reduce results from worker. - */ - protected IgniteFunction<List<R>, R> reducer; - - /** - * Name of cache used for training. - */ - protected String cacheName; - - /** - * Construct instance of this class with given arguments. - * - * @param worker Worker. - * @param keySupplier Supplier of keys. - * @param reducer Reducer. - * @param trainingUUID UUID of training. - * @param cacheName Name of cache used for training. - */ - public BaseLocalProcessorJob( - IgniteFunction<T, ResultAndUpdates<R>> worker, - IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keySupplier, - IgniteFunction<List<R>, R> reducer, - UUID trainingUUID, String cacheName) { - this.worker = worker; - this.keySupplier = keySupplier; - this.reducer = reducer; - this.trainingUUID = trainingUUID; - this.cacheName = cacheName; - } - - /** {@inheritDoc} */ - @Override public void cancel() { - // NO-OP. - } - - /** {@inheritDoc} */ - @Override public R execute() throws IgniteException { - List<ResultAndUpdates<R>> resultsAndUpdates = toProcess(). - map(worker). - collect(Collectors.toList()); - - ResultAndUpdates<R> totalRes = ResultAndUpdates.sum(reducer, resultsAndUpdates); - - totalRes.applyUpdates(ignite()); - - return totalRes.result(); - } - - /** - * Get stream of elements to process. - * - * @return Stream of elements to process. - */ - protected abstract Stream<T> toProcess(); - - /** - * Ignite instance. - * - * @return Ignite instance. - */ - protected static Ignite ignite() { - return Ignition.localIgnite(); - } - - /** - * Get cache used for training. - * - * @return Cache used for training. - */ - protected IgniteCache<GroupTrainerCacheKey<K>, V> cache() { - return ignite().getOrCreateCache(cacheName); - } - - /** - * Get affinity function for cache used in group training. - * - * @return Affinity function for cache used in group training. - */ - protected Affinity<GroupTrainerCacheKey> affinity() { - return ignite().affinity(cacheName); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/0234ee3b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java deleted file mode 100644 index 75f8179..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import org.apache.ignite.ml.Model; - -/** - * Model which outputs given constant. - * - * @param <T> Type of constant. - */ -public class ConstModel<T> implements Model<T, T> { - /** - * Constant to be returned by this model. - */ - private T c; - - /** - * Create instance of this class specified by input parameters. - * - * @param c Constant to be returned by this model. - */ - public ConstModel(T c) { - this.c = c; - } - - /** {@inheritDoc} */ - @Override public T apply(T val) { - return c; - } -}