http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/CompilerCustomizationBuilder.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/CompilerCustomizationBuilder.groovy b/src/main/groovy/CompilerCustomizationBuilder.groovy new file mode 100644 index 0000000..59b8cc5 --- /dev/null +++ b/src/main/groovy/CompilerCustomizationBuilder.groovy @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.codehaus.groovy.control.customizers.builder + +import groovy.transform.CompileStatic +import org.codehaus.groovy.control.CompilerConfiguration + +/** + * <p>A builder which allows easy configuration of compilation customizers. Instead of creating + * various compilation customizers by hand, you may use this builder instead, which provides a + * shorter syntax and removes most of the verbosity. + * + */ +@CompileStatic +class CompilerCustomizationBuilder extends FactoryBuilderSupport { + public CompilerCustomizationBuilder() { + registerFactories() + } + + public static CompilerConfiguration withConfig(CompilerConfiguration config, Closure code) { + CompilerCustomizationBuilder builder = new CompilerCustomizationBuilder() + config.invokeMethod('addCompilationCustomizers', builder.invokeMethod('customizers', code)) + + config + } + + @Override + protected Object postNodeCompletion(final Object parent, final Object node) { + Object value = super.postNodeCompletion(parent, node) + Object factory = getContextAttribute(CURRENT_FACTORY) + if (factory instanceof PostCompletionFactory) { + value = factory.postCompleteNode(this, parent, value) + setParent(parent, value) + } + + value + } + + private void registerFactories() { + registerFactory("ast", new ASTTransformationCustomizerFactory()) + registerFactory("customizers", new CustomizersFactory()) + registerFactory("imports", new ImportCustomizerFactory()) + registerFactory("inline", new InlinedASTCustomizerFactory()) + registerFactory("secureAst", new SecureASTCustomizerFactory()) + registerFactory("source", new SourceAwareCustomizerFactory()) + } +}
http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/ConditionalInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/ConditionalInterruptibleASTTransformation.groovy b/src/main/groovy/ConditionalInterruptibleASTTransformation.groovy new file mode 100644 index 0000000..2cda121 --- /dev/null +++ b/src/main/groovy/ConditionalInterruptibleASTTransformation.groovy @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.ConditionalInterrupt +import org.codehaus.groovy.ast.AnnotatedNode +import org.codehaus.groovy.ast.AnnotationNode +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.FieldNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.PropertyNode +import org.codehaus.groovy.ast.expr.ArgumentListExpression +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.tools.ClosureUtils +import org.codehaus.groovy.control.CompilePhase + +/** + * Allows "interrupt-safe" executions of scripts by adding a custom conditional + * check on loops (for, while, do) and first statement of closures. By default, also adds an interrupt check + * statement on the beginning of method calls. + * + * @see groovy.transform.ConditionalInterrupt + * @author Cedric Champeau + * @author Hamlet D'Arcy + * @author Paul King + * @since 1.8.0 + */ +@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) +public class ConditionalInterruptibleASTTransformation extends AbstractInterruptibleASTTransformation { + + private static final ClassNode MY_TYPE = ClassHelper.make(ConditionalInterrupt) + + private ClosureExpression conditionNode + private String conditionMethod + private MethodCallExpression conditionCallExpression + private ClassNode currentClass + + protected ClassNode type() { + return MY_TYPE + } + + protected void setupTransform(AnnotationNode node) { + super.setupTransform(node) + def member = node.getMember("value") + if (!member || !(member instanceof ClosureExpression)) internalError("Expected closure value for annotation parameter 'value'. Found $member") + conditionNode = member; + conditionMethod = 'conditionalTransform' + node.hashCode() + '$condition' + conditionCallExpression = new MethodCallExpression(new VariableExpression('this'), conditionMethod, new ArgumentListExpression()) + } + + protected String getErrorMessage() { + 'Execution interrupted. The following condition failed: ' + convertClosureToSource(conditionNode) + } + + void visitClass(ClassNode type) { + currentClass = type + def method = type.addMethod(conditionMethod, ACC_PRIVATE | ACC_SYNTHETIC, ClassHelper.OBJECT_TYPE, Parameter.EMPTY_ARRAY, ClassNode.EMPTY_ARRAY, conditionNode.code) + method.synthetic = true + if (applyToAllMembers) { + super.visitClass(type) + } + } + + protected Expression createCondition() { + conditionCallExpression + } + + @Override + void visitAnnotations(AnnotatedNode node) { + // this transformation does not apply on annotation nodes + // visiting could lead to stack overflows + } + + @Override + void visitField(FieldNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitField node + } + } + + @Override + void visitProperty(PropertyNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitProperty node + } + } + + @Override + void visitClosureExpression(ClosureExpression closureExpr) { + if (closureExpr == conditionNode) return // do not visit the closure from the annotation itself + def code = closureExpr.code + closureExpr.code = wrapBlock(code) + super.visitClosureExpression closureExpr + } + + @Override + void visitMethod(MethodNode node) { + if (node.name == conditionMethod && !node.isSynthetic()) return // do not visit the generated method + if (node.name == 'run' && currentClass.isScript() && node.parameters.length == 0) { + // the run() method should not have the statement added, otherwise the script binding won't be set before + // the condition is actually tested + super.visitMethod(node) + } else { + if (checkOnMethodStart && !node.isSynthetic() && !node.isStatic() && !node.isAbstract()) { + def code = node.code + node.code = wrapBlock(code); + } + if (!node.isSynthetic() && !node.isStatic()) super.visitMethod(node) + } + } + + /** + * Converts a ClosureExpression into the String source. + * @param expression a closure + * @return the source the closure was created from + */ + private String convertClosureToSource(ClosureExpression expression) { + try { + return ClosureUtils.convertClosureToSource(this.source.source, expression); + } catch(Exception e) { + return e.message + } + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/GrapeMain.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/GrapeMain.groovy b/src/main/groovy/GrapeMain.groovy new file mode 100644 index 0000000..c78d25e --- /dev/null +++ b/src/main/groovy/GrapeMain.groovy @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.tools + +import groovy.grape.Grape +import groovy.transform.Field +import org.apache.commons.cli.CommandLine +import org.apache.commons.cli.DefaultParser +import org.apache.commons.cli.HelpFormatter +import org.apache.commons.cli.Option +import org.apache.commons.cli.OptionGroup +import org.apache.commons.cli.Options +import org.apache.ivy.util.DefaultMessageLogger +import org.apache.ivy.util.Message + +//commands + +@Field install = {arg, cmd -> + if (arg.size() > 5 || arg.size() < 3) { + println 'install requires two to four arguments: <group> <module> [<version> [<classifier>]]' + return + } + def ver = '*' + if (arg.size() >= 4) { + ver = arg[3] + } + def classifier = null + if (arg.size() >= 5) { + classifier = arg[4] + } + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + cmd.getOptionValues('r')?.each { String url -> + Grape.addResolver(name:url, root:url) + } + + try { + Grape.grab(autoDownload: true, group: arg[1], module: arg[2], version: ver, classifier: classifier, noExceptions: true) + } catch (Exception e) { + println "An error occured : $ex" + } +} + +@Field uninstall = {arg, cmd -> + if (arg.size() != 4) { + println 'uninstall requires three arguments: <group> <module> <version>' + // TODO make version optional? support classifier? +// println 'uninstall requires two to four arguments, <group> <module> [<version>] [<classifier>]' + return + } + String group = arg[1] + String module = arg[2] + String ver = arg[3] +// def classifier = null + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + if (!Grape.enumerateGrapes().find {String groupName, Map g -> + g.any {String moduleName, List<String> versions -> + group == groupName && module == moduleName && ver in versions + } + }) { + println "uninstall did not find grape matching: $group $module $ver" + def fuzzyMatches = Grape.enumerateGrapes().findAll { String groupName, Map g -> + g.any {String moduleName, List<String> versions -> + groupName.contains(group) || moduleName.contains(module) || + group.contains(groupName) || module.contains(moduleName) + } + } + if (fuzzyMatches) { + println 'possible matches:' + fuzzyMatches.each { String groupName, Map g -> println " $groupName: $g" } + } + return + } + Grape.instance.uninstallArtifact(group, module, ver) +} + +@Field list = {arg, cmd -> + println "" + + int moduleCount = 0 + int versionCount = 0 + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + Grape.enumerateGrapes().each {String groupName, Map group -> + group.each {String moduleName, List<String> versions -> + println "$groupName $moduleName $versions" + moduleCount++ + versionCount += versions.size() + } + } + println "" + println "$moduleCount Grape modules cached" + println "$versionCount Grape module versions cached" +} + +@Field resolve = {arg, cmd -> + Options options = new Options(); + options.addOption(Option.builder("a").hasArg(false).longOpt("ant").build()); + options.addOption(Option.builder("d").hasArg(false).longOpt("dos").build()); + options.addOption(Option.builder("s").hasArg(false).longOpt("shell").build()); + options.addOption(Option.builder("i").hasArg(false).longOpt("ivy").build()); + CommandLine cmd2 = new DefaultParser().parse(options, arg[1..-1] as String[], true); + arg = cmd2.args + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging(Message.MSG_ERR) + + if ((arg.size() % 3) != 0) { + println 'There needs to be a multiple of three arguments: (group module version)+' + return + } + if (args.size() < 3) { + println 'At least one Grape reference is required' + return + } + def before, between, after + def ivyFormatRequested = false + + if (cmd2.hasOption('a')) { + before = '<pathelement location="' + between = '">\n<pathelement location="' + after = '">' + } else if (cmd2.hasOption('d')) { + before = 'set CLASSPATH=' + between = ';' + after = '' + } else if (cmd2.hasOption('s')) { + before = 'export CLASSPATH=' + between = ':' + after = '' + } else if (cmd2.hasOption('i')) { + ivyFormatRequested = true + before = '<dependency ' + between = '">\n<dependency ' + after = '">' + } else { + before = '' + between = '\n' + after = '\n' + } + + iter = arg.iterator() + def params = [[:]] + def depsInfo = [] // this list will contain the module/group/version info of all resolved dependencies + if(ivyFormatRequested) { + params << depsInfo + } + while (iter.hasNext()) { + params.add([group: iter.next(), module: iter.next(), version: iter.next()]) + } + try { + def results = [] + def uris = Grape.resolve(* params) + if(!ivyFormatRequested) { + for (URI uri: uris) { + if (uri.scheme == 'file') { + results += new File(uri).path + } else { + results += uri.toASCIIString() + } + } + } else { + depsInfo.each { dep -> + results += ('org="' + dep.group + '" name="' + dep.module + '" revision="' + dep.revision) + } + } + + if (results) { + println "${before}${results.join(between)}${after}" + } else { + println 'Nothing was resolved' + } + } catch (Exception e) { + println "Error in resolve:\n\t$e.message" + if (e.message =~ /unresolved dependency/) println "Perhaps the grape is not installed?" + } +} + +@Field help = { arg, cmd -> grapeHelp() } + +@Field commands = [ + 'install': [closure: install, + shortHelp: 'Installs a particular grape'], + 'uninstall': [closure: uninstall, + shortHelp: 'Uninstalls a particular grape (non-transitively removes the respective jar file from the grape cache)'], + 'list': [closure: list, + shortHelp: 'Lists all installed grapes'], + 'resolve': [closure: resolve, + shortHelp: 'Enumerates the jars used by a grape'], + 'help': [closure: help, + shortHelp: 'Usage information'] +] + +@Field grapeHelp = { + int spacesLen = commands.keySet().max {it.length()}.length() + 3 + String spaces = ' ' * spacesLen + + PrintWriter pw = new PrintWriter(binding.variables.out ?: System.out) + new HelpFormatter().printHelp( + pw, + 80, + "grape [options] <command> [args]\n", + "options:", + options, + 2, + 4, + null, // footer + true); + pw.flush() + + println "" + println "commands:" + commands.each {String k, v -> + println " ${(k + spaces).substring(0, spacesLen)} $v.shortHelp" + } + println "" +} + +@Field setupLogging = {int defaultLevel = 2 -> // = Message.MSG_INFO -> some parsing error :( + if (cmd.hasOption('q')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_ERR)) + } else if (cmd.hasOption('w')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_WARN)) + } else if (cmd.hasOption('i')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_INFO)) + } else if (cmd.hasOption('V')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_VERBOSE)) + } else if (cmd.hasOption('d')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_DEBUG)) + } else { + Message.setDefaultLogger(new DefaultMessageLogger(defaultLevel)) + } +} + +// command line parsing +@Field Options options = new Options(); + +options.addOption(Option.builder("D").longOpt("define").desc("define a system property").numberOfArgs(2).valueSeparator().argName("name=value").build()); +options.addOption(Option.builder("r").longOpt("resolver").desc("define a grab resolver (for install)").hasArg(true).argName("url").build()); +options.addOption(Option.builder("h").hasArg(false).desc("usage information").longOpt("help").build()); + +// Logging Level Options +options.addOptionGroup( + new OptionGroup() + .addOption(Option.builder("q").hasArg(false).desc("Log level 0 - only errors").longOpt("quiet").build()) + .addOption(Option.builder("w").hasArg(false).desc("Log level 1 - errors and warnings").longOpt("warn").build()) + .addOption(Option.builder("i").hasArg(false).desc("Log level 2 - info").longOpt("info").build()) + .addOption(Option.builder("V").hasArg(false).desc("Log level 3 - verbose").longOpt("verbose").build()) + .addOption(Option.builder("d").hasArg(false).desc("Log level 4 - debug").longOpt("debug").build()) +) +options.addOption(Option.builder("v").hasArg(false).desc("display the Groovy and JVM versions").longOpt("version").build()); + +@Field CommandLine cmd + +cmd = new DefaultParser().parse(options, args, true); + +if (cmd.hasOption('h')) { + grapeHelp() + return +} + +if (cmd.hasOption('v')) { + String version = GroovySystem.getVersion(); + println "Groovy Version: $version JVM: ${System.getProperty('java.version')}" + return +} + +if (options.hasOption('D')) { + options.getOptionProperties('D')?.each { k, v -> + System.setProperty(k, v) + } +} + +String[] arg = cmd.args +if (arg?.length == 0) { + grapeHelp() +} else if (commands.containsKey(arg[0])) { + commands[arg[0]].closure(arg, cmd) +} else { + println "grape: '${arg[0]}' is not a grape command. See 'grape --help'" +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/HasRecursiveCalls.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/HasRecursiveCalls.groovy b/src/main/groovy/HasRecursiveCalls.groovy new file mode 100644 index 0000000..79f8e6d --- /dev/null +++ b/src/main/groovy/HasRecursiveCalls.groovy @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression + +/** + * + * Check if there are any recursive calls in a method + * + * @author Johannes Link + */ +@CompileStatic +class HasRecursiveCalls extends CodeVisitorSupport { + MethodNode method + boolean hasRecursiveCalls = false + + public void visitMethodCallExpression(MethodCallExpression call) { + if (isRecursive(call)) { + hasRecursiveCalls = true + } else { + super.visitMethodCallExpression(call) + } + } + + public void visitStaticMethodCallExpression(StaticMethodCallExpression call) { + if (isRecursive(call)) { + hasRecursiveCalls = true + } else { + super.visitStaticMethodCallExpression(call) + } + } + + private boolean isRecursive(call) { + new RecursivenessTester().isRecursive(method: method, call: call) + } + + synchronized boolean test(MethodNode method) { + hasRecursiveCalls = false + this.method = method + this.method.code.visit(this) + hasRecursiveCalls + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/InWhileLoopWrapper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/InWhileLoopWrapper.groovy b/src/main/groovy/InWhileLoopWrapper.groovy new file mode 100644 index 0000000..981f146 --- /dev/null +++ b/src/main/groovy/InWhileLoopWrapper.groovy @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.VariableScope +import org.codehaus.groovy.ast.expr.BooleanExpression +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.CatchStatement +import org.codehaus.groovy.ast.stmt.ContinueStatement +import org.codehaus.groovy.ast.stmt.EmptyStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.TryCatchStatement +import org.codehaus.groovy.ast.stmt.WhileStatement + +/** + * Wrap the body of a method in a while loop, nested in a try-catch. + * This is the first step in making a tail recursive method iterative. + * + * There are two ways to invoke the next iteration step: + * 1. "continue _RECURE_HERE_" is used by recursive calls outside of closures + * 2. "throw LOOP_EXCEPTION" is used by recursive calls within closures b/c you cannot invoke "continue" from there + * + * @author Johannes Link + */ +@CompileStatic +class InWhileLoopWrapper { + + static final String LOOP_LABEL = '_RECUR_HERE_' + static final GotoRecurHereException LOOP_EXCEPTION = new GotoRecurHereException() + + void wrap(MethodNode method) { + BlockStatement oldBody = method.code as BlockStatement + TryCatchStatement tryCatchStatement = new TryCatchStatement( + oldBody, + EmptyStatement.INSTANCE + ) + tryCatchStatement.addCatch(new CatchStatement( + new Parameter(ClassHelper.make(GotoRecurHereException), 'ignore'), + new ContinueStatement(InWhileLoopWrapper.LOOP_LABEL) + )) + + WhileStatement whileLoop = new WhileStatement( + new BooleanExpression(new ConstantExpression(true)), + new BlockStatement([tryCatchStatement] as List<Statement>, new VariableScope(method.variableScope)) + ) + List<Statement> whileLoopStatements = ((BlockStatement) whileLoop.loopBlock).statements + if (whileLoopStatements.size() > 0) + whileLoopStatements[0].statementLabel = LOOP_LABEL + BlockStatement newBody = new BlockStatement([] as List<Statement>, new VariableScope(method.variableScope)) + newBody.addStatement(whileLoop) + method.code = newBody + } +} + +/** + * Exception will be thrown by recursive calls in closures and caught in while loop to continue to LOOP_LABEL + */ +class GotoRecurHereException extends Throwable { + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/RecursivenessTester.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/RecursivenessTester.groovy b/src/main/groovy/RecursivenessTester.groovy new file mode 100644 index 0000000..7c9545a --- /dev/null +++ b/src/main/groovy/RecursivenessTester.groovy @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.VariableExpression + +/** + * + * Test if a method call is recursive if called within a given method node. + * Handles static calls as well. + * + * Currently known simplifications: + * - Does not check for method overloading or overridden methods. + * - Does not check for matching return types; even void and any object type are considered to be compatible. + * - Argument type matching could be more specific in case of static compilation. + * - Method names via a GString are never considered to be recursive + * + * @author Johannes Link + */ +class RecursivenessTester { + public boolean isRecursive(params) { + assert params.method.class == MethodNode + assert params.call.class == MethodCallExpression || StaticMethodCallExpression + + isRecursive(params.method, params.call) + } + + public boolean isRecursive(MethodNode method, MethodCallExpression call) { + if (!isCallToThis(call)) + return false + // Could be a GStringExpression + if (! (call.method instanceof ConstantExpression)) + return false + if (call.method.value != method.name) + return false + methodParamsMatchCallArgs(method, call) + } + + public boolean isRecursive(MethodNode method, StaticMethodCallExpression call) { + if (!method.isStatic()) + return false + if (method.declaringClass != call.ownerType) + return false + if (call.method != method.name) + return false + methodParamsMatchCallArgs(method, call) + } + + private boolean isCallToThis(MethodCallExpression call) { + if (call.objectExpression == null) + return call.isImplicitThis() + if (! (call.objectExpression instanceof VariableExpression)) { + return false + } + return call.objectExpression.isThisExpression() + } + + private boolean methodParamsMatchCallArgs(method, call) { + if (method.parameters.size() != call.arguments.expressions.size()) + return false + def classNodePairs = [method.parameters*.type, call.arguments*.type].transpose() + return classNodePairs.every { ClassNode paramType, ClassNode argType -> + return areTypesCallCompatible(argType, paramType) + } + } + + /** + * Parameter type and calling argument type can both be derived from the other since typing information is + * optional in Groovy. + * Since int is not derived from Integer (nor the other way around) we compare the boxed types + */ + private areTypesCallCompatible(ClassNode argType, ClassNode paramType) { + ClassNode boxedArg = ClassHelper.getWrapper(argType) + ClassNode boxedParam = ClassHelper.getWrapper(paramType) + return boxedArg.isDerivedFrom(boxedParam) || boxedParam.isDerivedFrom(boxedArg) + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/ReturnAdderForClosures.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/ReturnAdderForClosures.groovy b/src/main/groovy/ReturnAdderForClosures.groovy new file mode 100644 index 0000000..64ebce7 --- /dev/null +++ b/src/main/groovy/ReturnAdderForClosures.groovy @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.classgen.ReturnAdder + +/** + * Adds explicit return statements to implicit return points in a closure. This is necessary since + * tail-recursion is detected by having the recursive call within the return statement. + * + * @author Johannes Link + */ +class ReturnAdderForClosures extends CodeVisitorSupport { + + synchronized void visitMethod(MethodNode method) { + method.code.visit(this) + } + + public void visitClosureExpression(ClosureExpression expression) { + //Create a dummy method with the closure's code as the method's code. Then user ReturnAdder, which only works for methods. + MethodNode node = new MethodNode("dummy", 0, ClassHelper.OBJECT_TYPE, Parameter.EMPTY_ARRAY, ClassNode.EMPTY_ARRAY, expression.code); + new ReturnAdder().visitMethod(node); + super.visitClosureExpression(expression) + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/ReturnStatementToIterationConverter.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/ReturnStatementToIterationConverter.groovy b/src/main/groovy/ReturnStatementToIterationConverter.groovy new file mode 100644 index 0000000..2c75f4f --- /dev/null +++ b/src/main/groovy/ReturnStatementToIterationConverter.groovy @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.expr.BinaryExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.TupleExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.ExpressionStatement +import org.codehaus.groovy.ast.stmt.ReturnStatement +import org.codehaus.groovy.ast.stmt.Statement + +import static org.codehaus.groovy.ast.tools.GeneralUtils.assignS +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Translates all return statements into an invocation of the next iteration. This can be either + * - "continue LOOP_LABEL": Outside closures + * - "throw LOOP_EXCEPTION": Inside closures + * + * Moreover, before adding the recur statement the iteration parameters (originally the method args) + * are set to their new value. To prevent variable aliasing parameters will be copied into temp vars + * before they are changes so that their current iteration value can be used when setting other params. + * + * There's probably place for optimizing the amount of variable copying being done, e.g. + * parameters that are only handed through must not be copied at all. + * + * @author Johannes Link + */ +@CompileStatic +class ReturnStatementToIterationConverter { + + Statement recurStatement = AstHelper.recurStatement() + + Statement convert(ReturnStatement statement, Map<Integer, Map> positionMapping) { + Expression recursiveCall = statement.expression + if (!isAMethodCalls(recursiveCall)) + return statement + + Map<String, Map> tempMapping = [:] + Map tempDeclarations = [:] + List<ExpressionStatement> argAssignments = [] + + BlockStatement result = new BlockStatement() + result.statementLabel = statement.statementLabel + + /* Create temp declarations for all method arguments. + * Add the declarations and var mapping to tempMapping and tempDeclarations for further reference. + */ + getArguments(recursiveCall).eachWithIndex { Expression expression, int index -> + ExpressionStatement tempDeclaration = createTempDeclaration(index, positionMapping, tempMapping, tempDeclarations) + result.addStatement(tempDeclaration) + } + + /* + * Assign the iteration variables their new value before recuring + */ + getArguments(recursiveCall).eachWithIndex { Expression expression, int index -> + ExpressionStatement argAssignment = createAssignmentToIterationVariable(expression, index, positionMapping) + argAssignments.add(argAssignment) + result.addStatement(argAssignment) + } + + Set<String> unusedTemps = replaceAllArgUsages(argAssignments, tempMapping) + for (String temp : unusedTemps) { + result.statements.remove(tempDeclarations[temp]) + } + result.addStatement(recurStatement) + + return result + } + + private ExpressionStatement createAssignmentToIterationVariable(Expression expression, int index, Map<Integer, Map> positionMapping) { + String argName = positionMapping[index]['name'] + ClassNode argAndTempType = positionMapping[index]['type'] as ClassNode + ExpressionStatement argAssignment = (ExpressionStatement) assignS(varX(argName, argAndTempType), expression) + argAssignment + } + + private ExpressionStatement createTempDeclaration(int index, Map<Integer, Map> positionMapping, Map<String, Map> tempMapping, Map tempDeclarations) { + String argName = positionMapping[index]['name'] + String tempName = "_${argName}_" + ClassNode argAndTempType = positionMapping[index]['type'] as ClassNode + ExpressionStatement tempDeclaration = AstHelper.createVariableAlias(tempName, argAndTempType, argName) + tempMapping[argName] = [name: tempName, type: argAndTempType] + tempDeclarations[tempName] = tempDeclaration + return tempDeclaration + } + + private List<Expression> getArguments(Expression recursiveCall) { + if (recursiveCall instanceof MethodCallExpression) + return ((TupleExpression) ((MethodCallExpression) recursiveCall).arguments).expressions + if (recursiveCall instanceof StaticMethodCallExpression) + return ((TupleExpression) ((StaticMethodCallExpression) recursiveCall).arguments).expressions + } + + private boolean isAMethodCalls(Expression expression) { + expression.class in [MethodCallExpression, StaticMethodCallExpression] + } + + private Set<String> replaceAllArgUsages(List<ExpressionStatement> iterationVariablesAssignmentNodes, Map<String, Map> tempMapping) { + Set<String> unusedTempNames = tempMapping.values().collect {Map nameAndType -> (String) nameAndType['name']} as Set<String> + VariableReplacedListener tracker = new UsedVariableTracker() + for (ExpressionStatement statement : iterationVariablesAssignmentNodes) { + replaceArgUsageByTempUsage((BinaryExpression) statement.expression, tempMapping, tracker) + } + unusedTempNames = unusedTempNames - tracker.usedVariableNames + return unusedTempNames + } + + private void replaceArgUsageByTempUsage(BinaryExpression binary, Map tempMapping, UsedVariableTracker tracker) { + VariableAccessReplacer replacer = new VariableAccessReplacer(nameAndTypeMapping: tempMapping, listener: tracker) + // Replacement must only happen in binary.rightExpression. It's a hack in VariableExpressionReplacer which takes care of that. + replacer.replaceIn(binary) + } +} + +@CompileStatic +class UsedVariableTracker implements VariableReplacedListener { + + final Set<String> usedVariableNames = [] as Set + + @Override + void variableReplaced(VariableExpression oldVar, VariableExpression newVar) { + usedVariableNames.add(newVar.name) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/StatementReplacer.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/StatementReplacer.groovy b/src/main/groovy/StatementReplacer.groovy new file mode 100644 index 0000000..3a9dab3 --- /dev/null +++ b/src/main/groovy/StatementReplacer.groovy @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.DoWhileStatement +import org.codehaus.groovy.ast.stmt.ForStatement +import org.codehaus.groovy.ast.stmt.IfStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.WhileStatement + +/** + * Tool for replacing Statement objects in an AST by other Statement instances. + * + * Within @TailRecursive it is used to swap ReturnStatements with looping back to RECUR label + * + * @author Johannes Link + */ +@CompileStatic +class StatementReplacer extends CodeVisitorSupport { + + Closure<Boolean> when = { Statement node -> false } + Closure<Statement> replaceWith = { Statement statement -> statement } + int closureLevel = 0 + + void replaceIn(ASTNode root) { + root.visit(this) + } + + public void visitClosureExpression(ClosureExpression expression) { + closureLevel++ + try { + super.visitClosureExpression(expression) + } finally { + closureLevel-- + } + } + + public void visitBlockStatement(BlockStatement block) { + List<Statement> copyOfStatements = new ArrayList<Statement>(block.statements) + copyOfStatements.eachWithIndex { Statement statement, int index -> + replaceIfNecessary(statement) { Statement node -> block.statements[index] = node } + } + super.visitBlockStatement(block); + } + + public void visitIfElse(IfStatement ifElse) { + replaceIfNecessary(ifElse.ifBlock) { Statement s -> ifElse.ifBlock = s } + replaceIfNecessary(ifElse.elseBlock) { Statement s -> ifElse.elseBlock = s } + super.visitIfElse(ifElse); + } + + public void visitForLoop(ForStatement forLoop) { + replaceIfNecessary(forLoop.loopBlock) { Statement s -> forLoop.loopBlock = s } + super.visitForLoop(forLoop); + } + + public void visitWhileLoop(WhileStatement loop) { + replaceIfNecessary(loop.loopBlock) { Statement s -> loop.loopBlock = s } + super.visitWhileLoop(loop); + } + + public void visitDoWhileLoop(DoWhileStatement loop) { + replaceIfNecessary(loop.loopBlock) { Statement s -> loop.loopBlock = s } + super.visitDoWhileLoop(loop); + } + + + private void replaceIfNecessary(Statement nodeToCheck, Closure replacementCode) { + if (conditionFulfilled(nodeToCheck)) { + ASTNode replacement = replaceWith(nodeToCheck) + replacement.setSourcePosition(nodeToCheck); + replacement.copyNodeMetaData(nodeToCheck); + replacementCode(replacement) + } + } + + private boolean conditionFulfilled(ASTNode nodeToCheck) { + if (when.maximumNumberOfParameters < 2) + return when(nodeToCheck) + else + return when(nodeToCheck, isInClosure()) + } + + private boolean isInClosure() { + closureLevel > 0 + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/StringUtil.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/StringUtil.groovy b/src/main/groovy/StringUtil.groovy new file mode 100644 index 0000000..ed83e53 --- /dev/null +++ b/src/main/groovy/StringUtil.groovy @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.util + +import groovy.transform.CompileStatic + +/** + * String utility functions. + */ +@CompileStatic +class StringUtil { + /** + * Provides Groovy with functionality similar to the unix tr command + * which translates a string replacing characters from a source set + * with characters from a replacement set. + * + * @since 1.7.3 + */ + static String tr(String text, String source, String replacement) { + if (!text || !source) { return text } + source = expandHyphen(source) + replacement = expandHyphen(replacement) + + // padding replacement with a last character, if necessary + replacement = replacement.padRight(source.size(), replacement[replacement.size() - 1]) + + return text.collect { String original -> + if (source.contains(original)) { + replacement[source.lastIndexOf(original)] + } else { + original + } + }.join('') + } + + // no expansion for hyphen at start or end of Strings + private static String expandHyphen(String text) { + if (!text.contains('-')) { return text } + return text.replaceAll(/(.)-(.)/, { all, begin, end -> (begin..end).join('') }) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/TailRecursiveASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/TailRecursiveASTTransformation.groovy b/src/main/groovy/TailRecursiveASTTransformation.groovy new file mode 100644 index 0000000..0605f18 --- /dev/null +++ b/src/main/groovy/TailRecursiveASTTransformation.groovy @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import groovy.transform.Memoized +import groovy.transform.TailRecursive +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.AnnotationNode +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.TernaryExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.ReturnStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.classgen.ReturnAdder +import org.codehaus.groovy.classgen.VariableScopeVisitor +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.transform.AbstractASTTransformation +import org.codehaus.groovy.transform.GroovyASTTransformation + +/** + * Handles generation of code for the @TailRecursive annotation. + * + * It's doing its work in the earliest possible compile phase + * + * @author Johannes Link + */ +@CompileStatic +@GroovyASTTransformation(phase = CompilePhase.SEMANTIC_ANALYSIS) +class TailRecursiveASTTransformation extends AbstractASTTransformation { + + private static final Class MY_CLASS = TailRecursive.class; + private static final ClassNode MY_TYPE = new ClassNode(MY_CLASS); + static final String MY_TYPE_NAME = "@" + MY_TYPE.getNameWithoutPackage() + private HasRecursiveCalls hasRecursiveCalls = new HasRecursiveCalls() + private TernaryToIfStatementConverter ternaryToIfStatement = new TernaryToIfStatementConverter() + + + @Override + public void visit(ASTNode[] nodes, SourceUnit source) { + init(nodes, source); + + MethodNode method = nodes[1] as MethodNode + + if (method.isAbstract()) { + addError("Annotation " + MY_TYPE_NAME + " cannot be used for abstract methods.", method); + return; + } + + if (hasAnnotation(method, ClassHelper.make(Memoized))) { + ClassNode memoizedClassNode = ClassHelper.make(Memoized) + for (AnnotationNode annotationNode in method.annotations) { + if (annotationNode.classNode == MY_TYPE) + break + if (annotationNode.classNode == memoizedClassNode) { + addError("Annotation " + MY_TYPE_NAME + " must be placed before annotation @Memoized.", annotationNode) + return + } + } + } + + if (!hasRecursiveMethodCalls(method)) { + AnnotationNode annotationNode = method.getAnnotations(ClassHelper.make(TailRecursive))[0] + addError("No recursive calls detected. You must remove annotation " + MY_TYPE_NAME + ".", annotationNode) + return; + } + + transformToIteration(method, source) + ensureAllRecursiveCallsHaveBeenTransformed(method) + } + + private boolean hasAnnotation(MethodNode methodNode, ClassNode annotation) { + List annots = methodNode.getAnnotations(annotation); + return (annots != null && annots.size() > 0); + } + + + private void transformToIteration(MethodNode method, SourceUnit source) { + if (method.isVoidMethod()) { + transformVoidMethodToIteration(method, source) + } else { + transformNonVoidMethodToIteration(method, source) + } + } + + private void transformVoidMethodToIteration(MethodNode method, SourceUnit source) { + addError("Void methods are not supported by @TailRecursive yet.", method) + } + + private void transformNonVoidMethodToIteration(MethodNode method, SourceUnit source) { + addMissingDefaultReturnStatement(method) + replaceReturnsWithTernariesToIfStatements(method) + wrapMethodBodyWithWhileLoop(method) + + Map<String, Map> nameAndTypeMapping = name2VariableMappingFor(method) + replaceAllAccessToParams(method, nameAndTypeMapping) + addLocalVariablesForAllParameters(method, nameAndTypeMapping) //must happen after replacing access to params + + Map<Integer, Map> positionMapping = position2VariableMappingFor(method) + replaceAllRecursiveReturnsWithIteration(method, positionMapping) + repairVariableScopes(source, method) + } + + private void repairVariableScopes(SourceUnit source, MethodNode method) { + new VariableScopeVisitor(source).visitClass(method.declaringClass) + } + + private void replaceReturnsWithTernariesToIfStatements(MethodNode method) { + Closure<Boolean> whenReturnWithTernary = { ASTNode node -> + if (!(node instanceof ReturnStatement)) { + return false + } + return (((ReturnStatement) node).expression instanceof TernaryExpression) + } + Closure<Statement> replaceWithIfStatement = { ReturnStatement statement -> + ternaryToIfStatement.convert(statement) + } + StatementReplacer replacer = new StatementReplacer(when: whenReturnWithTernary, replaceWith: replaceWithIfStatement) + replacer.replaceIn(method.code) + + } + + private void addLocalVariablesForAllParameters(MethodNode method, Map<String, Map> nameAndTypeMapping) { + BlockStatement code = method.code as BlockStatement + nameAndTypeMapping.each { String paramName, Map localNameAndType -> + code.statements.add(0, AstHelper.createVariableDefinition( + (String) localNameAndType['name'], + (ClassNode) localNameAndType['type'], + new VariableExpression(paramName, (ClassNode) localNameAndType['type']) + )) + } + } + + private void replaceAllAccessToParams(MethodNode method, Map<String, Map> nameAndTypeMapping) { + new VariableAccessReplacer(nameAndTypeMapping: nameAndTypeMapping).replaceIn(method.code) + } + + // Public b/c there are tests for this method + Map<String, Map> name2VariableMappingFor(MethodNode method) { + Map<String, Map> nameAndTypeMapping = [:] + method.parameters.each { Parameter param -> + String paramName = param.name + ClassNode paramType = param.type as ClassNode + String iterationVariableName = iterationVariableName(paramName) + nameAndTypeMapping[paramName] = [name: iterationVariableName, type: paramType] + } + return nameAndTypeMapping + } + + // Public b/c there are tests for this method + Map<Integer, Map> position2VariableMappingFor(MethodNode method) { + Map<Integer, Map> positionMapping = [:] + method.parameters.eachWithIndex { Parameter param, int index -> + String paramName = param.name + ClassNode paramType = param.type as ClassNode + String iterationVariableName = this.iterationVariableName(paramName) + positionMapping[index] = [name: iterationVariableName, type: paramType] + } + return positionMapping + } + + private String iterationVariableName(String paramName) { + '_' + paramName + '_' + } + + private void replaceAllRecursiveReturnsWithIteration(MethodNode method, Map positionMapping) { + replaceRecursiveReturnsOutsideClosures(method, positionMapping) + replaceRecursiveReturnsInsideClosures(method, positionMapping) + } + + private void replaceRecursiveReturnsOutsideClosures(MethodNode method, Map<Integer, Map> positionMapping) { + Closure<Boolean> whenRecursiveReturn = { Statement statement, boolean inClosure -> + if (inClosure) + return false + if (!(statement instanceof ReturnStatement)) { + return false + } + Expression inner = ((ReturnStatement) statement).expression + if (!(inner instanceof MethodCallExpression) && !(inner instanceof StaticMethodCallExpression)) { + return false + } + return isRecursiveIn(inner, method) + } + Closure<Statement> replaceWithContinueBlock = { ReturnStatement statement -> + new ReturnStatementToIterationConverter().convert(statement, positionMapping) + } + def replacer = new StatementReplacer(when: whenRecursiveReturn, replaceWith: replaceWithContinueBlock) + replacer.replaceIn(method.code) + } + + private void replaceRecursiveReturnsInsideClosures(MethodNode method, Map<Integer, Map> positionMapping) { + Closure<Boolean> whenRecursiveReturn = { Statement statement, boolean inClosure -> + if (!inClosure) + return false + if (!(statement instanceof ReturnStatement)) { + return false + } + Expression inner = ((ReturnStatement )statement).expression + if (!(inner instanceof MethodCallExpression) && !(inner instanceof StaticMethodCallExpression)) { + return false + } + return isRecursiveIn(inner, method) + } + Closure<Statement> replaceWithThrowLoopException = { ReturnStatement statement -> + new ReturnStatementToIterationConverter(recurStatement: AstHelper.recurByThrowStatement()).convert(statement, positionMapping) + } + StatementReplacer replacer = new StatementReplacer(when: whenRecursiveReturn, replaceWith: replaceWithThrowLoopException) + replacer.replaceIn(method.code) + } + + private void wrapMethodBodyWithWhileLoop(MethodNode method) { + new InWhileLoopWrapper().wrap(method) + } + + private void addMissingDefaultReturnStatement(MethodNode method) { + new ReturnAdder().visitMethod(method) + new ReturnAdderForClosures().visitMethod(method) + } + + private void ensureAllRecursiveCallsHaveBeenTransformed(MethodNode method) { + List<Expression> remainingRecursiveCalls = new CollectRecursiveCalls().collect(method) + for(Expression expression : remainingRecursiveCalls) { + addError("Recursive call could not be transformed by @TailRecursive. Maybe it's not a tail call.", expression) + } + } + + private boolean hasRecursiveMethodCalls(MethodNode method) { + hasRecursiveCalls.test(method) + } + + private boolean isRecursiveIn(Expression methodCall, MethodNode method) { + if (methodCall instanceof MethodCallExpression) + return new RecursivenessTester().isRecursive(method, (MethodCallExpression) methodCall) + if (methodCall instanceof StaticMethodCallExpression) + return new RecursivenessTester().isRecursive(method, (StaticMethodCallExpression) methodCall) + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/TernaryToIfStatementConverter.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/TernaryToIfStatementConverter.groovy b/src/main/groovy/TernaryToIfStatementConverter.groovy new file mode 100644 index 0000000..c47e1d2 --- /dev/null +++ b/src/main/groovy/TernaryToIfStatementConverter.groovy @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.expr.TernaryExpression +import org.codehaus.groovy.ast.stmt.IfStatement +import org.codehaus.groovy.ast.stmt.ReturnStatement +import org.codehaus.groovy.ast.stmt.Statement + +/** + * Since a ternary statement has more than one exit point tail-recursiveness testing cannot be easily done. + * Therefore this class translates a ternary statement (or Elvis operator) into the equivalent if-else statement. + * + * @author Johannes Link + */ +@CompileStatic +class TernaryToIfStatementConverter { + + Statement convert(ReturnStatement statementWithInnerTernaryExpression) { + if (!(statementWithInnerTernaryExpression.expression instanceof TernaryExpression)) + return statementWithInnerTernaryExpression + TernaryExpression ternary = statementWithInnerTernaryExpression.expression as TernaryExpression + return new IfStatement(ternary.booleanExpression, new ReturnStatement(ternary.trueExpression), new ReturnStatement(ternary.falseExpression)) + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/ThreadInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/ThreadInterruptibleASTTransformation.groovy b/src/main/groovy/ThreadInterruptibleASTTransformation.groovy new file mode 100644 index 0000000..a4fb4c3 --- /dev/null +++ b/src/main/groovy/ThreadInterruptibleASTTransformation.groovy @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.CompileStatic +import groovy.transform.ThreadInterrupt +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.expr.ArgumentListExpression +import org.codehaus.groovy.ast.expr.ClassExpression +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.control.CompilePhase + +/** + * Allows "interrupt-safe" executions of scripts by adding Thread.currentThread().isInterrupted() + * checks on loops (for, while, do) and first statement of closures. By default, also adds an interrupt check + * statement on the beginning of method calls. + * + * @see groovy.transform.ThreadInterrupt + * + * @author Cedric Champeau + * @author Hamlet D'Arcy + * + * @since 1.8.0 + */ +@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) +@CompileStatic +public class ThreadInterruptibleASTTransformation extends AbstractInterruptibleASTTransformation { + + private static final ClassNode MY_TYPE = ClassHelper.make(ThreadInterrupt) + private static final ClassNode THREAD_TYPE = ClassHelper.make(Thread) + private static final MethodNode CURRENTTHREAD_METHOD + private static final MethodNode ISINTERRUPTED_METHOD + + static { + CURRENTTHREAD_METHOD = THREAD_TYPE.getMethod('currentThread', Parameter.EMPTY_ARRAY) + ISINTERRUPTED_METHOD = THREAD_TYPE.getMethod('isInterrupted', Parameter.EMPTY_ARRAY) + } + + protected ClassNode type() { + return MY_TYPE; + } + + protected String getErrorMessage() { + 'Execution interrupted. The current thread has been interrupted.' + } + + protected Expression createCondition() { + def currentThread = new MethodCallExpression(new ClassExpression(THREAD_TYPE), + 'currentThread', + ArgumentListExpression.EMPTY_ARGUMENTS) + currentThread.methodTarget = CURRENTTHREAD_METHOD + def isInterrupted = new MethodCallExpression( + currentThread, + 'isInterrupted', ArgumentListExpression.EMPTY_ARGUMENTS) + isInterrupted.methodTarget = ISINTERRUPTED_METHOD + [currentThread, isInterrupted]*.implicitThis = false + + isInterrupted + } + + + @Override + public void visitClosureExpression(ClosureExpression closureExpr) { + def code = closureExpr.code + closureExpr.code = wrapBlock(code) + super.visitClosureExpression closureExpr + } + + @Override + public void visitMethod(MethodNode node) { + if (checkOnMethodStart && !node.isSynthetic() && !node.isAbstract()) { + def code = node.code + node.code = wrapBlock(code); + } + super.visitMethod(node) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/TimedInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/TimedInterruptibleASTTransformation.groovy b/src/main/groovy/TimedInterruptibleASTTransformation.groovy new file mode 100644 index 0000000..fbc923b --- /dev/null +++ b/src/main/groovy/TimedInterruptibleASTTransformation.groovy @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.TimedInterrupt +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.AnnotatedNode +import org.codehaus.groovy.ast.AnnotationNode +import org.codehaus.groovy.ast.ClassCodeVisitorSupport +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.FieldNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.PropertyNode +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.expr.DeclarationExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.DoWhileStatement +import org.codehaus.groovy.ast.stmt.ForStatement +import org.codehaus.groovy.ast.stmt.WhileStatement +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.SourceUnit + +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException + +import static org.codehaus.groovy.ast.ClassHelper.make +import static org.codehaus.groovy.ast.tools.GeneralUtils.args +import static org.codehaus.groovy.ast.tools.GeneralUtils.callX +import static org.codehaus.groovy.ast.tools.GeneralUtils.classX +import static org.codehaus.groovy.ast.tools.GeneralUtils.constX +import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX +import static org.codehaus.groovy.ast.tools.GeneralUtils.ifS +import static org.codehaus.groovy.ast.tools.GeneralUtils.ltX +import static org.codehaus.groovy.ast.tools.GeneralUtils.plusX +import static org.codehaus.groovy.ast.tools.GeneralUtils.propX +import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Allows "interrupt-safe" executions of scripts by adding timer expiration + * checks on loops (for, while, do) and first statement of closures. By default, + * also adds an interrupt check statement on the beginning of method calls. + * + * @author Cedric Champeau + * @author Hamlet D'Arcy + * @author Paul King + * @see groovy.transform.ThreadInterrupt + * @since 1.8.0 + */ +@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) +public class TimedInterruptibleASTTransformation extends AbstractASTTransformation { + + private static final ClassNode MY_TYPE = make(TimedInterrupt) + private static final String CHECK_METHOD_START_MEMBER = 'checkOnMethodStart' + private static final String APPLY_TO_ALL_CLASSES = 'applyToAllClasses' + private static final String APPLY_TO_ALL_MEMBERS = 'applyToAllMembers' + private static final String THROWN_EXCEPTION_TYPE = "thrown" + + public void visit(ASTNode[] nodes, SourceUnit source) { + init(nodes, source); + AnnotationNode node = nodes[0] + AnnotatedNode annotatedNode = nodes[1] + if (!MY_TYPE.equals(node.getClassNode())) { + internalError("Transformation called from wrong annotation: $node.classNode.name") + } + + def checkOnMethodStart = getConstantAnnotationParameter(node, CHECK_METHOD_START_MEMBER, Boolean.TYPE, true) + def applyToAllMembers = getConstantAnnotationParameter(node, APPLY_TO_ALL_MEMBERS, Boolean.TYPE, true) + def applyToAllClasses = applyToAllMembers ? getConstantAnnotationParameter(node, APPLY_TO_ALL_CLASSES, Boolean.TYPE, true) : false + def maximum = getConstantAnnotationParameter(node, 'value', Long.TYPE, Long.MAX_VALUE) + def thrown = AbstractInterruptibleASTTransformation.getClassAnnotationParameter(node, THROWN_EXCEPTION_TYPE, make(TimeoutException)) + + Expression unit = node.getMember('unit') ?: propX(classX(TimeUnit), "SECONDS") + + // should be limited to the current SourceUnit or propagated to the whole CompilationUnit + // DO NOT inline visitor creation in code below. It has state that must not persist between calls + if (applyToAllClasses) { + // guard every class and method defined in this script + source.getAST()?.classes?.each { ClassNode it -> + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitClass(it) + } + } else if (annotatedNode instanceof ClassNode) { + // only guard this particular class + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitClass annotatedNode + } else if (!applyToAllMembers && annotatedNode instanceof MethodNode) { + // only guard this particular method (plus initCode for class) + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitMethod annotatedNode + visitor.visitClass annotatedNode.declaringClass + } else if (!applyToAllMembers && annotatedNode instanceof FieldNode) { + // only guard this particular field (plus initCode for class) + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitField annotatedNode + visitor.visitClass annotatedNode.declaringClass + } else if (!applyToAllMembers && annotatedNode instanceof DeclarationExpression) { + // only guard this particular declaration (plus initCode for class) + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitDeclarationExpression annotatedNode + visitor.visitClass annotatedNode.declaringClass + } else { + // only guard the script class + source.getAST()?.classes?.each { ClassNode it -> + if (it.isScript()) { + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitClass(it) + } + } + } + } + + static def getConstantAnnotationParameter(AnnotationNode node, String parameterName, Class type, defaultValue) { + def member = node.getMember(parameterName) + if (member) { + if (member instanceof ConstantExpression) { + // TODO not sure this try offers value - testing Groovy annotation type handing - throw GroovyBugError or remove? + try { + return member.value.asType(type) + } catch (ignore) { + internalError("Expecting boolean value for ${parameterName} annotation parameter. Found $member") + } + } else { + internalError("Expecting boolean value for ${parameterName} annotation parameter. Found $member") + } + } + return defaultValue + } + + private static void internalError(String message) { + throw new RuntimeException("Internal error: $message") + } + + private static class TimedInterruptionVisitor extends ClassCodeVisitorSupport { + final private SourceUnit source + final private boolean checkOnMethodStart + final private boolean applyToAllClasses + final private boolean applyToAllMembers + private FieldNode expireTimeField = null + private FieldNode startTimeField = null + private final Expression unit + private final maximum + private final ClassNode thrown + private final String basename + + TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, hash) { + this.source = source + this.checkOnMethodStart = checkOnMethodStart + this.applyToAllClasses = applyToAllClasses + this.applyToAllMembers = applyToAllMembers + this.unit = unit + this.maximum = maximum + this.thrown = thrown + this.basename = 'timedInterrupt' + hash + } + + /** + * @return Returns the interruption check statement. + */ + final createInterruptStatement() { + ifS( + + ltX( + propX(varX("this"), basename + '$expireTime'), + callX(make(System), 'nanoTime') + ), + throwS( + ctorX(thrown, + args( + plusX( + plusX( + constX('Execution timed out after ' + maximum + ' '), + callX(callX(unit, 'name'), 'toLowerCase', propX(classX(Locale), 'US')) + ), + plusX( + constX('. Start time: '), + propX(varX("this"), basename + '$startTime') + ) + ) + + ) + ) + ) + ) + } + + /** + * Takes a statement and wraps it into a block statement which first element is the interruption check statement. + * @param statement the statement to be wrapped + * @return a {@link BlockStatement block statement} which first element is for checking interruption, and the + * second one the statement to be wrapped. + */ + private wrapBlock(statement) { + def stmt = new BlockStatement(); + stmt.addStatement(createInterruptStatement()); + stmt.addStatement(statement); + stmt + } + + @Override + void visitClass(ClassNode node) { + if (node.getDeclaredField(basename + '$expireTime')) { + return + } + expireTimeField = node.addField(basename + '$expireTime', + ACC_FINAL | ACC_PRIVATE, + ClassHelper.long_TYPE, + plusX( + callX(make(System), 'nanoTime'), + callX( + propX(classX(TimeUnit), 'NANOSECONDS'), + 'convert', + args(constX(maximum, true), unit) + ) + ) + ); + expireTimeField.synthetic = true + startTimeField = node.addField(basename + '$startTime', + ACC_FINAL | ACC_PRIVATE, + make(Date), + ctorX(make(Date)) + ) + startTimeField.synthetic = true + + // force these fields to be initialized first + node.fields.remove(expireTimeField) + node.fields.remove(startTimeField) + node.fields.add(0, startTimeField) + node.fields.add(0, expireTimeField) + if (applyToAllMembers) { + super.visitClass node + } + } + + @Override + void visitClosureExpression(ClosureExpression closureExpr) { + def code = closureExpr.code + if (code instanceof BlockStatement) { + code.statements.add(0, createInterruptStatement()) + } else { + closureExpr.code = wrapBlock(code) + } + super.visitClosureExpression closureExpr + } + + @Override + void visitField(FieldNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitField node + } + } + + @Override + void visitProperty(PropertyNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitProperty node + } + } + + /** + * Shortcut method which avoids duplicating code for every type of loop. + * Actually wraps the loopBlock of different types of loop statements. + */ + private visitLoop(loopStatement) { + def statement = loopStatement.loopBlock + loopStatement.loopBlock = wrapBlock(statement) + } + + @Override + void visitForLoop(ForStatement forStatement) { + visitLoop(forStatement) + super.visitForLoop(forStatement) + } + + @Override + void visitDoWhileLoop(final DoWhileStatement doWhileStatement) { + visitLoop(doWhileStatement) + super.visitDoWhileLoop(doWhileStatement) + } + + @Override + void visitWhileLoop(final WhileStatement whileStatement) { + visitLoop(whileStatement) + super.visitWhileLoop(whileStatement) + } + + @Override + void visitMethod(MethodNode node) { + if (checkOnMethodStart && !node.isSynthetic() && !node.isStatic() && !node.isAbstract()) { + def code = node.code + node.code = wrapBlock(code); + } + if (!node.isSynthetic() && !node.isStatic()) { + super.visitMethod(node) + } + } + + protected SourceUnit getSourceUnit() { + return source; + } + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/0edfcde9/src/main/groovy/TransformTestHelper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/TransformTestHelper.groovy b/src/main/groovy/TransformTestHelper.groovy new file mode 100644 index 0000000..d9921d5 --- /dev/null +++ b/src/main/groovy/TransformTestHelper.groovy @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.tools.ast + +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.classgen.GeneratorContext +import org.codehaus.groovy.control.CompilationUnit +import org.codehaus.groovy.control.CompilationUnit.PrimaryClassNodeOperation +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.transform.ASTTransformation + +import java.security.CodeSource + +/* +* This TestHarness exists so that a global transform can be run without +* using the Jar services mechanism, which requires building a jar. +* +* To use this simply create an instance of TransformTestHelper with +* an ASTTransformation and CompilePhase, then invoke parse(File) or +* parse(String). +* +* This test harness is not exactly the same as executing a global transformation +* but can greatly aide in debugging and testing a transform. You should still +* test your global transformation when packaged as a jar service before +* releasing it. +* +* @author Hamlet D'Arcy +*/ +class TransformTestHelper { + + private final ASTTransformation transform + private final CompilePhase phase + + /** + * Creates the test helper. + * @param transform + * the transform to run when compiling the file later + * @param phase + * the phase to run the transform in + */ + def TransformTestHelper(ASTTransformation transform, CompilePhase phase) { + this.transform = transform + this.phase = phase + } + + /** + * Compiles the File into a Class applying the transform specified in the constructor. + * @input input + * must be a groovy source file + */ + public Class parse(File input) { + TestHarnessClassLoader loader = new TestHarnessClassLoader(transform, phase) + return loader.parseClass(input) + } + + /** + * Compiles the String into a Class applying the transform specified in the constructor. + * @input input + * must be a valid groovy source string + */ + public Class parse(String input) { + TestHarnessClassLoader loader = new TestHarnessClassLoader(transform, phase) + return loader.parseClass(input) + } +} + +/** +* ClassLoader exists so that TestHarnessOperation can be wired into the compile. +* +* @author Hamlet D'Arcy +*/ [email protected] class TestHarnessClassLoader extends GroovyClassLoader { + + private final ASTTransformation transform + private final CompilePhase phase + + TestHarnessClassLoader(ASTTransformation transform, CompilePhase phase) { + this.transform = transform + this.phase = phase + } + + protected CompilationUnit createCompilationUnit(CompilerConfiguration config, CodeSource codeSource) { + CompilationUnit cu = super.createCompilationUnit(config, codeSource) + cu.addPhaseOperation(new TestHarnessOperation(transform), phase.getPhaseNumber()) + return cu + } +} + +/** +* Operation exists so that an AstTransformation can be run against the SourceUnit. +* +* @author Hamlet D'Arcy +*/ [email protected] class TestHarnessOperation extends PrimaryClassNodeOperation { + + private final ASTTransformation transform + + def TestHarnessOperation(transform) { + this.transform = transform; + } + + public void call(SourceUnit source, GeneratorContext context, ClassNode classNode) { + transform.visit(null, source) + } +}
