This is an automated email from the ASF dual-hosted git repository. indhub pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 9d4bb9c Minor changes to Caffe Translator (#8939) 9d4bb9c is described below commit 9d4bb9c53c00d3b8f8ddd1e9e92ac8cbeb885111 Author: Indhu Bharathi <indhubhara...@gmail.com> AuthorDate: Fri Dec 8 16:28:17 2017 -0800 Minor changes to Caffe Translator (#8939) * - Add license to string template files. - Add license to gradlew - Some bug fixes and refactoring for optimizer generation. - Language change in comment that goes into generated code. - Don't generate CaffeLoss layer for Accuracy layer. It is now being translated to MXNet Accuracy metrics. - Minor bug fix in searching for the correct optimizer template. - Bump verion up to 0.9.2 * Add license for Optimizer.java * Code cleanup. --- tools/caffe_translator/build.gradle | 2 +- tools/caffe_translator/gradlew | 17 +++++ .../java/io/mxnet/caffetranslator/Converter.java | 75 +++++++--------------- .../java/io/mxnet/caffetranslator/Optimizer.java | 48 ++++++++++++++ .../main/java/io/mxnet/caffetranslator/Solver.java | 52 ++++++++++++++- .../src/main/resources/templates/accuracy.st | 18 ++++++ .../src/main/resources/templates/activation.st | 18 ++++++ .../src/main/resources/templates/add.st | 18 ++++++ .../src/main/resources/templates/batchnorm.st | 18 ++++++ .../src/main/resources/templates/concat.st | 18 ++++++ .../src/main/resources/templates/convolution.st | 18 ++++++ .../src/main/resources/templates/deconvolution.st | 18 ++++++ .../src/main/resources/templates/dropout.st | 18 ++++++ .../src/main/resources/templates/fc.st | 18 ++++++ .../src/main/resources/templates/flatten.st | 18 ++++++ .../src/main/resources/templates/group.st | 18 ++++++ .../src/main/resources/templates/imports.st | 18 ++++++ .../src/main/resources/templates/init_params.st | 18 ++++++ .../src/main/resources/templates/iterator.st | 18 ++++++ .../src/main/resources/templates/logging.st | 18 ++++++ .../src/main/resources/templates/lrn.st | 18 ++++++ .../src/main/resources/templates/lrpolicy_exp.st | 18 ++++++ .../src/main/resources/templates/lrpolicy_inv.st | 18 ++++++ .../main/resources/templates/lrpolicy_multistep.st | 18 ++++++ .../src/main/resources/templates/lrpolicy_poly.st | 18 ++++++ .../main/resources/templates/lrpolicy_sigmoid.st | 18 ++++++ .../src/main/resources/templates/lrpolicy_step.st | 18 ++++++ .../src/main/resources/templates/maxium.st | 18 ++++++ .../main/resources/templates/metrics_classes.st | 27 ++++++-- .../src/main/resources/templates/mul.st | 18 ++++++ .../src/main/resources/templates/opt_adadelta.st | 32 +++++++++ .../src/main/resources/templates/opt_adagrad.st | 28 ++++++++ .../src/main/resources/templates/opt_adam.st | 36 +++++++++++ .../src/main/resources/templates/opt_default.st | 15 ----- .../src/main/resources/templates/opt_nesterov.st | 28 ++++++++ .../src/main/resources/templates/opt_rmsprop.st | 32 +++++++++ .../src/main/resources/templates/opt_sgd.st | 36 ++++++++--- .../src/main/resources/templates/opt_vars.st | 24 +++++++ .../main/resources/templates/param_initializer.st | 18 ++++++ .../src/main/resources/templates/params_loader.st | 18 ++++++ .../src/main/resources/templates/permute.st | 18 ++++++ .../src/main/resources/templates/pooling.st | 18 ++++++ .../src/main/resources/templates/power.st | 18 ++++++ .../src/main/resources/templates/runner.st | 18 ++++++ .../src/main/resources/templates/softmaxoutput.st | 18 ++++++ .../src/main/resources/templates/symbols.stg | 18 ++++++ .../src/main/resources/templates/top_k_accuracy.st | 18 ++++++ .../src/main/resources/templates/var.st | 18 ++++++ 48 files changed, 979 insertions(+), 85 deletions(-) diff --git a/tools/caffe_translator/build.gradle b/tools/caffe_translator/build.gradle index 4206767..da5e900 100644 --- a/tools/caffe_translator/build.gradle +++ b/tools/caffe_translator/build.gradle @@ -10,7 +10,7 @@ apply plugin: 'maven' apply plugin: 'signing' group 'org.caffetranslator' -version '0.9.1' +version '0.9.2' def isReleaseBuild def repositoryUrl diff --git a/tools/caffe_translator/gradlew b/tools/caffe_translator/gradlew index cccdd3d..07cc915 100755 --- a/tools/caffe_translator/gradlew +++ b/tools/caffe_translator/gradlew @@ -1,5 +1,22 @@ #!/usr/bin/env sh +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + ############################################################################## ## ## Gradle start up script for UN*X diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java index 90ed9d2..96d6fec 100644 --- a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java +++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java @@ -154,22 +154,33 @@ public class Converter { Layer layer = layers.get(layerIndex); SymbolGenerator generator = generators.getGenerator(layer.getType()); - // If the translator cannot translate this layer to an MXNet layer, - // use CaffeOp or CaffeLoss instead. + // Handle layers for which there is no Generator if (generator == null) { - if (layer.getType().toLowerCase().endsWith("loss")) { + if (layer.getType().equalsIgnoreCase("Accuracy")) { + // We handle accuracy layers at a later stage. Do nothing for now. + } else if (layer.getType().toLowerCase().endsWith("loss")) { + // This is a loss layer we don't have a generator for. Wrap it in CaffeLoss. generator = generators.getGenerator("CaffePluginLossLayer"); } else { + // This is a layer we don't have a generator for. Wrap it in CaffeOp. generator = generators.getGenerator("PluginIntLayerGenerator"); } } - GeneratorOutput out = generator.generate(layer, mlModel); - String segment = out.code; - code.append(segment); - code.append(NL); - - layerIndex += out.numLayersTranslated; + if (generator != null) { // If we have a generator + // Generate code + GeneratorOutput out = generator.generate(layer, mlModel); + String segment = out.code; + code.append(segment); + code.append(NL); + + // Update layerIndex depending on how many layers we ended up translating + layerIndex += out.numLayersTranslated; + } else { // If we don't have a generator + // We've decided to skip this layer. Generate no code. Just increment layerIndex + // by 1 and move on to the next layer. + layerIndex++; + } } String loss = getLoss(mlModel, code); @@ -304,50 +315,8 @@ public class Converter { } private String generateOptimizer() { - String caffeOptimizer = solver.getProperty("type", "sgd").toLowerCase(); - ST st; - - String lr = solver.getProperty("base_lr"); - String momentum = solver.getProperty("momentum", "0.9"); - String wd = solver.getProperty("weight_decay", "0.0005"); - - switch (caffeOptimizer) { - case "adadelta": - st = gh.getTemplate("opt_default"); - st.add("opt_name", "AdaDelta"); - st.add("epsilon", solver.getProperty("delta")); - break; - case "adagrad": - st = gh.getTemplate("opt_default"); - st.add("opt_name", "AdaGrad"); - break; - case "adam": - st = gh.getTemplate("opt_default"); - st.add("opt_name", "Adam"); - break; - case "nesterov": - st = gh.getTemplate("opt_sgd"); - st.add("opt_name", "NAG"); - st.add("momentum", momentum); - break; - case "rmsprop": - st = gh.getTemplate("opt_default"); - st.add("opt_name", "RMSProp"); - break; - default: - if (!caffeOptimizer.equals("sgd")) { - System.err.println("Unknown optimizer. Will use SGD instead."); - } - - st = gh.getTemplate("opt_sgd"); - st.add("opt_name", "SGD"); - st.add("momentum", momentum); - break; - } - st.add("lr", lr); - st.add("wd", wd); - - return st.render(); + Optimizer optimizer = new Optimizer(solver); + return optimizer.generateInitCode(); } private String generateInitializer() { diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Optimizer.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Optimizer.java new file mode 100644 index 0000000..da24942 --- /dev/null +++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Optimizer.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Optimizer.java + * \brief Generates optimizer from solver prototxt + */ + +package io.mxnet.caffetranslator; + +import org.stringtemplate.v4.ST; + +public class Optimizer { + private final GenerationHelper gh; + private final Solver solver; + + public Optimizer(Solver solver) { + this.gh = new GenerationHelper(); + this.solver = solver; + } + + public String generateInitCode() { + ST st = gh.getTemplate("opt_" + solver.getType().toLowerCase()); + if (st == null) { + System.err.println(String.format("Unknown optimizer type (%s). Using SGD instead.", solver.getType())); + st = gh.getTemplate("opt_sgd"); + } + + st.add("solver", solver); + return st.render(); + } +} diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java index ec4c812..9693771 100644 --- a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java +++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java @@ -24,6 +24,7 @@ package io.mxnet.caffetranslator; +import lombok.Getter; import org.antlr.v4.runtime.CharStream; import org.antlr.v4.runtime.CharStreams; import org.antlr.v4.runtime.CommonTokenStream; @@ -31,6 +32,7 @@ import org.antlr.v4.runtime.CommonTokenStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.List; @@ -38,9 +40,18 @@ import java.util.Map; public class Solver { + private final String solverPath; private boolean parseDone; private Map<String, List<String>> properties; - private final String solverPath; + /** + * Fields corresponding to keys that can be present in the solver prototxt. 'setFields' sets these + * using reflection after parsing the solver prototxt. A solver object is passed to string templates + * and the templates read these fields. + */ + @Getter + private String base_lr, momentum, weight_decay, lr_policy, gamma, stepsize, stepvalue, max_iter, + solver_mode, snapshot, snapshot_prefix, test_iter, test_interval, display, type, delta, + momentum2, rms_decay, solver_type; public Solver(String solverPath) { this.solverPath = solverPath; @@ -67,10 +78,49 @@ public class Solver { properties = solverListener.getProperties(); + setFields(properties); + parseDone = true; return true; } + private void setFields(Map<String, List<String>> properties) { + Class<?> cls = getClass(); + + for (Map.Entry<String, List<String>> entry : properties.entrySet()) { + String key = entry.getKey(); + try { + Field field = cls.getDeclaredField(key); + field.set(this, entry.getValue().get(0)); + } catch (NoSuchFieldException e) { + // Just ignore + } catch (IllegalAccessException e) { + /** + * This shouldn't happen. If it does happen because we overlooked something, print + * it in the console so we can investigate it. + */ + e.printStackTrace(); + } + } + + setDefaults(); + } + + private void setDefaults() { + if (type == null) { + type = "SGD"; + } + if (delta == null) { + delta = "1e-8"; + } + if (momentum2 == null) { + momentum2 = "0.999"; + } + if (rms_decay == null) { + rms_decay = "0.99"; + } + } + public String getProperty(String key) { List<String> list = getProperties(key); if (list == null) { diff --git a/tools/caffe_translator/src/main/resources/templates/accuracy.st b/tools/caffe_translator/src/main/resources/templates/accuracy.st index f741def..cbe15f6 100644 --- a/tools/caffe_translator/src/main/resources/templates/accuracy.st +++ b/tools/caffe_translator/src/main/resources/templates/accuracy.st @@ -1,2 +1,20 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.metric.Accuracy(output_names=['<output_name>'], label_names=['<label_name>'], name='<name>') test_metrics.add(<var>) diff --git a/tools/caffe_translator/src/main/resources/templates/activation.st b/tools/caffe_translator/src/main/resources/templates/activation.st index 5a9c37b..042c2e3 100644 --- a/tools/caffe_translator/src/main/resources/templates/activation.st +++ b/tools/caffe_translator/src/main/resources/templates/activation.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.symbol.Activation(data=<data>, act_type='<type>', name='<name>') diff --git a/tools/caffe_translator/src/main/resources/templates/add.st b/tools/caffe_translator/src/main/resources/templates/add.st index ca9428f..738ac3e 100644 --- a/tools/caffe_translator/src/main/resources/templates/add.st +++ b/tools/caffe_translator/src/main/resources/templates/add.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = <data1> + <data2> diff --git a/tools/caffe_translator/src/main/resources/templates/batchnorm.st b/tools/caffe_translator/src/main/resources/templates/batchnorm.st index c043c70..7f2326d 100644 --- a/tools/caffe_translator/src/main/resources/templates/batchnorm.st +++ b/tools/caffe_translator/src/main/resources/templates/batchnorm.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <if(fix_beta)> <var>_beta = mx.sym.BlockGrad(mx.sym.Variable("<name>_beta", init=mx.init.Constant(0))) <endif> diff --git a/tools/caffe_translator/src/main/resources/templates/concat.st b/tools/caffe_translator/src/main/resources/templates/concat.st index 75ffa3c..3f33275 100644 --- a/tools/caffe_translator/src/main/resources/templates/concat.st +++ b/tools/caffe_translator/src/main/resources/templates/concat.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.concat(<data;separator=", "><if(dim)>, dim=<dim><endif>, name='<name>'); diff --git a/tools/caffe_translator/src/main/resources/templates/convolution.st b/tools/caffe_translator/src/main/resources/templates/convolution.st index c4bdd51..c167217 100644 --- a/tools/caffe_translator/src/main/resources/templates/convolution.st +++ b/tools/caffe_translator/src/main/resources/templates/convolution.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.Convolution(data=<data>, <if(weight)>weight=<weight>,<endif> <if(bias)>bias=<bias>,<endif> diff --git a/tools/caffe_translator/src/main/resources/templates/deconvolution.st b/tools/caffe_translator/src/main/resources/templates/deconvolution.st index 5b63f56..67483b9 100644 --- a/tools/caffe_translator/src/main/resources/templates/deconvolution.st +++ b/tools/caffe_translator/src/main/resources/templates/deconvolution.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.Deconvolution(data=<data>, <if(use_weight)>weight=weight,<endif> <if(use_bias)>bias=bias,<endif> diff --git a/tools/caffe_translator/src/main/resources/templates/dropout.st b/tools/caffe_translator/src/main/resources/templates/dropout.st index 9791c09..ed28dc7 100644 --- a/tools/caffe_translator/src/main/resources/templates/dropout.st +++ b/tools/caffe_translator/src/main/resources/templates/dropout.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.Dropout(data=<data>, p=<prob>, name='<name>') diff --git a/tools/caffe_translator/src/main/resources/templates/fc.st b/tools/caffe_translator/src/main/resources/templates/fc.st index 22365b3..353b424 100644 --- a/tools/caffe_translator/src/main/resources/templates/fc.st +++ b/tools/caffe_translator/src/main/resources/templates/fc.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.symbol.FullyConnected(data=<data>, <if(weight)>weight=<weight>, <endif><if(bias)>bias=<bias>, <endif>num_hidden=<num>, <if(no_bias)>no_bias=True, <endif>name='<name>') diff --git a/tools/caffe_translator/src/main/resources/templates/flatten.st b/tools/caffe_translator/src/main/resources/templates/flatten.st index 8434335..2ee6ffa 100644 --- a/tools/caffe_translator/src/main/resources/templates/flatten.st +++ b/tools/caffe_translator/src/main/resources/templates/flatten.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.flatten(data=<data>, name='<name>') diff --git a/tools/caffe_translator/src/main/resources/templates/group.st b/tools/caffe_translator/src/main/resources/templates/group.st index 33e312f..9cadf65 100644 --- a/tools/caffe_translator/src/main/resources/templates/group.st +++ b/tools/caffe_translator/src/main/resources/templates/group.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.Group([<symbols;separator=", ">]); diff --git a/tools/caffe_translator/src/main/resources/templates/imports.st b/tools/caffe_translator/src/main/resources/templates/imports.st index b37bd33..da03a64 100644 --- a/tools/caffe_translator/src/main/resources/templates/imports.st +++ b/tools/caffe_translator/src/main/resources/templates/imports.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> from __future__ import division import copy import logging diff --git a/tools/caffe_translator/src/main/resources/templates/init_params.st b/tools/caffe_translator/src/main/resources/templates/init_params.st index 3a277b6..7c8d7b0 100644 --- a/tools/caffe_translator/src/main/resources/templates/init_params.st +++ b/tools/caffe_translator/src/main/resources/templates/init_params.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <if(params_file)> arg_params, aux_params = load_params('<params_file>') module.init_params(initializer=mx.init.Xavier(), arg_params=arg_params, aux_params=aux_params, diff --git a/tools/caffe_translator/src/main/resources/templates/iterator.st b/tools/caffe_translator/src/main/resources/templates/iterator.st index 5bc2a9d..d608979 100644 --- a/tools/caffe_translator/src/main/resources/templates/iterator.st +++ b/tools/caffe_translator/src/main/resources/templates/iterator.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <iter_name> = mx.io.CaffeDataIter( prototxt = <prototxt>, diff --git a/tools/caffe_translator/src/main/resources/templates/logging.st b/tools/caffe_translator/src/main/resources/templates/logging.st index 73785e5..cc94872 100644 --- a/tools/caffe_translator/src/main/resources/templates/logging.st +++ b/tools/caffe_translator/src/main/resources/templates/logging.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> def get_logger(name): formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S') diff --git a/tools/caffe_translator/src/main/resources/templates/lrn.st b/tools/caffe_translator/src/main/resources/templates/lrn.st index ec003c1..b679898 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrn.st +++ b/tools/caffe_translator/src/main/resources/templates/lrn.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.LRN(data=<data>, alpha=<alpha>, beta=<beta>, knorm=<knorm>, nsize=<nsize>, name=<name>) diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st index 43afca2..03daae3 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st +++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> lr = optimizer_params['learning_rate'] lr *= gamma optimizer_params['learning_rate'] = lr diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st index 5da8aa6..e62c2d3 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st +++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> lr = optimizer_params['learning_rate'] lr = base_lr * math.pow((1 + gamma * batch_num), -power) optimizer_params['learning_rate'] = lr diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st index fe09301..0761908 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st +++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> lr_update_steps = [<steps;separator=", ">] if(batch_num in lr_update_steps): lr = optimizer_params['learning_rate'] diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st index e43fd78..d62c64b 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st +++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> lr = optimizer_params['learning_rate'] lr = math.pow(base_lr * (1 - batch_num/max_iter), power) optimizer_params['learning_rate'] = lr diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st index 33ba055..f44ab5a 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st +++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> lr = optimizer_params['learning_rate'] lr = base_lr * ( 1/(1 + math.exp(-gamma * (batch_num - stepsize)))) optimizer_params['learning_rate'] = lr diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st index 04468ae..1f3d975 100644 --- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st +++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> if(batch_num % stepsize == 0): lr = optimizer_params['learning_rate'] lr *= gamma diff --git a/tools/caffe_translator/src/main/resources/templates/maxium.st b/tools/caffe_translator/src/main/resources/templates/maxium.st index d9431dd..9b18246 100644 --- a/tools/caffe_translator/src/main/resources/templates/maxium.st +++ b/tools/caffe_translator/src/main/resources/templates/maxium.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.maximum(<data1>, <data2>) diff --git a/tools/caffe_translator/src/main/resources/templates/metrics_classes.st b/tools/caffe_translator/src/main/resources/templates/metrics_classes.st index e8323fb..e586616 100644 --- a/tools/caffe_translator/src/main/resources/templates/metrics_classes.st +++ b/tools/caffe_translator/src/main/resources/templates/metrics_classes.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> class TrainMetrics(): metric_map = {} @@ -16,17 +34,16 @@ class TrainMetrics(): self.update_metrics(module, label, reset=True) self.print_metrics(batch_num) else: - # If I'll have to print metrics 'average_loss' iterations from now, - # append a metric so I can start updating that. + # Metrics must be print 'average_loss' iterations from now. + # Append a metric which will get updated starting now. if((batch_num + self.average_loss) % self.display == 0): self.append_one() - # If I'm less than 'average_loss' iteration away from a display step, - # update the metrics. + # Less that 'average_loss' iteration away from a display step. Update metrics. if((batch_num + self.average_loss) % self.display \< self.average_loss): self.update_metrics(module, label) - # If I'm at a display step, print the metrics. + # At display step. Print metrics. if(batch_num % self.display == 0): self.print_metrics(batch_num, remove_heads=True) diff --git a/tools/caffe_translator/src/main/resources/templates/mul.st b/tools/caffe_translator/src/main/resources/templates/mul.st index 411a407..59c4837 100644 --- a/tools/caffe_translator/src/main/resources/templates/mul.st +++ b/tools/caffe_translator/src/main/resources/templates/mul.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = <data1> * (<data2>) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_adadelta.st b/tools/caffe_translator/src/main/resources/templates/opt_adadelta.st new file mode 100644 index 0000000..cfd465b --- /dev/null +++ b/tools/caffe_translator/src/main/resources/templates/opt_adadelta.st @@ -0,0 +1,32 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<opt_vars(solver)> +<if(solver.momentum)> +rho = <solver.momentum> +<endif> +<if(solver.delta)> +epsilon = <solver.delta> +<endif> + +optimizer_params={'learning_rate':base_lr<\\> +<if(solver.wd)>, 'wd':wd<endif><\\> +<if(solver.momentum)>, 'rho':rho<endif><\\> +<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\> + +module.init_optimizer(optimizer='AdaDelta', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_adagrad.st b/tools/caffe_translator/src/main/resources/templates/opt_adagrad.st new file mode 100644 index 0000000..527cedf --- /dev/null +++ b/tools/caffe_translator/src/main/resources/templates/opt_adagrad.st @@ -0,0 +1,28 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<opt_vars(solver)> +<if(solver.delta)> +epsilon = <solver.delta> +<endif> + +optimizer_params={'learning_rate':base_lr<\\> +<if(solver.wd)>, 'wd':wd<endif><\\> +<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\> + +module.init_optimizer(optimizer='AdaGrad', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_adam.st b/tools/caffe_translator/src/main/resources/templates/opt_adam.st new file mode 100644 index 0000000..b0a8ca3 --- /dev/null +++ b/tools/caffe_translator/src/main/resources/templates/opt_adam.st @@ -0,0 +1,36 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<opt_vars(solver)> +<if(solver.momentum)> +beta1 = <solver.momentum> +<endif> +<if(solver.momentum2)> +beta2 = <solver.momentum2> +<endif> +<if(solver.delta)> +epsilon = <solver.delta> +<endif> + +optimizer_params={'learning_rate':base_lr<\\> +<if(solver.wd)>, 'wd':swd<endif><\\> +<if(solver.momentum)>, 'beta1':beta1<endif><\\> +<if(solver.momentum2)>, 'beta2':beta2<endif><\\> +<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\> + +module.init_optimizer(optimizer='Adam', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_default.st b/tools/caffe_translator/src/main/resources/templates/opt_default.st deleted file mode 100644 index e5a72ac..0000000 --- a/tools/caffe_translator/src/main/resources/templates/opt_default.st +++ /dev/null @@ -1,15 +0,0 @@ -<if(lr)> -base_lr = <lr> -<endif> -<if(momentum)> -momentum = <momentum> -<endif> -<if(wd)> -wd = <wd> -<endif> -<if(epsilon)> -epsilon = <epsilon> -<endif> - -optimizer_params={'learning_rate':base_lr <if(momentum)>, 'momentum':momentum<endif><if(wd)>, 'wd':wd<endif><if(epsilon)>, 'epsilon':epsilon<endif>} -module.init_optimizer(optimizer='<opt_name>', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_nesterov.st b/tools/caffe_translator/src/main/resources/templates/opt_nesterov.st new file mode 100644 index 0000000..6262d48 --- /dev/null +++ b/tools/caffe_translator/src/main/resources/templates/opt_nesterov.st @@ -0,0 +1,28 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<opt_vars(solver)> +<if(solver.momentum)> +momentum = <solver.momentum> +<endif> + +optimizer_params={'learning_rate':base_lr<\\> +<if(solver.wd)>, 'wd':wd<endif><\\> +<if(solver.momentum)>, 'momentum':momentum<endif>}<\\> + +module.init_optimizer(optimizer='NAG', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_rmsprop.st b/tools/caffe_translator/src/main/resources/templates/opt_rmsprop.st new file mode 100644 index 0000000..6baec42 --- /dev/null +++ b/tools/caffe_translator/src/main/resources/templates/opt_rmsprop.st @@ -0,0 +1,32 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<opt_vars(solver)> +<if(solver.rms_decay)> +gamma1 = <solver.rms_decay> +<endif> +<if(solver.delta)> +epsilon = <solver.delta> +<endif> + +optimizer_params={'learning_rate':base_lr<\\> +<if(solver.wd)>, 'wd':wd<endif><\\> +<if(solver.rms_decay)>, 'gamma1':gamma1<endif><\\> +<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\> + +module.init_optimizer(optimizer='RMSProp', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_sgd.st b/tools/caffe_translator/src/main/resources/templates/opt_sgd.st index 8a24e05..aa547a6 100644 --- a/tools/caffe_translator/src/main/resources/templates/opt_sgd.st +++ b/tools/caffe_translator/src/main/resources/templates/opt_sgd.st @@ -1,12 +1,28 @@ -<if(lr)> -base_lr = <lr> -<endif> -<if(momentum)> -momentum = <momentum> -<endif> -<if(wd)> -wd = <wd> +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<opt_vars(solver)> +<if(solver.momentum)> +momentum = <solver.momentum> <endif> -optimizer_params={'learning_rate':base_lr <if(momentum)>, 'momentum':momentum<endif><if(wd)>, 'wd':wd<endif>} -module.init_optimizer(optimizer='<opt_name>', optimizer_params=optimizer_params) +optimizer_params={'learning_rate':base_lr<\\> +<if(solver.wd)>, 'wd':wd<endif><\\> +<if(solver.momentum)>, 'momentum':momentum<endif>}<\\> + +module.init_optimizer(optimizer='SGD', optimizer_params=optimizer_params) diff --git a/tools/caffe_translator/src/main/resources/templates/opt_vars.st b/tools/caffe_translator/src/main/resources/templates/opt_vars.st new file mode 100644 index 0000000..19b2f4c --- /dev/null +++ b/tools/caffe_translator/src/main/resources/templates/opt_vars.st @@ -0,0 +1,24 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> +<if(solver.base_lr)> +base_lr = <solver.base_lr> +<endif> +<if(solver.wd)> +wd = <solver.wd> +<endif> \ No newline at end of file diff --git a/tools/caffe_translator/src/main/resources/templates/param_initializer.st b/tools/caffe_translator/src/main/resources/templates/param_initializer.st index b496fc3..abad5da 100644 --- a/tools/caffe_translator/src/main/resources/templates/param_initializer.st +++ b/tools/caffe_translator/src/main/resources/templates/param_initializer.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> class ParamInitializer(): lst_patterns = [] lst_initializers = [] diff --git a/tools/caffe_translator/src/main/resources/templates/params_loader.st b/tools/caffe_translator/src/main/resources/templates/params_loader.st index 22efec4..c124c98 100644 --- a/tools/caffe_translator/src/main/resources/templates/params_loader.st +++ b/tools/caffe_translator/src/main/resources/templates/params_loader.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> def load_params(params_file): save_dict = mx.nd.load(params_file) arg_params = {} diff --git a/tools/caffe_translator/src/main/resources/templates/permute.st b/tools/caffe_translator/src/main/resources/templates/permute.st index 2b06a76..9f94bdb 100644 --- a/tools/caffe_translator/src/main/resources/templates/permute.st +++ b/tools/caffe_translator/src/main/resources/templates/permute.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.transpose(data=<data>, axes=(<axes;separator=", ">), name='<name>') diff --git a/tools/caffe_translator/src/main/resources/templates/pooling.st b/tools/caffe_translator/src/main/resources/templates/pooling.st index 5389754..7aceffd 100644 --- a/tools/caffe_translator/src/main/resources/templates/pooling.st +++ b/tools/caffe_translator/src/main/resources/templates/pooling.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.symbol.Pooling(data=<data>, pool_type='<type>', <if(global_pool)> diff --git a/tools/caffe_translator/src/main/resources/templates/power.st b/tools/caffe_translator/src/main/resources/templates/power.st index a512a67..7fe3ee8 100644 --- a/tools/caffe_translator/src/main/resources/templates/power.st +++ b/tools/caffe_translator/src/main/resources/templates/power.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = (<shift> + (<scale> * <data>)) ** <power> diff --git a/tools/caffe_translator/src/main/resources/templates/runner.st b/tools/caffe_translator/src/main/resources/templates/runner.st index 6df9671..8346ffe 100644 --- a/tools/caffe_translator/src/main/resources/templates/runner.st +++ b/tools/caffe_translator/src/main/resources/templates/runner.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> ctx = <ctx> module = mx.mod.Module(symbol=<loss>, context=ctx, data_names=[<data_names;separator=", ">], label_names=[<label_names;separator=", ">]) diff --git a/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st b/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st index bc63891..57a8e71 100644 --- a/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st +++ b/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.SoftmaxOutput(data=<data>, label=<label>, name='<name>') <var>_metric = mx.metric.CrossEntropy(output_names=['<name>_output'], label_names=['<label_name>'], name='<name>/metric') train_metrics.add(<var>_metric) diff --git a/tools/caffe_translator/src/main/resources/templates/symbols.stg b/tools/caffe_translator/src/main/resources/templates/symbols.stg index fda9125..2a76eb0 100644 --- a/tools/caffe_translator/src/main/resources/templates/symbols.stg +++ b/tools/caffe_translator/src/main/resources/templates/symbols.stg @@ -1,3 +1,21 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> CaffePluginIntLayer(var, tops, num_data, num_weight, num_out, data, prototxt, name) ::= "<var> = mx.symbol.CaffeOp(<if(data)><data>, <endif><if(num_data)>num_data=<num_data>, <endif><if(num_out)>num_out=<num_out>, <endif><if(num_weight)>num_weight=<num_weight>, <endif>prototxt='<prototxt>', name='<name>') <if(tops)><tops:{top|<top_assign(top, var, i0)>};separator=\"\n\"> <endif>" diff --git a/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st b/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st index de93ee9..29a713f 100644 --- a/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st +++ b/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st @@ -1,2 +1,20 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.metric.TopKAccuracy(top_k=<k>, output_names=['<output_name>'], label_names=['<label_name>'], name='<name>') test_metrics.add(<var>) diff --git a/tools/caffe_translator/src/main/resources/templates/var.st b/tools/caffe_translator/src/main/resources/templates/var.st index e850b689..fa08cd7 100644 --- a/tools/caffe_translator/src/main/resources/templates/var.st +++ b/tools/caffe_translator/src/main/resources/templates/var.st @@ -1 +1,19 @@ +<! + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +!> <var> = mx.sym.Variable('<name>'<if(lr_mult)>, lr_mult=<lr_mult><endif><if(wd_mult)>, wd_mult=<wd_mult><endif><if(init)>, init=<init><endif><if(shape)>, shape=(<shape;separator=", ">)<endif>) -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].