[ https://issues.apache.org/jira/browse/BEAM-3612?focusedWorklogId=165044&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-165044 ]
ASF GitHub Bot logged work on BEAM-3612: ---------------------------------------- Author: ASF GitHub Bot Created on: 12/Nov/18 18:03 Start Date: 12/Nov/18 18:03 Worklog Time Spent: 10m Work Description: aaltay closed pull request #7000: [BEAM-3612] Add a shim generator tool URL: https://github.com/apache/beam/pull/7000 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/sdks/go/cmd/starcgen/starcgen.go b/sdks/go/cmd/starcgen/starcgen.go new file mode 100644 index 00000000000..87e80110b39 --- /dev/null +++ b/sdks/go/cmd/starcgen/starcgen.go @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// starcgen is a tool to generate specialized type assertion shims to be +// used in Apache Beam Go SDK pipelines instead of the default reflection shim. +// This is done through static analysis of go sources for the package in question. +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/apache/beam/sdks/go/pkg/beam/util/starcgenx" +) + +var ( + inputs = flag.String("inputs", "", "comma separated list of file with types to create") + output = flag.String("output", "", "output file with types to create") + ids = flag.String("identifiers", "", "comma separated list of package local identifiers for which to generate code") +) + +// Generate takes the typechecked inputs, and generates the shim file for the relevant +// identifiers. +func Generate(w io.Writer, filename, pkg string, ids []string, fset *token.FileSet, files []*ast.File) error { + e := starcgenx.NewExtractor(pkg) + e.Ids = ids + + // Importing from source should work in most cases. + imp := importer.For("source", nil) + if err := e.FromAsts(imp, fset, files); err != nil { + // Always print out the debugging info to the file. + if _, errw := w.Write(e.Bytes()); errw != nil { + return fmt.Errorf("error writing debug data to file after err %v:%v", err, errw) + } + return fmt.Errorf("error extracting from asts: %v", err) + } + + e.Print("*/\n") + data := e.Generate(filename) + if err := write(w, []byte(license)); err != nil { + return err + } + return write(w, data) +} + +func write(w io.Writer, data []byte) error { + n, err := w.Write(data) + if err != nil && n < len(data) { + return fmt.Errorf("short write of data got %d, want %d", n, len(data)) + } + return err +} + +func usage() { + fmt.Fprintf(os.Stderr, "Usage: %v [options] --inputs=<comma separated of go files>\n", filepath.Base(os.Args[0])) + flag.PrintDefaults() +} + +func main() { + flag.Usage = usage + flag.Parse() + + log.SetFlags(log.Lshortfile) + log.SetPrefix("starcgen: ") + + ipts := strings.Split(*inputs, ",") + fset := token.NewFileSet() + var fs []*ast.File + var pkg string + + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + log.Fatal(err) + } + + for _, i := range ipts { + f, err := parser.ParseFile(fset, i, nil, 0) + if err != nil { + err1 := err + f, err = parser.ParseFile(fset, filepath.Join(dir, i), nil, 0) + if err != nil { + log.Print(err1) + log.Fatal(err) // parse error + } + } + + if pkg == "" { + pkg = f.Name.Name + } else if pkg != f.Name.Name { + log.Fatalf("Input file %v has mismatched package path, got %q, want %q", i, f.Name.Name, pkg) + } + fs = append(fs, f) + } + if pkg == "" { + log.Fatalf("No package detected in input files: %v", inputs) + } + + if *output == "" { + name := pkg + if len(ipts) == 1 { + name = filepath.Base(ipts[0]) + if index := strings.Index(name, "."); index > 0 { + name = name[:index] + } + } + *output = filepath.Join(filepath.Dir(ipts[0]), name+".shims.go") + } + + f, err := os.OpenFile(*output, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Fatal(err) + } + if err := Generate(f, *output, pkg, strings.Split(*ids, ","), fset, fs); err != nil { + log.Fatal(err) + } +} + +const license = `// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +` diff --git a/sdks/go/cmd/starcgen/starcgen_test.go b/sdks/go/cmd/starcgen/starcgen_test.go new file mode 100644 index 00000000000..7282ada8a27 --- /dev/null +++ b/sdks/go/cmd/starcgen/starcgen_test.go @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "go/ast" + "go/parser" + "go/token" + "strings" + "testing" +) + +func TestGenerate(t *testing.T) { + tests := []struct { + name string + pkg string + files []string + ids []string + expected []string + excluded []string + }{ + {name: "genAllSingleFile", files: []string{hello1}, pkg: "hello", ids: []string{}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerContext۰ContextStringГString", "funcMakerFooГString"}, + }, + {name: "genSpecificSingleFile", files: []string{hello1}, pkg: "hello", ids: []string{"MyTitle"}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "funcMakerContext۰ContextStringГString"}, + excluded: []string{"MyOtherDoFn", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerFooГString"}, + }, + {name: "genAllMultiFile", files: []string{hello1, hello2}, pkg: "hello", ids: []string{}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterFunction(anotherFn)", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerContext۰ContextStringГString", "funcMakerFooГString", "funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГFoo"}, + }, + {name: "genSpecificMultiFile1", files: []string{hello1, hello2}, pkg: "hello", ids: []string{"MyTitle"}, + expected: []string{"runtime.RegisterFunction(MyTitle)", "funcMakerContext۰ContextStringГString"}, + excluded: []string{"MyOtherDoFn", "anotherFn", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerFooГString", "funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГFoo"}, + }, + {name: "genSpecificMultiFile2", files: []string{hello1, hello2}, pkg: "hello", ids: []string{"anotherFn"}, + expected: []string{"funcMakerShimx۰EmitterГString", "funcMakerShimx۰EmitterГString"}, + excluded: []string{"MyOtherDoFn", "MyTitle", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerFooГString"}, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + fset := token.NewFileSet() + var fs []*ast.File + for i, f := range test.files { + n, err := parser.ParseFile(fset, "", f, 0) + if err != nil { + t.Fatalf("couldn't parse test.files[%d]: %v", i, err) + } + fs = append(fs, n) + } + var b bytes.Buffer + if err := Generate(&b, test.name+".go", test.pkg, test.ids, fset, fs); err != nil { + t.Fatal(err) + } + s := string(b.Bytes()) + for _, i := range test.expected { + if !strings.Contains(s, i) { + t.Errorf("expected %q in generated file", i) + } + } + for _, i := range test.excluded { + if strings.Contains(s, i) { + t.Errorf("found %q in generated file", i) + } + } + t.Log(s) + }) + } +} + +const hello1 = ` +package hello + +import ( + "context" + "strings" +) + +func MyTitle(ctx context.Context, v string) string { + return strings.Title(v) +} + +type foo struct{} + +func MyOtherDoFn(v foo) string { + return "constant" +} +` + +const hello2 = ` +package hello + +import ( + "context" + "strings" + + "github.com/apache/beam/sdks/go/pkg/beam/util/shimx" +) + +func anotherFn(v shimx.Emitter) string { + return v.Name +} + +func fooFn(v shimx.Emitter) foo { + return foo{} +} +` diff --git a/sdks/go/pkg/beam/util/shimx/generate.go b/sdks/go/pkg/beam/util/shimx/generate.go new file mode 100644 index 00000000000..6c0eb4a231e --- /dev/null +++ b/sdks/go/pkg/beam/util/shimx/generate.go @@ -0,0 +1,413 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package shimx specifies the templates for generating type assertion shims for +// Apache Beam Go SDK pipelines. +// +// In particular, the shims are used by the Beam Go SDK to avoid reflection at runtime, +// which is the default mode of operation. The shims are specialized for the code +// in question, using type assertion to convert arguments as required, and invoke the +// user code. +// +// Similar shims are required for emitters, and iterators in order to propagate values +// out of, and in to user functions respectively without reflection overhead. +// +// Registering user types is required to support user types as PCollection +// types, while registering functions is required to avoid possibly expensive function +// resolution at worker start up, which defaults to using DWARF Symbol tables. +// +// The generator largely relies on basic types and strings to ensure that it's usable +// by dynamic processes via reflection, or by any static analysis approach that is +// used in the future. +package shimx + +import ( + "fmt" + "io" + "sort" + "strings" + "text/template" +) + +// Beam imports that the generated code requires. +var ( + ExecImport = "github.com/apache/beam/sdks/go/pkg/beam/core/runtime/exec" + TypexImport = "github.com/apache/beam/sdks/go/pkg/beam/core/typex" + ReflectxImport = "github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx" + RuntimeImport = "github.com/apache/beam/sdks/go/pkg/beam/core/runtime" +) + +func validateBeamImports() { + checkImportSuffix(ExecImport, "exec") + checkImportSuffix(TypexImport, "typex") + checkImportSuffix(ReflectxImport, "reflectx") + checkImportSuffix(RuntimeImport, "runtime") +} + +func checkImportSuffix(path, suffix string) { + if !strings.HasSuffix(path, suffix) { + panic(fmt.Sprintf("expected %v to end with %v. can't generate valid code", path, suffix)) + } +} + +// Top is the top level inputs into the template file for generating shims. +type Top struct { + FileName, ToolName, Package string + + Imports []string // the full import paths + Functions []string // the plain names of the functions to be registered. + Types []string // the plain names of the types to be registered. + Emitters []Emitter + Inputs []Input + Shims []Func +} + +// sort orders the shims consistently to minimize diffs in the generated code. +func (t *Top) sort() { + sort.Strings(t.Imports) + sort.Strings(t.Functions) + sort.Strings(t.Types) + sort.SliceStable(t.Emitters, func(i, j int) bool { + return t.Emitters[i].Name < t.Emitters[j].Name + }) + sort.SliceStable(t.Inputs, func(i, j int) bool { + return t.Inputs[i].Name < t.Inputs[j].Name + }) + sort.SliceStable(t.Shims, func(i, j int) bool { + return t.Shims[i].Name < t.Shims[j].Name + }) +} + +// processImports removes imports that are otherwise handled by the template +// This method is on the value to shallow copy the Field references to avoid +// mutating the user provided instance. +func (t Top) processImports() *Top { + pred := map[string]bool{"reflect": true} + var filtered []string + if len(t.Emitters) > 0 { + pred["context"] = true + } + if len(t.Inputs) > 0 { + pred["fmt"] = true + pred["io"] = true + } + if len(t.Types) > 0 || len(t.Functions) > 0 { + filtered = append(filtered, RuntimeImport) + pred[RuntimeImport] = true + } + if len(t.Shims) > 0 { + filtered = append(filtered, ReflectxImport) + pred[ReflectxImport] = true + } + if len(t.Emitters) > 0 || len(t.Inputs) > 0 { + filtered = append(filtered, ExecImport) + pred[ExecImport] = true + } + needTypexImport := len(t.Emitters) > 0 + for _, i := range t.Inputs { + if i.Time { + needTypexImport = true + break + } + } + if needTypexImport { + filtered = append(filtered, TypexImport) + pred[TypexImport] = true + } + for _, imp := range t.Imports { + if !pred[imp] { + filtered = append(filtered, imp) + } + } + t.Imports = filtered + return &t +} + +// Emitter represents an emitter shim to be generated. +type Emitter struct { + Name, Type string // The user name of the function, the type of the emit. + Time bool // if this uses event time. + Key, Val string // The type of the emits. +} + +// Input represents an iterator shim to be generated. +type Input struct { + Name, Type string // The user name of the function, the type of the iterator (including the bool). + Time bool // if this uses event time. + Key, Val string // The type of the inputs, pointers removed. +} + +// Func represents a type assertion shim for function invocation to be generated. +type Func struct { + Name, Type string + In, Out []string +} + +// Name creates a capitalized identifier from a type string. The identifier +// follows the rules of go identifiers and should be compileable. +// See https://golang.org/ref/spec#Identifiers for details. +func Name(t string) string { + if strings.HasPrefix(t, "[]") { + return Name(t[2:] + "Slice") + } + + t = strings.Replace(t, "beam.", "typex.", -1) + t = strings.Replace(t, ".", "۰", -1) // For packages + t = strings.Replace(t, "*", "Ꮨ", -1) // For pointers + t = strings.Replace(t, "[", "_", -1) // For maps + t = strings.Replace(t, "]", "_", -1) + return strings.Title(t) +} + +// FuncName returns a compilable Go identifier for a function, given valid +// type names as generated by Name. +// See https://golang.org/ref/spec#Identifiers for details. +func FuncName(inNames, outNames []string) string { + return fmt.Sprintf("%sГ%s", strings.Join(inNames, ""), strings.Join(outNames, "")) +} + +// File writes go code to the given writer. +func File(w io.Writer, top *Top) { + validateBeamImports() + top = top.processImports() + top.sort() + vampireTemplate.Funcs(funcMap).Execute(w, top) +} + +var vampireTemplate = template.Must(template.New("vampire").Funcs(funcMap).Parse(`// Code generated by {{.ToolName}}. DO NOT EDIT. +// File: {{.FileName}} + +package {{.Package}} + +import ( + +{{- if .Emitters}} + "context" +{{- end}} +{{- if .Inputs}} + "fmt" + "io" +{{- end}} + "reflect" +{{- if .Imports}} + + // Library imports +{{- end}} +{{- range $import := .Imports}} + "{{$import}}" +{{- end}} +) + +func init() { +{{- range $x := .Functions}} + runtime.RegisterFunction({{$x}}) +{{- end}} +{{- range $x := .Types}} + runtime.RegisterType(reflect.TypeOf((*{{$x}})(nil)).Elem()) +{{- end}} +{{- range $x := .Shims}} + reflectx.RegisterFunc(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), funcMaker{{$x.Name}}) +{{- end}} +{{- range $x := .Emitters}} + exec.RegisterEmitter(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), emitMaker{{$x.Name}}) +{{- end}} +{{- range $x := .Inputs}} + exec.RegisterInput(reflect.TypeOf((*{{$x.Type}})(nil)).Elem(), iterMaker{{$x.Name}}) +{{- end}} +} + +{{range $x := .Shims -}} +type caller{{$x.Name}} struct { + fn {{$x.Type}} +} + +func funcMaker{{$x.Name}}(fn interface{}) reflectx.Func { + f := fn.({{$x.Type}}) + return &caller{{$x.Name}}{fn: f} +} + +func (c *caller{{$x.Name}}) Name() string { + return reflectx.FunctionName(c.fn) +} + +func (c *caller{{$x.Name}}) Type() reflect.Type { + return reflect.TypeOf(c.fn) +} + +func (c *caller{{$x.Name}}) Call(args []interface{}) []interface{} { + {{mktuplef (len $x.Out) "out%d"}}{{- if len $x.Out}} := {{end -}}c.fn({{mkparams "args[%d].(%v)" $x.In}}) + return []interface{}{ {{- mktuplef (len $x.Out) "out%d" -}} } +} + +func (c *caller{{$x.Name}}) Call{{len $x.In}}x{{len $x.Out}}({{mkargs (len $x.In) "arg%v" "interface{}"}}) ({{- mktuple (len $x.Out) "interface{}"}}) { + {{if len $x.Out}}return {{end}}c.fn({{mkparams "arg%d.(%v)" $x.In}}) +} + +{{end}} +{{if .Emitters -}} +type emitNative struct { + n exec.ElementProcessor + fn interface{} + + ctx context.Context + ws []typex.Window + et typex.EventTime +} + +func (e *emitNative) Init(ctx context.Context, ws []typex.Window, et typex.EventTime) error { + e.ctx = ctx + e.ws = ws + e.et = et + return nil +} + +func (e *emitNative) Value() interface{} { + return e.fn +} + +{{end -}} +{{range $x := .Emitters -}} +func emitMaker{{$x.Name}}(n exec.ElementProcessor) exec.ReusableEmitter { + ret := &emitNative{n: n} + ret.fn = ret.invoke{{.Name}} + return ret +} + +func (e *emitNative) invoke{{$x.Name}}({{if $x.Time -}} t typex.EventTime, {{end}}{{if $x.Key}}key {{$x.Key}}, {{end}}val {{$x.Val}}) { + value := exec.FullValue{Windows: e.ws, Timestamp: {{- if $x.Time}} t{{else}} e.et{{end}}, {{- if $x.Key}} Elm: key, Elm2: val {{else}} Elm: val{{end -}} } + if err := e.n.ProcessElement(e.ctx, value); err != nil { + panic(err) + } +} + +{{end}} +{{- if .Inputs -}} +type iterNative struct { + s exec.ReStream + fn interface{} + + // cur is the "current" stream, if any. + cur exec.Stream +} + +func (v *iterNative) Init() error { + cur, err := v.s.Open() + if err != nil { + return err + } + v.cur = cur + return nil +} + +func (v *iterNative) Value() interface{} { + return v.fn +} + +func convToString(v interface{}) string { + switch v.(type) { + case []byte: + return string(v.([]byte)) + default: + return v.(string) + } +} + +func (v *iterNative) Reset() error { + if err := v.cur.Close(); err != nil { + return err + } + v.cur = nil + return nil +} +{{- end}} +{{- range $x := .Inputs}} +func iterMaker{{$x.Name}}(s exec.ReStream) exec.ReusableInput { + ret := &iterNative{s: s} + ret.fn = ret.read{{$x.Name}} + return ret +} + +func (v *iterNative) read{{$x.Name}}({{if $x.Time -}} et *typex.EventTime, {{end}}{{if $x.Key}}key *{{$x.Key}}, {{end}}value *{{$x.Val}}) bool { + elm, err := v.cur.Read() + if err != nil { + if err == io.EOF { + return false + } + panic(fmt.Sprintf("broken stream: %v", err)) + } + +{{- if $x.Time}} + *et = elm.Timestamp +{{- end}} +{{- if eq $x.Key "string"}} + *key = convToString(elm.Elm) +{{- else if $x.Key}} + *key = elm.Elm.({{$x.Key}}) +{{- end}} +{{- if eq $x.Val "string"}} + *value = convToString(elm.Elm{{- if $x.Key -}} 2 {{- end -}}) +{{- else}} + *value = elm.Elm{{- if $x.Key -}} 2 {{- end -}}.({{$x.Val}}) +{{- end}} + return true +} +{{- end}} + +// DO NOT MODIFY: GENERATED CODE +`)) + +// funcMap contains useful functions for use in the template. +var funcMap template.FuncMap = map[string]interface{}{ + "mkargs": mkargs, + "mkparams": mkparams, + "mktuple": mktuple, + "mktuplef": mktuplef, +} + +// mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)> type". +// If n is 0, it returns the empty string. +func mkargs(n int, format, typ string) string { + if n == 0 { + return "" + } + return fmt.Sprintf("%v %v", mktuplef(n, format), typ) +} + +// mkparams(format, []type) returns "<fmt.Sprintf(format, 0, type[0])>, .., <fmt.Sprintf(format, n-1, type[0])>". +func mkparams(format string, types []string) string { + var ret []string + for i, t := range types { + ret = append(ret, fmt.Sprintf(format, i, t)) + } + return strings.Join(ret, ", ") +} + +// mktuple(n, v) returns "v, v, ..., v". +func mktuple(n int, v string) string { + var ret []string + for i := 0; i < n; i++ { + ret = append(ret, v) + } + return strings.Join(ret, ", ") +} + +// mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)>" +func mktuplef(n int, format string) string { + var ret []string + for i := 0; i < n; i++ { + ret = append(ret, fmt.Sprintf(format, i)) + } + return strings.Join(ret, ", ") +} diff --git a/sdks/go/pkg/beam/util/shimx/generate_test.go b/sdks/go/pkg/beam/util/shimx/generate_test.go new file mode 100644 index 00000000000..3696bbab7f3 --- /dev/null +++ b/sdks/go/pkg/beam/util/shimx/generate_test.go @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shimx + +import ( + "bytes" + "sort" + "testing" +) + +func TestTop_Sort(t *testing.T) { + top := Top{ + Imports: []string{"z", "a", "r"}, + Functions: []string{"z", "a", "r"}, + Types: []string{"z", "a", "r"}, + Emitters: []Emitter{{Name: "z"}, {Name: "a"}, {Name: "r"}}, + Inputs: []Input{{Name: "z"}, {Name: "a"}, {Name: "r"}}, + Shims: []Func{{Name: "z"}, {Name: "a"}, {Name: "r"}}, + } + + top.sort() + if !sort.SliceIsSorted(top.Imports, func(i, j int) bool { return top.Imports[i] < top.Imports[j] }) { + t.Errorf("top.Imports not sorted: got %v, want it sorted", top.Imports) + } + if !sort.SliceIsSorted(top.Functions, func(i, j int) bool { return top.Functions[i] < top.Functions[j] }) { + t.Errorf("top.Types not sorted: got %v, want it sorted", top.Functions) + } + if !sort.SliceIsSorted(top.Types, func(i, j int) bool { return top.Types[i] < top.Types[j] }) { + t.Errorf("top.Types not sorted: got %v, want it sorted", top.Types) + } + if !sort.SliceIsSorted(top.Emitters, func(i, j int) bool { return top.Emitters[i].Name < top.Emitters[j].Name }) { + t.Errorf("top.Emitters not sorted by name: got %v, want it sorted", top.Emitters) + } + if !sort.SliceIsSorted(top.Inputs, func(i, j int) bool { return top.Inputs[i].Name < top.Inputs[j].Name }) { + t.Errorf("top.Inputs not sorted by name: got %v, want it sorted", top.Inputs) + } + if !sort.SliceIsSorted(top.Shims, func(i, j int) bool { return top.Shims[i].Name < top.Shims[j].Name }) { + t.Errorf("top.Shims not sorted: got %v, want it sorted", top.Shims) + } +} + +func TestTop_ProcessImports(t *testing.T) { + needsFiltering := []string{"context", "keepit", "fmt", "io", "reflect", "unrelated"} + + tests := []struct { + name string + got *Top + want []string + }{ + {name: "default", got: &Top{}, want: []string{"context", "keepit", "fmt", "io", "unrelated"}}, + {name: "emit", got: &Top{Emitters: []Emitter{{Name: "emit"}}}, want: []string{ExecImport, TypexImport, "keepit", "fmt", "io", "unrelated"}}, + {name: "iter", got: &Top{Inputs: []Input{{Name: "iter"}}}, want: []string{ExecImport, "context", "keepit", "unrelated"}}, + {name: "iterWTime", got: &Top{Inputs: []Input{{Name: "iterWTime", Time: true}}}, want: []string{ExecImport, TypexImport, "context", "keepit", "unrelated"}}, + {name: "shim", got: &Top{Shims: []Func{{Name: "emit"}}}, want: []string{ReflectxImport, "context", "keepit", "fmt", "io", "unrelated"}}, + {name: "iter&emit", got: &Top{Emitters: []Emitter{{Name: "emit"}}, Inputs: []Input{{Name: "iter"}}}, want: []string{ExecImport, TypexImport, "keepit", "unrelated"}}, + {name: "functions", got: &Top{Functions: []string{"func1"}}, want: []string{RuntimeImport, "context", "keepit", "fmt", "io", "unrelated"}}, + {name: "types", got: &Top{Types: []string{"func1"}}, want: []string{RuntimeImport, "context", "keepit", "fmt", "io", "unrelated"}}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + top := test.got + top.Imports = needsFiltering + top = top.processImports() + for i := range top.Imports { + if top.Imports[i] != test.want[i] { + t.Fatalf("want %v, got %v", test.want, top.Imports) + } + } + }) + } +} + +func TestName(t *testing.T) { + tests := []struct { + have, want string + }{ + {"int", "Int"}, + {"foo.MyInt", "Foo۰MyInt"}, + {"[]beam.X", "Typex۰XSlice"}, + {"map[int]beam.X", "Map_int_typex۰X"}, + {"map[string]*beam.X", "Map_string_Ꮨtypex۰X"}, + } + for _, test := range tests { + if got := Name(test.have); got != test.want { + t.Errorf("Name(%v) = %v, want %v", test.have, got, test.want) + } + } +} + +func TestFuncName(t *testing.T) { + tests := []struct { + in, out []string + want string + }{ + {in: []string{"Int"}, out: []string{"Int"}, want: "IntГInt"}, + {in: []string{"Int"}, out: []string{}, want: "IntГ"}, + {in: []string{}, out: []string{"Bool"}, want: "ГBool"}, + {in: []string{"Bool", "String"}, out: []string{"Int", "Bool"}, want: "BoolStringГIntBool"}, + {in: []string{"String", "Map_int_typex۰X"}, out: []string{"Int", "Typex۰XSlice"}, want: "StringMap_int_typex۰XГIntTypex۰XSlice"}, + } + for _, test := range tests { + if got := FuncName(test.in, test.out); got != test.want { + t.Errorf("FuncName(%v,%v) = %v, want %v", test.in, test.out, got, test.want) + } + } +} + +func TestFile(t *testing.T) { + top := Top{ + Package: "gentest", + Imports: []string{"z", "a", "r"}, + Functions: []string{"z", "a", "r"}, + Types: []string{"z", "a", "r"}, + Emitters: []Emitter{ + {Name: "z", Type: "func(int)", Val: "Int"}, + {Name: "a", Type: "func(bool, int) bool", Key: "bool", Val: "int"}, + {Name: "r", Type: "func(typex.EventTime, string) bool", Time: true, Val: "string"}, + }, + Inputs: []Input{ + {Name: "z", Type: "func(*int) bool"}, + {Name: "a", Type: "func(*bool, *int) bool", Key: "bool", Val: "int"}, + {Name: "r", Type: "func(*typex.EventTime, *string) bool", Time: true, Val: "string"}, + }, + Shims: []Func{ + {Name: "z", Type: "func(string, func(int))", In: []string{"string", "func(int)"}}, + {Name: "a", Type: "func(float64) (int, int)", In: []string{"float64"}, Out: []string{"int", "int"}}, + {Name: "r", Type: "func(string, func(int))", In: []string{"string", "func(int)"}}, + }, + } + top.sort() + + var b bytes.Buffer + if err := vampireTemplate.Funcs(funcMap).Execute(&b, top); err != nil { + t.Errorf("error generating template: %v", err) + } +} + +func TestMkargs(t *testing.T) { + tests := []struct { + n int + format, typ string + want string + }{ + {n: 0, format: "Foo", typ: "Bar", want: ""}, + {n: 1, format: "arg%d", typ: "Bar", want: "arg0 Bar"}, + {n: 4, format: "a%d", typ: "Baz", want: "a0, a1, a2, a3 Baz"}, + } + for _, test := range tests { + if got := mkargs(test.n, test.format, test.typ); got != test.want { + t.Errorf("mkargs(%v,%v,%v) = %v, want %v", test.n, test.format, test.typ, got, test.want) + } + } +} + +func TestMkparams(t *testing.T) { + tests := []struct { + format string + types []string + want string + }{ + {format: "Foo", types: []string{}, want: ""}, + {format: "arg%d %v", types: []string{"Bar"}, want: "arg0 Bar"}, + {format: "a%d %v", types: []string{"Foo", "Baz", "interface{}"}, want: "a0 Foo, a1 Baz, a2 interface{}"}, + } + for _, test := range tests { + if got := mkparams(test.format, test.types); got != test.want { + t.Errorf("mkparams(%v,%v) = %v, want %v", test.format, test.types, got, test.want) + } + } +} + +func TestMktuple(t *testing.T) { + tests := []struct { + n int + v string + want string + }{ + {n: 0, v: "Foo", want: ""}, + {n: 1, v: "Bar", want: "Bar"}, + {n: 4, v: "Baz", want: "Baz, Baz, Baz, Baz"}, + } + for _, test := range tests { + if got := mktuple(test.n, test.v); got != test.want { + t.Errorf("mktuple(%v,%v) = %v, want %v", test.n, test.v, got, test.want) + } + } +} + +func TestMktuplef(t *testing.T) { + tests := []struct { + n int + format, typ string + want string + }{ + {n: 0, format: "Foo%d", want: ""}, + {n: 1, format: "arg%d", want: "arg0"}, + {n: 4, format: "a%d", want: "a0, a1, a2, a3"}, + } + for _, test := range tests { + if got := mktuplef(test.n, test.format); got != test.want { + t.Errorf("mktuplef(%v,%v) = %v, want %v", test.n, test.format, got, test.want) + } + } +} diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go new file mode 100644 index 00000000000..003a3c91df7 --- /dev/null +++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package starcgenx is a Static Analysis Type Assertion shim and Registration Code Generator +// which provides an extractor to extract types from a package, in order to generate +// approprate shimsr a package so code can be generated for it. +// +// It's written for use by the starcgen tool, but separate to permit +// alternative "go/importer" Importers for accessing types from imported packages. +package starcgenx + +import ( + "bytes" + "fmt" + "go/ast" + "go/token" + "go/types" + "strings" + + "github.com/apache/beam/sdks/go/pkg/beam/util/shimx" +) + +// NewExtractor returns an extractor for the given package. +func NewExtractor(pkg string) *Extractor { + return &Extractor{ + Package: pkg, + functions: make(map[string]struct{}), + types: make(map[string]struct{}), + funcs: make(map[string]*types.Signature), + emits: make(map[string]shimx.Emitter), + iters: make(map[string]shimx.Input), + imports: make(map[string]struct{}), + allExported: true, + } +} + +// Extractor contains and uniquifies the cache of types and things that need to be generated. +type Extractor struct { + w bytes.Buffer + Package string + debug bool + + // Ids is an optional slice of package local identifiers + Ids []string + + // Register and uniquify the needed shims for each kind. + // Functions to Register + functions map[string]struct{} + // Types to Register (structs, essentially) + types map[string]struct{} + // FuncShims needed + funcs map[string]*types.Signature + // Emitter Shims needed + emits map[string]shimx.Emitter + // Iterator Shims needed + iters map[string]shimx.Input + + // list of packages we need to import. + imports map[string]struct{} + + allExported bool // Marks if all ptransforms are exported and available in main. +} + +// Summary prints out a summary of the shims and registrations to +// be generated to the buffer. +func (e *Extractor) Summary() { + e.Print("\n") + e.Print("Summary\n") + e.Printf("All exported?: %v\n", e.allExported) + e.Printf("%d\t Functions\n", len(e.functions)) + e.Printf("%d\t Types\n", len(e.types)) + e.Printf("%d\t Shims\n", len(e.funcs)) + e.Printf("%d\t Emits\n", len(e.emits)) + e.Printf("%d\t Inputs\n", len(e.iters)) +} + +// lifecycleMethodName returns if the passed in string is one of the lifecycle method names used +// by the Go SDK as DoFn or CombineFn lifecycle methods. These are the only methods that need +// shims generated for them, as per beam/core/graph/fn.go +// TODO(lostluck): Move this to beam/core/graph/fn.go, so it can stay up to date. +func lifecycleMethodName(n string) bool { + switch n { + case "ProcessElement", "StartBundle", "FinishBundle", "Setup", "Teardown", "CreateAccumulator", "AddInput", "MergeAccumulators", "ExtractOutput", "Compact": + return true + default: + return false + } +} + +// Bytes forwards to fmt.Fprint to the extractor buffer. +func (e *Extractor) Bytes() []byte { + return e.w.Bytes() +} + +// Print forwards to fmt.Fprint to the extractor buffer. +func (e *Extractor) Print(s string) { + if e.debug { + fmt.Fprint(&e.w, s) + } +} + +// Printf forwards to fmt.Printf to the extractor buffer. +func (e *Extractor) Printf(f string, args ...interface{}) { + if e.debug { + fmt.Fprintf(&e.w, f, args...) + } +} + +// FromAsts analyses the contents of a package +func (e *Extractor) FromAsts(imp types.Importer, fset *token.FileSet, files []*ast.File) error { + conf := types.Config{ + Importer: imp, + IgnoreFuncBodies: true, + DisableUnusedImportCheck: true, + } + info := &types.Info{ + Defs: make(map[*ast.Ident]types.Object), + } + if len(e.Ids) != 0 { + // TODO(lostluck): This becomes unnnecessary iff we can figure out + // which ParDos are being passed to beam.ParDo or beam.Combine. + // If there are ids, we need to also look at function bodies, and uses. + var checkFuncBodies bool + for _, v := range e.Ids { + if strings.Contains(v, ".") { + checkFuncBodies = true + break + } + } + conf.IgnoreFuncBodies = !checkFuncBodies + info.Uses = make(map[*ast.Ident]types.Object) + } + + if _, err := conf.Check(e.Package, fset, files, info); err != nil { + return fmt.Errorf("failed to type check package %s : %v", e.Package, err) + } + + e.Print("/*\n") + var idsRequired, idsFound map[string]bool + if len(e.Ids) > 0 { + e.Printf("Filtering by %d identifiers: %q\n", len(e.Ids), strings.Join(e.Ids, ", ")) + idsRequired = make(map[string]bool) + idsFound = make(map[string]bool) + for _, id := range e.Ids { + idsRequired[id] = true + } + } + e.Print("CHECKING DEFS\n") + for id, obj := range info.Defs { + e.fromObj(fset, id, obj, idsRequired, idsFound) + } + e.Print("CHECKING USES\n") + for id, obj := range info.Uses { + e.fromObj(fset, id, obj, idsRequired, idsFound) + } + var notFound []string + for _, k := range e.Ids { + if !idsFound[k] { + notFound = append(notFound, k) + } + } + if len(notFound) > 0 { + return fmt.Errorf("couldn't find the following identifiers; please check for typos, or remove them: %v", strings.Join(notFound, ", ")) + } + e.Print("*/\n") + + return nil +} + +func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsFound map[string]bool) bool { + if idsRequired == nil { + return true + } + if idsFound == nil { + panic("broken invariant: idsFound map is nil, but idsRequired map exists") + } + // If we're filtering IDs, then it needs to be in the filtered identifiers, + // or it's receiver type identifier needs to be in the filtered identifiers. + if idsRequired[ident] { + idsFound[ident] = true + return true + } + // Check if this is a function. + sig, ok := obj.Type().(*types.Signature) + if !ok { + return false + } + // If this is a function, and it has a receiver, it's a method. + if recv := sig.Recv(); recv != nil && lifecycleMethodName(ident) { + // We don't want to care about pointers, so dereference to value type. + t := recv.Type() + p, ok := t.(*types.Pointer) + for ok { + t = p.Elem() + p, ok = t.(*types.Pointer) + } + ts := types.TypeString(t, e.qualifier) + e.Printf("RRR has %v, ts: %s %s--- ", sig, ts, ident) + if !idsRequired[ts] { + e.Print("IGNORE\n") + return false + } + e.Print("KEEP\n") + idsFound[ts] = true + return true + } + return false +} + +func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object, idsRequired, idsFound map[string]bool) { + if obj == nil { // Omit the package declaration. + e.Printf("%s: %q has no object, probably a package\n", + fset.Position(id.Pos()), id.Name) + return + } + + pkg := obj.Pkg() + if pkg == nil { + e.Printf("%s: %q has no package \n", + fset.Position(id.Pos()), id.Name) + // No meaningful identifier. + return + } + ident := fmt.Sprintf("%s.%s", pkg.Name(), obj.Name()) + if pkg.Name() == e.Package { + ident = obj.Name() + } + if !e.isRequired(ident, obj, idsRequired, idsFound) { + return + } + + switch ot := obj.(type) { + case *types.Var: + // Vars are tricky since they could be anything, and anywhere (package scope, parameters, etc) + // eg. Flags, or Field Tags, among others. + // I'm increasingly convinced that we should simply igonore vars. + // Do nothing for vars. + case *types.Func: + sig := obj.Type().(*types.Signature) + if recv := sig.Recv(); recv != nil { + // Methods don't need registering, but they do need shim generation. + e.Printf("%s: %q is a method of %v -> %v--- %T %v %v %v\n", + fset.Position(id.Pos()), id.Name, recv.Type(), obj, obj, id, obj.Pkg(), obj.Type()) + if !lifecycleMethodName(id.Name) { + // If this is not a lifecycle method, we should ignore it. + return + } + } else if id.Name != "init" { + // init functions are special and should be ignored. + // Functions need registering, as well as shim generation. + e.Printf("%s: %q is a top level func %v --- %T %v %v %v\n", + fset.Position(id.Pos()), ident, obj, obj, id, obj.Pkg(), obj.Type()) + e.functions[ident] = struct{}{} + } + // For functions from other packages. + if pkg.Name() != e.Package { + e.imports[pkg.Path()] = struct{}{} + } + + e.funcs[e.sigKey(sig)] = sig + e.extractParameters(sig) + e.Printf("\t%v\n", sig) + case *types.TypeName: + e.Printf("%s: %q is a type %v --- %T %v %v %v %v\n", + fset.Position(id.Pos()), id.Name, obj, obj, id, obj.Pkg(), obj.Type(), obj.Name()) + // Probably need to sanity check that this type actually is/has a ProcessElement + // or MergeAccumulators defined for this type so unnecessary registrations don't happen, + // an can explicitly produce an error if an explicitly named type *isn't* a DoFn or CombineFn. + e.extractType(ot) + default: + e.Printf("%s: %q defines %v --- %T %v %v %v\n", + fset.Position(id.Pos()), types.ObjectString(obj, e.qualifier), obj, obj, id, obj.Pkg(), obj.Type()) + } +} + +func (e *Extractor) extractType(ot *types.TypeName) { + name := types.TypeString(ot.Type(), e.qualifier) + // Unwrap an alias by one level. + // Attempting to deference a full chain of aliases runs the risk of crossing + // a visibility boundary such as internal packages. + // A single level is safe since the code we're analysing imports it, + // so we can assume the generated code can access it too. + if ot.IsAlias() { + if t, ok := ot.Type().(*types.Named); ok { + ot = t.Obj() + name = types.TypeString(t, e.qualifier) + + if pkg := t.Obj().Pkg(); pkg != nil { + e.imports[pkg.Path()] = struct{}{} + } + } + } + e.types[name] = struct{}{} +} + +// Examines the signature and extracts types of parameters for generating +// necessary imports and emitter and iterator code. +func (e *Extractor) extractParameters(sig *types.Signature) { + in := sig.Params() // *types.Tuple + for i := 0; i < in.Len(); i++ { + s := in.At(i) // *types.Var + + // Pointer types need to be iteratively unwrapped until we're at the base type, + // so we can get the import if necessary. + t := s.Type() + p, ok := t.(*types.Pointer) + for ok { + t = p.Elem() + p, ok = t.(*types.Pointer) + } + // Here's were we ensure we register new imports. + if t, ok := t.(*types.Named); ok { + if pkg := t.Obj().Pkg(); pkg != nil { + e.imports[pkg.Path()] = struct{}{} + } + e.extractType(t.Obj()) + } + + if a, ok := s.Type().(*types.Signature); ok { + // Check if the type is an emitter or iterator for the specialized + // shim generation for those types. + if emt, ok := e.makeEmitter(a); ok { + e.emits[emt.Name] = emt + } + if ipt, ok := e.makeInput(a); ok { + e.iters[ipt.Name] = ipt + } + // Tail recurse on functional parameters. + e.extractParameters(a) + } + } +} + +func (e *Extractor) qualifier(pkg *types.Package) string { + n := tail(pkg.Name()) + if n == e.Package { + return "" + } + return n +} + +func tail(path string) string { + if i := strings.LastIndex("/", path); i >= 0 { + path = path[i:] + } + return path +} + +func (e *Extractor) tupleStrings(t *types.Tuple) []string { + var vs []string + for i := 0; i < t.Len(); i++ { + v := t.At(i) + vs = append(vs, types.TypeString(v.Type(), e.qualifier)) + } + return vs +} + +// sigKey produces a variable name agnostic key for the function signature. +func (e *Extractor) sigKey(sig *types.Signature) string { + ps, rs := e.tupleStrings(sig.Params()), e.tupleStrings(sig.Results()) + return fmt.Sprintf("func(%v) (%v)", strings.Join(ps, ","), strings.Join(rs, ",")) +} + +// Generate produces an additional file for the Go package that was extracted, +// to be included in a subsequent compilation. +func (e *Extractor) Generate(filename string) []byte { + var functions []string + for fn := range e.functions { + // No extra processing necessary, since these should all be package local. + functions = append(functions, fn) + } + var typs []string + for t := range e.types { + typs = append(typs, t) + } + var shims []shimx.Func + for sig, t := range e.funcs { + shim := shimx.Func{Type: sig} + var inNames []string + in := t.Params() // *types.Tuple + for i := 0; i < in.Len(); i++ { + s := in.At(i) // *types.Var + shim.In = append(shim.In, types.TypeString(s.Type(), e.qualifier)) + inNames = append(inNames, e.NameType(s.Type())) + } + var outNames []string + out := t.Results() // *types.Tuple + for i := 0; i < out.Len(); i++ { + s := out.At(i) + shim.Out = append(shim.Out, types.TypeString(s.Type(), e.qualifier)) + outNames = append(outNames, e.NameType(s.Type())) + } + shim.Name = shimx.FuncName(inNames, outNames) + shims = append(shims, shim) + } + var emits []shimx.Emitter + for _, t := range e.emits { + emits = append(emits, t) + } + var inputs []shimx.Input + for _, t := range e.iters { + inputs = append(inputs, t) + } + + var imports []string + for k := range e.imports { + if k == "" || k == e.Package { + continue + } + imports = append(imports, k) + } + + top := shimx.Top{ + FileName: filename, + ToolName: "starcgen", + Package: e.Package, + Imports: imports, + Functions: functions, + Types: typs, + Shims: shims, + Emitters: emits, + Inputs: inputs, + } + e.Print("\n") + shimx.File(&e.w, &top) + return e.w.Bytes() +} + +func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) { + // Emitters must have no return values. + if sig.Results().Len() != 0 { + return shimx.Emitter{}, false + } + p := sig.Params() + emt := shimx.Emitter{Type: e.sigKey(sig)} + switch p.Len() { + case 1: + emt.Time = false + emt.Val = e.varString(p.At(0)) + case 2: + // TODO(rebo): Fix this when imports are resolved. + // This is the tricky one... Need to verify what happens with aliases. + // And get a candle to compare this against somehwere. isEventTime(p.At(0)) maybe. + // if p.At(0) == typex.EventTimeType { + // emt.Time = true + // } else { + emt.Key = e.varString(p.At(0)) + //} + emt.Val = e.varString(p.At(1)) + case 3: + // If there's 3, the first one must be typex.EventTime. + emt.Time = true + emt.Key = e.varString(p.At(1)) + emt.Val = e.varString(p.At(2)) + default: + return shimx.Emitter{}, false + } + if emt.Time { + emt.Name = fmt.Sprintf("ET%s%s", shimx.Name(emt.Key), shimx.Name(emt.Val)) + } else { + emt.Name = fmt.Sprintf("%s%s", shimx.Name(emt.Key), shimx.Name(emt.Val)) + } + return emt, true +} + +// makeInput checks if the given signature is an iterator or not, and if so, +// returns a shimx.Input struct for the signature for use by the code +// generator. The canonical check for an iterater signature is in the +// funcx.UnfoldIter function which uses the reflect library, +// and this logic is replicated here. +func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) { + r := sig.Results() + if r.Len() != 1 { + return shimx.Input{}, false + } + // Iterators must return a bool. + if b, ok := r.At(0).Type().(*types.Basic); !ok || b.Kind() != types.Bool { + return shimx.Input{}, false + } + p := sig.Params() + for i := 0; i < p.Len(); i++ { + // All params for iterators must be pointers. + if _, ok := p.At(i).Type().(*types.Pointer); !ok { + return shimx.Input{}, false + } + } + itr := shimx.Input{Type: e.sigKey(sig)} + switch p.Len() { + case 1: + itr.Time = false + itr.Val = e.deref(p.At(0)) + case 2: + // TODO(rebo): Fix this when imports are resolved. + // This is the tricky one... Need to verify what happens with aliases. + // And get a candle to compare this against somehwere. isEventTime(p.At(0)) maybe. + // if p.At(0) == typex.EventTimeType { + // itr.Time = true + // } else { + itr.Key = e.deref(p.At(0)) + //} + itr.Val = e.deref(p.At(1)) + case 3: + // If there's 3, the first one must be typex.EventTime. + itr.Time = true + itr.Key = e.deref(p.At(1)) + itr.Val = e.deref(p.At(2)) + default: + return shimx.Input{}, false + } + if itr.Time { + itr.Name = fmt.Sprintf("ET%s%s", shimx.Name(itr.Key), shimx.Name(itr.Val)) + } else { + itr.Name = fmt.Sprintf("%s%s", shimx.Name(itr.Key), shimx.Name(itr.Val)) + } + return itr, true +} + +// deref returns the string identifier for the element type of a pointer var. +// deref panics if the var type is not a pointer. +func (e *Extractor) deref(v *types.Var) string { + p := v.Type().(*types.Pointer) + return types.TypeString(p.Elem(), e.qualifier) +} + +// varString provides the correct type for a variable within the +// package for which we're generated code. +func (e *Extractor) varString(v *types.Var) string { + return types.TypeString(v.Type(), e.qualifier) +} + +// NameType turns a reflect.Type into a strying based on it's name. +// It prefixes Emit or Iter if the function satisfies the constrains of those types. +func (e *Extractor) NameType(t types.Type) string { + switch a := t.(type) { + case *types.Signature: + if emt, ok := e.makeEmitter(a); ok { + return "Emit" + emt.Name + } + if ipt, ok := e.makeInput(a); ok { + return "Iter" + ipt.Name + } + return shimx.Name(e.sigKey(a)) + case *types.Pointer: + return e.NameType(a.Elem()) + case *types.Slice: + return "Sliceof" + e.NameType(a.Elem()) + default: + return shimx.Name(types.TypeString(t, e.qualifier)) + } +} diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go new file mode 100644 index 00000000000..9141acb114e --- /dev/null +++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx_test.go @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package starcgenx + +import ( + "go/ast" + "go/importer" + "go/parser" + "go/token" + "strings" + "testing" +) + +func TestExtractor(t *testing.T) { + tests := []struct { + name string + pkg string + files []string + ids []string + expected []string + excluded []string + }{ + {name: "pardo1", files: []string{pardo}, pkg: "pardo", + expected: []string{"runtime.RegisterFunction(MyIdent)", "runtime.RegisterFunction(MyDropVal)", "runtime.RegisterFunction(MyOtherDoFn)", "runtime.RegisterType(reflect.TypeOf((*foo)(nil)).Elem())", "funcMakerStringГString", "funcMakerIntStringГInt", "funcMakerFooГStringFoo"}, + }, + {name: "emits1", files: []string{emits}, pkg: "emits", + expected: []string{"runtime.RegisterFunction(anotherFn)", "runtime.RegisterFunction(emitFn)", "runtime.RegisterType(reflect.TypeOf((*reInt)(nil)).Elem())", "funcMakerEmitIntIntГ", "emitMakerIntInt", "funcMakerIntIntEmitIntIntГError"}, + }, + {name: "iters1", files: []string{iters}, pkg: "iters", + expected: []string{"runtime.RegisterFunction(iterFn)", "funcMakerStringIterIntГ", "iterMakerInt"}, + }, + {name: "structs1", files: []string{structs}, pkg: "structs", ids: []string{"myDoFn"}, + expected: []string{"runtime.RegisterType(reflect.TypeOf((*myDoFn)(nil)).Elem())", "funcMakerEmitIntГ", "emitMakerInt", "funcMakerValTypeValTypeEmitIntГ", "runtime.RegisterType(reflect.TypeOf((*valType)(nil)).Elem())"}, + excluded: []string{"funcMakerStringГ", "emitMakerString", "nonPipelineType"}, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + fset := token.NewFileSet() + var fs []*ast.File + for i, f := range test.files { + n, err := parser.ParseFile(fset, "", f, 0) + if err != nil { + t.Fatalf("couldn't parse test.files[%d]: %v", i, err) + } + fs = append(fs, n) + } + e := NewExtractor(test.pkg) + e.Ids = test.ids + if err := e.FromAsts(importer.Default(), fset, fs); err != nil { + t.Fatal(err) + } + data := e.Generate("test_shims.go") + s := string(data) + for _, i := range test.expected { + if !strings.Contains(s, i) { + t.Errorf("expected %q in generated file", i) + } + } + for _, i := range test.excluded { + if strings.Contains(s, i) { + t.Errorf("found %q in generated file", i) + } + } + t.Log(s) + }) + } +} + +const pardo = ` +package pardo + +func MyIdent(v string) string { + return v +} + +func MyDropVal(k int,v string) int { + return k +} + +// A user defined type +type foo struct{} + +func MyOtherDoFn(v foo) (string,foo) { + return "constant" +} +` + +const emits = ` +package emits + +type reInt int + +func anotherFn(emit func(int,int)) { + emit(0, 0) +} + +func emitFn(k,v int, emit func(int,int)) error { + for i := 0; i < v; i++ { emit(k, i) } + return nil +} +` +const iters = ` +package iters + +func iterFn(k string, iters func(*int) bool) {} +` + +const structs = ` +package structs + +type myDoFn struct{} + +// valType should be picked up via processElement +type valType int + +func (f *myDoFn) ProcessElement(k, v valType, emit func(int)) {} + +func (f *myDoFn) Setup(emit func(int)) {} +func (f *myDoFn) StartBundle(emit func(int)) {} +func (f *myDoFn) FinishBundle(emit func(int)) error {} +func (f *myDoFn) Teardown(emit func(int)) {} + +type nonPipelineType int + +// UnrelatedMethods shouldn't have shims or tangents generated for them +func (f *myDoFn) UnrelatedMethod1(v string) {} +func (f *myDoFn) UnrelatedMethod2(notEmit func(string)) {} + +func (f *myDoFn) UnrelatedMethod3(notEmit func(nonPipelineType)) {} +` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org Issue Time Tracking ------------------- Worklog Id: (was: 165044) Time Spent: 6h 20m (was: 6h 10m) > Make it easy to generate type-specialized Go SDK reflectx.Funcs > --------------------------------------------------------------- > > Key: BEAM-3612 > URL: https://issues.apache.org/jira/browse/BEAM-3612 > Project: Beam > Issue Type: Improvement > Components: sdk-go > Reporter: Henning Rohde > Assignee: Robert Burke > Priority: Major > Time Spent: 6h 20m > Remaining Estimate: 0h > -- This message was sent by Atlassian JIRA (v7.6.3#76005)