Merge tag '0.22.0' into go1 Release 0.22.0
Project: http://git-wip-us.apache.org/repos/asf/qpid-proton/repo Commit: http://git-wip-us.apache.org/repos/asf/qpid-proton/commit/6f799990 Tree: http://git-wip-us.apache.org/repos/asf/qpid-proton/tree/6f799990 Diff: http://git-wip-us.apache.org/repos/asf/qpid-proton/diff/6f799990 Branch: refs/heads/go1 Commit: 6f799990cdf739b3caacf66eb2a9a29b14c9abeb Parents: 6e5b4d5 e3797ce Author: Alan Conway <acon...@redhat.com> Authored: Tue Apr 10 17:15:21 2018 -0400 Committer: Alan Conway <acon...@redhat.com> Committed: Tue Apr 10 17:15:21 2018 -0400 ---------------------------------------------------------------------- amqp/marshal.go | 3 +- amqp/unmarshal.go | 4 +- electron/auth_test.go | 92 ++++++++++++++++++++++++--------------------- electron/connection.go | 65 +++++++++++++++++++++----------- 4 files changed, 94 insertions(+), 70 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/amqp/marshal.go ---------------------------------------------------------------------- diff --cc amqp/marshal.go index 33b30a8,0000000..99584a2 mode 100644,000000..100644 --- a/amqp/marshal.go +++ b/amqp/marshal.go @@@ -1,360 -1,0 +1,359 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package amqp + +// #include <proton/codec.h> +import "C" + +import ( + "fmt" + "io" + "reflect" + "time" + "unsafe" +) + +// Error returned if Go data cannot be marshaled as an AMQP type. +type MarshalError struct { + // The Go type. + GoType reflect.Type + s string +} + +func (e MarshalError) Error() string { return e.s } + +func newMarshalError(v interface{}, s string) *MarshalError { + t := reflect.TypeOf(v) + return &MarshalError{GoType: t, s: fmt.Sprintf("cannot marshal %s: %s", t, s)} +} + +func dataMarshalError(v interface{}, data *C.pn_data_t) error { + if pe := PnError(C.pn_data_error(data)); pe != nil { + return newMarshalError(v, pe.Error()) + } + return nil +} + +/* +Marshal encodes a Go value as AMQP data in buffer. +If buffer is nil, or is not large enough, a new buffer is created. + +Returns the buffer used for encoding with len() adjusted to the actual size of data. + +Go types are encoded as follows + + +-------------------------------------+--------------------------------------------+ + |Go type |AMQP type | + +-------------------------------------+--------------------------------------------+ + |bool |bool | + +-------------------------------------+--------------------------------------------+ + |int8, int16, int32, int64 (int) |byte, short, int, long (int or long) | + +-------------------------------------+--------------------------------------------+ + |uint8, uint16, uint32, uint64 (uint) |ubyte, ushort, uint, ulong (uint or ulong) | + +-------------------------------------+--------------------------------------------+ + |float32, float64 |float, double. | + +-------------------------------------+--------------------------------------------+ + |string |string | + +-------------------------------------+--------------------------------------------+ + |[]byte, Binary |binary | + +-------------------------------------+--------------------------------------------+ + |Symbol |symbol | + +-------------------------------------+--------------------------------------------+ + |Char |char | + +-------------------------------------+--------------------------------------------+ + |interface{} |the contained type | + +-------------------------------------+--------------------------------------------+ + |nil |null | + +-------------------------------------+--------------------------------------------+ + |map[K]T |map with K and T converted as above | + +-------------------------------------+--------------------------------------------+ + |Map |map, may have mixed types for keys, values | + +-------------------------------------+--------------------------------------------+ + |AnyMap |map (See AnyMap) | + +-------------------------------------+--------------------------------------------+ + |List, []interface{} |list, may have mixed-type values | + +-------------------------------------+--------------------------------------------+ + |[]T, [N]T |array, T is mapped as per this table | + +-------------------------------------+--------------------------------------------+ + |Described |described type | + +-------------------------------------+--------------------------------------------+ + |time.Time |timestamp | + +-------------------------------------+--------------------------------------------+ + |UUID |uuid | + +-------------------------------------+--------------------------------------------+ + +The following Go types cannot be marshaled: uintptr, function, channel, struct, complex64/128 + - AMQP types not yet supported: - - decimal32/64/128, ++AMQP types not yet supported: decimal32/64/128 +*/ + +func Marshal(v interface{}, buffer []byte) (outbuf []byte, err error) { + data := C.pn_data(0) + defer C.pn_data_free(data) + if err = recoverMarshal(v, data); err != nil { + return buffer, err + } + encode := func(buf []byte) ([]byte, error) { + n := int(C.pn_data_encode(data, cPtr(buf), cLen(buf))) + switch { + case n == int(C.PN_OVERFLOW): + return buf, overflow + case n < 0: + return buf, dataMarshalError(v, data) + default: + return buf[:n], nil + } + } + return encodeGrow(buffer, encode) +} + +// Internal use only +func MarshalUnsafe(v interface{}, pnData unsafe.Pointer) (err error) { + return recoverMarshal(v, (*C.pn_data_t)(pnData)) +} + +func recoverMarshal(v interface{}, data *C.pn_data_t) (err error) { + defer func() { // Convert panic to error return + if r := recover(); r != nil { + if err2, ok := r.(*MarshalError); ok { + err = err2 // Convert internal panic to error + } else { + panic(r) // Unrecognized error, continue to panic + } + } + }() + marshal(v, data) // Panics on error + return +} + +const minEncode = 256 + +// overflow is returned when an encoding function can't fit data in the buffer. +var overflow = fmt.Errorf("buffer too small") + +// encodeFn encodes into buffer[0:len(buffer)]. +// Returns buffer with length adjusted for data encoded. +// If buffer too small, returns overflow as error. +type encodeFn func(buffer []byte) ([]byte, error) + +// encodeGrow calls encode() into buffer, if it returns overflow grows the buffer. +// Returns the final buffer. +func encodeGrow(buffer []byte, encode encodeFn) ([]byte, error) { + if buffer == nil || len(buffer) == 0 { + buffer = make([]byte, minEncode) + } + var err error + for buffer, err = encode(buffer); err == overflow; buffer, err = encode(buffer) { + buffer = make([]byte, 2*len(buffer)) + } + return buffer, err +} + +// Marshal v to data +func marshal(i interface{}, data *C.pn_data_t) { + switch v := i.(type) { + case nil: + C.pn_data_put_null(data) + case bool: + C.pn_data_put_bool(data, C.bool(v)) + + // Signed integers + case int8: + C.pn_data_put_byte(data, C.int8_t(v)) + case int16: + C.pn_data_put_short(data, C.int16_t(v)) + case int32: + C.pn_data_put_int(data, C.int32_t(v)) + case int64: + C.pn_data_put_long(data, C.int64_t(v)) + case int: + if intIs64 { + C.pn_data_put_long(data, C.int64_t(v)) + } else { + C.pn_data_put_int(data, C.int32_t(v)) + } + + // Unsigned integers + case uint8: + C.pn_data_put_ubyte(data, C.uint8_t(v)) + case uint16: + C.pn_data_put_ushort(data, C.uint16_t(v)) + case uint32: + C.pn_data_put_uint(data, C.uint32_t(v)) + case uint64: + C.pn_data_put_ulong(data, C.uint64_t(v)) + case uint: + if intIs64 { + C.pn_data_put_ulong(data, C.uint64_t(v)) + } else { + C.pn_data_put_uint(data, C.uint32_t(v)) + } + + // Floating point + case float32: + C.pn_data_put_float(data, C.float(v)) + case float64: + C.pn_data_put_double(data, C.double(v)) + + // String-like (string, binary, symbol) + case string: + C.pn_data_put_string(data, pnBytes([]byte(v))) + case []byte: + C.pn_data_put_binary(data, pnBytes(v)) + case Binary: + C.pn_data_put_binary(data, pnBytes([]byte(v))) + case Symbol: + C.pn_data_put_symbol(data, pnBytes([]byte(v))) + + // Other simple types + case time.Time: + C.pn_data_put_timestamp(data, C.pn_timestamp_t(v.UnixNano()/1000)) + case UUID: + C.pn_data_put_uuid(data, *(*C.pn_uuid_t)(unsafe.Pointer(&v[0]))) + case Char: + C.pn_data_put_char(data, (C.pn_char_t)(v)) + + // Described types + case Described: + C.pn_data_put_described(data) + C.pn_data_enter(data) + marshal(v.Descriptor, data) + marshal(v.Value, data) + C.pn_data_exit(data) + + // Restricted type annotation-key, marshals as contained value + case AnnotationKey: + marshal(v.Get(), data) + + // Special type to represent AMQP maps with keys that are illegal in Go + case AnyMap: + C.pn_data_put_map(data) + C.pn_data_enter(data) + defer C.pn_data_exit(data) + for _, kv := range v { + marshal(kv.Key, data) + marshal(kv.Value, data) + } + + default: + // Examine complex types (Go map, slice, array) by reflected structure + switch reflect.TypeOf(i).Kind() { + + case reflect.Map: + m := reflect.ValueOf(v) + C.pn_data_put_map(data) + if C.pn_data_enter(data) { + defer C.pn_data_exit(data) + } else { + panic(dataMarshalError(i, data)) + } + for _, key := range m.MapKeys() { + marshal(key.Interface(), data) + marshal(m.MapIndex(key).Interface(), data) + } + + case reflect.Slice, reflect.Array: + // Note: Go array and slice are mapped the same way: + // if element type is an interface, map to AMQP list (mixed type) + // if element type is a non-interface type map to AMQP array (single type) + s := reflect.ValueOf(v) + if pnType, ok := arrayTypeMap[s.Type().Elem()]; ok { + C.pn_data_put_array(data, false, pnType) + } else { + C.pn_data_put_list(data) + } + C.pn_data_enter(data) + defer C.pn_data_exit(data) + for j := 0; j < s.Len(); j++ { + marshal(s.Index(j).Interface(), data) + } + + default: + panic(newMarshalError(v, "no conversion")) + } + } + if err := dataMarshalError(i, data); err != nil { + panic(err) + } +} + +// Mapping froo Go element type to AMQP array type for types that can go in an AMQP array +// NOTE: this must be kept consistent with marshal() which does the actual marshalling. +var arrayTypeMap = map[reflect.Type]C.pn_type_t{ + nil: C.PN_NULL, + reflect.TypeOf(true): C.PN_BOOL, + + reflect.TypeOf(int8(0)): C.PN_BYTE, + reflect.TypeOf(int16(0)): C.PN_INT, + reflect.TypeOf(int32(0)): C.PN_SHORT, + reflect.TypeOf(int64(0)): C.PN_LONG, + + reflect.TypeOf(uint8(0)): C.PN_UBYTE, + reflect.TypeOf(uint16(0)): C.PN_UINT, + reflect.TypeOf(uint32(0)): C.PN_USHORT, + reflect.TypeOf(uint64(0)): C.PN_ULONG, + + reflect.TypeOf(float32(0)): C.PN_FLOAT, + reflect.TypeOf(float64(0)): C.PN_DOUBLE, + + reflect.TypeOf(""): C.PN_STRING, + reflect.TypeOf((*Symbol)(nil)).Elem(): C.PN_SYMBOL, + reflect.TypeOf((*Binary)(nil)).Elem(): C.PN_BINARY, + reflect.TypeOf([]byte{}): C.PN_BINARY, + + reflect.TypeOf((*time.Time)(nil)).Elem(): C.PN_TIMESTAMP, + reflect.TypeOf((*UUID)(nil)).Elem(): C.PN_UUID, + reflect.TypeOf((*Char)(nil)).Elem(): C.PN_CHAR, +} + +// Compute mapping of int/uint at runtime as they depend on execution environment. +func init() { + if intIs64 { + arrayTypeMap[reflect.TypeOf(int(0))] = C.PN_LONG + arrayTypeMap[reflect.TypeOf(uint(0))] = C.PN_ULONG + } else { + arrayTypeMap[reflect.TypeOf(int(0))] = C.PN_INT + arrayTypeMap[reflect.TypeOf(uint(0))] = C.PN_UINT + } +} + +func clearMarshal(v interface{}, data *C.pn_data_t) { + C.pn_data_clear(data) + marshal(v, data) +} + +// Encoder encodes AMQP values to an io.Writer +type Encoder struct { + writer io.Writer + buffer []byte +} + +// New encoder returns a new encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{w, make([]byte, minEncode)} +} + +func (e *Encoder) Encode(v interface{}) (err error) { + e.buffer, err = Marshal(v, e.buffer) + if err == nil { + _, err = e.writer.Write(e.buffer) + } + return err +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/amqp/unmarshal.go ---------------------------------------------------------------------- diff --cc amqp/unmarshal.go index 97e8437,0000000..2c6e3f1 mode 100644,000000..100644 --- a/amqp/unmarshal.go +++ b/amqp/unmarshal.go @@@ -1,733 -1,0 +1,731 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +oor more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package amqp + +// #include <proton/codec.h> +import "C" + +import ( + "bytes" + "fmt" + "io" + "reflect" + "time" + "unsafe" +) + +const minDecode = 1024 + +// Error returned if AMQP data cannot be unmarshaled as the desired Go type. +type UnmarshalError struct { + // The name of the AMQP type. + AMQPType string + // The Go type. + GoType reflect.Type + + s string +} + +func (e UnmarshalError) Error() string { return e.s } + +// Error returned if there are not enough bytes to decode a complete AMQP value. +var EndOfData = &UnmarshalError{s: "Not enough data for AMQP value"} + +var badData = &UnmarshalError{s: "Unexpected error in data"} + +func newUnmarshalError(pnType C.pn_type_t, v interface{}) *UnmarshalError { + e := &UnmarshalError{ + AMQPType: C.pn_type_t(pnType).String(), + GoType: reflect.TypeOf(v), + } + if e.GoType == nil || e.GoType.Kind() != reflect.Ptr { + e.s = fmt.Sprintf("cannot unmarshal to Go type %v, not a pointer", e.GoType) + } else { + e.s = fmt.Sprintf("cannot unmarshal AMQP %v to Go %v", e.AMQPType, e.GoType.Elem()) + } + return e +} + +func doPanic(data *C.pn_data_t, v interface{}) { + e := newUnmarshalError(C.pn_data_type(data), v) + panic(e) +} + +func doPanicMsg(data *C.pn_data_t, v interface{}, msg string) { + e := newUnmarshalError(C.pn_data_type(data), v) + e.s = e.s + ": " + msg + panic(e) +} + +func panicIfBadData(data *C.pn_data_t, v interface{}) { + if C.pn_data_errno(data) != 0 { + doPanicMsg(data, v, PnError(C.pn_data_error(data)).Error()) + } +} + +func panicUnless(ok bool, data *C.pn_data_t, v interface{}) { + if !ok { + doPanic(data, v) + } +} + +func checkOp(ok bool, v interface{}) { + if !ok { + panic(&badData) + } +} + +// +// Decoding from a pn_data_t +// +// NOTE: we use panic() to signal a decoding error, simplifies decoding logic. +// We recover() at the highest possible level - i.e. in the exported Unmarshal or Decode. +// + +// Decoder decodes AMQP values from an io.Reader. +// +type Decoder struct { + reader io.Reader + buffer bytes.Buffer +} + +// NewDecoder returns a new decoder that reads from r. +// +// The decoder has it's own buffer and may read more data than required for the +// AMQP values requested. Use Buffered to see if there is data left in the +// buffer. +// +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r, bytes.Buffer{}} +} + +// Buffered returns a reader of the data remaining in the Decoder's buffer. The +// reader is valid until the next call to Decode. +// +func (d *Decoder) Buffered() io.Reader { + return bytes.NewReader(d.buffer.Bytes()) +} + +// Decode reads the next AMQP value from the Reader and stores it in the value pointed to by v. +// +// See the documentation for Unmarshal for details about the conversion of AMQP into a Go value. +// +func (d *Decoder) Decode(v interface{}) (err error) { + data := C.pn_data(0) + defer C.pn_data_free(data) + var n int + for n, err = decode(data, d.buffer.Bytes()); err == EndOfData; { + err = d.more() + if err == nil { + n, err = decode(data, d.buffer.Bytes()) + } + } + if err == nil { + if err = recoverUnmarshal(v, data); err == nil { + d.buffer.Next(n) + } + } + return +} + +/* + +Unmarshal decodes AMQP-encoded bytes and stores the result in the Go value +pointed to by v. Legal conversions from the source AMQP type to the target Go +type as follows: + + +----------------------------+-------------------------------------------------+ + |Target Go type | Allowed AMQP types + +============================+==================================================+ + |bool |bool | + +----------------------------+--------------------------------------------------+ + |int, int8, int16, int32, |Equivalent or smaller signed integer type: | + |int64 |byte, short, int, long or char | + +----------------------------+--------------------------------------------------+ + |uint, uint8, uint16, uint32,|Equivalent or smaller unsigned integer type: | + |uint64 |ubyte, ushort, uint, ulong | + +----------------------------+--------------------------------------------------+ + |float32, float64 |Equivalent or smaller float or double | + +----------------------------+--------------------------------------------------+ + |string, []byte |string, symbol or binary | + +----------------------------+--------------------------------------------------+ + |Symbol |symbol | + +----------------------------+--------------------------------------------------+ + |Char |char | + +----------------------------+--------------------------------------------------+ + |Described |AMQP described type [1] | + +----------------------------+--------------------------------------------------+ + |Time |timestamp | + +----------------------------+--------------------------------------------------+ + |UUID |uuid | + +----------------------------+--------------------------------------------------+ + |map[interface{}]interface{} |Any AMQP map | + +----------------------------+--------------------------------------------------+ + |map[K]T |map, provided all keys and values can unmarshal | + | |to types K,T | + +----------------------------+--------------------------------------------------+ + |[]interface{} |AMQP list or array | + +----------------------------+--------------------------------------------------+ + |[]T |list or array if elements can unmarshal as T | + +----------------------------+------------------n-------------------------------+ + |interface{} |any AMQP type[2] | + +----------------------------+--------------------------------------------------+ + +[1] An AMQP described value can also unmarshal to a plain value, discarding the +descriptor. Unmarshalling into the special amqp.Described type preserves the +descriptor. + +[2] Any AMQP value can be unmarshalled to an interface{}. The Go type is +determined by the AMQP type as follows: + + +----------------------------+--------------------------------------------------+ + |Source AMQP Type |Go Type in target interface{} | + +============================+==================================================+ + |bool |bool | + +----------------------------+--------------------------------------------------+ + |byte,short,int,long |int8,int16,int32,int64 | + +----------------------------+--------------------------------------------------+ + |ubyte,ushort,uint,ulong |uint8,uint16,uint32,uint64 | + +----------------------------+--------------------------------------------------+ + |float, double |float32, float64 | + +----------------------------+--------------------------------------------------+ + |string |string | + +----------------------------+--------------------------------------------------+ + |symbol |Symbol | + +----------------------------+--------------------------------------------------+ + |char |Char | + +----------------------------+--------------------------------------------------+ + |binary |Binary | + +----------------------------+--------------------------------------------------+ + |null |nil | + +----------------------------+--------------------------------------------------+ + |described type |Described | + +----------------------------+--------------------------------------------------+ + |timestamp |time.Time | + +----------------------------+--------------------------------------------------+ + |uuid |UUID | + +----------------------------+--------------------------------------------------+ + |map |Map or AnyMap[4] | + +----------------------------+--------------------------------------------------+ + |list |List | + +----------------------------+--------------------------------------------------+ + |array |[]T for simple types, T is chosen as above [3] | + +----------------------------+--------------------------------------------------+ + +[3] An AMQP array of simple types unmarshalls as a slice of the corresponding Go type. +An AMQP array containing complex types (lists, maps or nested arrays) unmarshals +to the generic array type amqp.Array + +[4] An AMQP map unmarshals as the generic `type Map map[interface{}]interface{}` +unless it contains key values that are illegal as Go map types, in which case +it unmarshals as type AnyMap. + +The following Go types cannot be unmarshaled: uintptr, function, interface, +channel, array (use slice), struct + - AMQP types not yet supported: - - decimal32/64/128 - - maps with key values that are not legal Go map keys. ++AMQP types not yet supported: decimal32/64/128 +*/ +func Unmarshal(bytes []byte, v interface{}) (n int, err error) { + data := C.pn_data(0) + defer C.pn_data_free(data) + n, err = decode(data, bytes) + if err == nil { + err = recoverUnmarshal(v, data) + } + return +} + +// Internal +func UnmarshalUnsafe(pnData unsafe.Pointer, v interface{}) (err error) { + return recoverUnmarshal(v, (*C.pn_data_t)(pnData)) +} + +// more reads more data when we can't parse a complete AMQP type +func (d *Decoder) more() error { + var readSize int64 = minDecode + if int64(d.buffer.Len()) > readSize { // Grow by doubling + readSize = int64(d.buffer.Len()) + } + var n int64 + n, err := d.buffer.ReadFrom(io.LimitReader(d.reader, readSize)) + if n == 0 && err == nil { // ReadFrom won't report io.EOF, just returns 0 + err = io.EOF + } + return err +} + +// Call unmarshal(), convert panic to error value +func recoverUnmarshal(v interface{}, data *C.pn_data_t) (err error) { + defer func() { + if r := recover(); r != nil { + if uerr, ok := r.(*UnmarshalError); ok { + err = uerr + } else { + panic(r) + } + } + }() + unmarshal(v, data) + return nil +} + +// Unmarshal from data into value pointed at by v. Returns v. +// NOTE: If you update this you also need to update getInterface() +func unmarshal(v interface{}, data *C.pn_data_t) { + rt := reflect.TypeOf(v) + rv := reflect.ValueOf(v) + panicUnless(v != nil && rt.Kind() == reflect.Ptr && !rv.IsNil(), data, v) + + // Check for PN_DESCRIBED first, as described types can unmarshal into any of the Go types. + // An interface{} target is handled in the switch below, even for described types. + if _, isInterface := v.(*interface{}); !isInterface && bool(C.pn_data_is_described(data)) { + getDescribed(data, v) + return + } + + // Unmarshal based on the target type + pnType := C.pn_data_type(data) + switch v := v.(type) { + + case *bool: + panicUnless(pnType == C.PN_BOOL, data, v) + *v = bool(C.pn_data_get_bool(data)) + + case *int8: + panicUnless(pnType == C.PN_BYTE, data, v) + *v = int8(C.pn_data_get_byte(data)) + + case *uint8: + panicUnless(pnType == C.PN_UBYTE, data, v) + *v = uint8(C.pn_data_get_ubyte(data)) + + case *int16: + switch C.pn_data_type(data) { + case C.PN_BYTE: + *v = int16(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int16(C.pn_data_get_short(data)) + default: + doPanic(data, v) + } + + case *uint16: + switch pnType { + case C.PN_UBYTE: + *v = uint16(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint16(C.pn_data_get_ushort(data)) + default: + doPanic(data, v) + } + + case *int32: + switch pnType { + case C.PN_CHAR: + *v = int32(C.pn_data_get_char(data)) + case C.PN_BYTE: + *v = int32(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int32(C.pn_data_get_short(data)) + case C.PN_INT: + *v = int32(C.pn_data_get_int(data)) + default: + doPanic(data, v) + } + + case *uint32: + switch pnType { + case C.PN_CHAR: + *v = uint32(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint32(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint32(C.pn_data_get_ushort(data)) + case C.PN_UINT: + *v = uint32(C.pn_data_get_uint(data)) + default: + doPanic(data, v) + } + + case *int64: + switch pnType { + case C.PN_CHAR: + *v = int64(C.pn_data_get_char(data)) + case C.PN_BYTE: + *v = int64(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int64(C.pn_data_get_short(data)) + case C.PN_INT: + *v = int64(C.pn_data_get_int(data)) + case C.PN_LONG: + *v = int64(C.pn_data_get_long(data)) + default: + doPanic(data, v) + } + + case *uint64: + switch pnType { + case C.PN_CHAR: + *v = uint64(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint64(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint64(C.pn_data_get_ushort(data)) + case C.PN_ULONG: + *v = uint64(C.pn_data_get_ulong(data)) + default: + doPanic(data, v) + } + + case *int: + switch pnType { + case C.PN_CHAR: + *v = int(C.pn_data_get_char(data)) + case C.PN_BYTE: + *v = int(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int(C.pn_data_get_short(data)) + case C.PN_INT: + *v = int(C.pn_data_get_int(data)) + case C.PN_LONG: + if intIs64 { + *v = int(C.pn_data_get_long(data)) + } else { + doPanic(data, v) + } + default: + doPanic(data, v) + } + + case *uint: + switch pnType { + case C.PN_CHAR: + *v = uint(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint(C.pn_data_get_ushort(data)) + case C.PN_UINT: + *v = uint(C.pn_data_get_uint(data)) + case C.PN_ULONG: + if intIs64 { + *v = uint(C.pn_data_get_ulong(data)) + } else { + doPanic(data, v) + } + default: + doPanic(data, v) + } + + case *float32: + panicUnless(pnType == C.PN_FLOAT, data, v) + *v = float32(C.pn_data_get_float(data)) + + case *float64: + switch pnType { + case C.PN_FLOAT: + *v = float64(C.pn_data_get_float(data)) + case C.PN_DOUBLE: + *v = float64(C.pn_data_get_double(data)) + default: + doPanic(data, v) + } + + case *string: + switch pnType { + case C.PN_STRING: + *v = goString(C.pn_data_get_string(data)) + case C.PN_SYMBOL: + *v = goString(C.pn_data_get_symbol(data)) + case C.PN_BINARY: + *v = goString(C.pn_data_get_binary(data)) + default: + doPanic(data, v) + } + + case *[]byte: + switch pnType { + case C.PN_STRING: + *v = goBytes(C.pn_data_get_string(data)) + case C.PN_SYMBOL: + *v = goBytes(C.pn_data_get_symbol(data)) + case C.PN_BINARY: + *v = goBytes(C.pn_data_get_binary(data)) + default: + doPanic(data, v) + } + return + + case *Char: + panicUnless(pnType == C.PN_CHAR, data, v) + *v = Char(C.pn_data_get_char(data)) + + case *Binary: + panicUnless(pnType == C.PN_BINARY, data, v) + *v = Binary(goBytes(C.pn_data_get_binary(data))) + + case *Symbol: + panicUnless(pnType == C.PN_SYMBOL, data, v) + *v = Symbol(goBytes(C.pn_data_get_symbol(data))) + + case *time.Time: + panicUnless(pnType == C.PN_TIMESTAMP, data, v) + *v = time.Unix(0, int64(C.pn_data_get_timestamp(data))*1000) + + case *UUID: + panicUnless(pnType == C.PN_UUID, data, v) + pn := C.pn_data_get_uuid(data) + copy((*v)[:], C.GoBytes(unsafe.Pointer(&pn.bytes), 16)) + + case *AnnotationKey: + panicUnless(pnType == C.PN_ULONG || pnType == C.PN_SYMBOL || pnType == C.PN_STRING, data, v) + unmarshal(&v.value, data) + + case *AnyMap: + panicUnless(C.pn_data_type(data) == C.PN_MAP, data, v) + n := int(C.pn_data_get_map(data)) / 2 + if cap(*v) < n { + *v = make(AnyMap, n) + } + *v = (*v)[:n] + data.enter(*v) + defer data.exit(*v) + for i := 0; i < n; i++ { + data.next(*v) + unmarshal(&(*v)[i].Key, data) + data.next(*v) + unmarshal(&(*v)[i].Value, data) + } + + case *interface{}: + getInterface(data, v) + + default: // This is not one of the fixed well-known types, reflect for map and slice types + + switch rt.Elem().Kind() { + case reflect.Map: + getMap(data, v) + case reflect.Slice: + getSequence(data, v) + default: + doPanic(data, v) + } + } +} + +// Unmarshalling into an interface{} the type is determined by the AMQP source type, +// since the interface{} target can hold any Go type. +func getInterface(data *C.pn_data_t, vp *interface{}) { + pnType := C.pn_data_type(data) + switch pnType { + case C.PN_BOOL: + *vp = bool(C.pn_data_get_bool(data)) + case C.PN_UBYTE: + *vp = uint8(C.pn_data_get_ubyte(data)) + case C.PN_BYTE: + *vp = int8(C.pn_data_get_byte(data)) + case C.PN_USHORT: + *vp = uint16(C.pn_data_get_ushort(data)) + case C.PN_SHORT: + *vp = int16(C.pn_data_get_short(data)) + case C.PN_UINT: + *vp = uint32(C.pn_data_get_uint(data)) + case C.PN_INT: + *vp = int32(C.pn_data_get_int(data)) + case C.PN_CHAR: + *vp = Char(C.pn_data_get_char(data)) + case C.PN_ULONG: + *vp = uint64(C.pn_data_get_ulong(data)) + case C.PN_LONG: + *vp = int64(C.pn_data_get_long(data)) + case C.PN_FLOAT: + *vp = float32(C.pn_data_get_float(data)) + case C.PN_DOUBLE: + *vp = float64(C.pn_data_get_double(data)) + case C.PN_BINARY: + *vp = Binary(goBytes(C.pn_data_get_binary(data))) + case C.PN_STRING: + *vp = goString(C.pn_data_get_string(data)) + case C.PN_SYMBOL: + *vp = Symbol(goString(C.pn_data_get_symbol(data))) + case C.PN_TIMESTAMP: + *vp = time.Unix(0, int64(C.pn_data_get_timestamp(data))*1000) + case C.PN_UUID: + var u UUID + unmarshal(&u, data) + *vp = u + case C.PN_MAP: + // We will try to unmarshal as a Map first, if that fails try AnyMap + m := make(Map, int(C.pn_data_get_map(data))/2) + if err := recoverUnmarshal(&m, data); err == nil { + *vp = m + } else { + am := make(AnyMap, int(C.pn_data_get_map(data))/2) + unmarshal(&am, data) + *vp = am + } + case C.PN_LIST: + l := List{} + unmarshal(&l, data) + *vp = l + case C.PN_ARRAY: + sp := getArrayStore(data) // interface{} containing T* for suitable T + unmarshal(sp, data) + *vp = reflect.ValueOf(sp).Elem().Interface() + case C.PN_DESCRIBED: + d := Described{} + unmarshal(&d, data) + *vp = d + case C.PN_NULL: + *vp = nil + case C.PN_INVALID: + // Allow decoding from an empty data object to an interface, treat it like NULL. + // This happens when optional values or properties are omitted from a message. + *vp = nil + default: // Don't know how to handle this + panic(newUnmarshalError(pnType, vp)) + } +} + +// Return an interface{} containing a pointer to an appropriate slice or Array +func getArrayStore(data *C.pn_data_t) interface{} { + // TODO aconway 2017-11-10: described arrays. + switch C.pn_data_get_array_type(data) { + case C.PN_BOOL: + return new([]bool) + case C.PN_UBYTE: + return new([]uint8) + case C.PN_BYTE: + return new([]int8) + case C.PN_USHORT: + return new([]uint16) + case C.PN_SHORT: + return new([]int16) + case C.PN_UINT: + return new([]uint32) + case C.PN_INT: + return new([]int32) + case C.PN_CHAR: + return new([]Char) + case C.PN_ULONG: + return new([]uint64) + case C.PN_LONG: + return new([]int64) + case C.PN_FLOAT: + return new([]float32) + case C.PN_DOUBLE: + return new([]float64) + case C.PN_BINARY: + return new([]Binary) + case C.PN_STRING: + return new([]string) + case C.PN_SYMBOL: + return new([]Symbol) + case C.PN_TIMESTAMP: + return new([]time.Time) + case C.PN_UUID: + return new([]UUID) + } + return new(Array) // Not a simple type, use generic Array +} + +var typeOfInterface = reflect.TypeOf(interface{}(nil)) + +// get into map pointed at by v +func getMap(data *C.pn_data_t, v interface{}) { + panicUnless(C.pn_data_type(data) == C.PN_MAP, data, v) + n := int(C.pn_data_get_map(data)) / 2 + mapValue := reflect.ValueOf(v).Elem() + mapValue.Set(reflect.MakeMap(mapValue.Type())) // Clear the map + data.enter(v) + defer data.exit(v) + // Allocate re-usable key/val values + keyType := mapValue.Type().Key() + keyPtr := reflect.New(keyType) + valPtr := reflect.New(mapValue.Type().Elem()) + for i := 0; i < n; i++ { + data.next(v) + unmarshal(keyPtr.Interface(), data) + if keyType.Kind() == reflect.Interface && !keyPtr.Elem().Elem().Type().Comparable() { + doPanicMsg(data, v, fmt.Sprintf("key %#v is not comparable", keyPtr.Elem().Interface())) + } + data.next(v) + unmarshal(valPtr.Interface(), data) + mapValue.SetMapIndex(keyPtr.Elem(), valPtr.Elem()) + } +} + +func getSequence(data *C.pn_data_t, vp interface{}) { + var count int + pnType := C.pn_data_type(data) + switch pnType { + case C.PN_LIST: + count = int(C.pn_data_get_list(data)) + case C.PN_ARRAY: + count = int(C.pn_data_get_array(data)) + default: + doPanic(data, vp) + } + listValue := reflect.MakeSlice(reflect.TypeOf(vp).Elem(), count, count) + data.enter(vp) + defer data.exit(vp) + for i := 0; i < count; i++ { + data.next(vp) + val := reflect.New(listValue.Type().Elem()) + unmarshal(val.Interface(), data) + listValue.Index(i).Set(val.Elem()) + } + reflect.ValueOf(vp).Elem().Set(listValue) +} + +func getDescribed(data *C.pn_data_t, vp interface{}) { + d, isDescribed := vp.(*Described) + data.enter(vp) + defer data.exit(vp) + data.next(vp) + if isDescribed { + unmarshal(&d.Descriptor, data) + data.next(vp) + unmarshal(&d.Value, data) + } else { + data.next(vp) // Skip descriptor + unmarshal(vp, data) // Unmarshal plain value + } +} + +// decode from bytes. +// Return bytes decoded or 0 if we could not decode a complete object. +// +func decode(data *C.pn_data_t, bytes []byte) (int, error) { + n := C.pn_data_decode(data, cPtr(bytes), cLen(bytes)) + if n == C.PN_UNDERFLOW { + C.pn_error_clear(C.pn_data_error(data)) + return 0, EndOfData + } else if n <= 0 { + return 0, &UnmarshalError{s: fmt.Sprintf("unmarshal %v", PnErrorCode(n))} + } + return int(n), nil +} + +// Checked versions of pn_data functions + +func (data *C.pn_data_t) enter(v interface{}) { checkOp(bool(C.pn_data_enter(data)), v) } +func (data *C.pn_data_t) exit(v interface{}) { checkOp(bool(C.pn_data_exit(data)), v) } +func (data *C.pn_data_t) next(v interface{}) { checkOp(bool(C.pn_data_next(data)), v) } http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/electron/auth_test.go ---------------------------------------------------------------------- diff --cc electron/auth_test.go index 9fa9fa2,0000000..162b366 mode 100644,000000..100644 --- a/electron/auth_test.go +++ b/electron/auth_test.go @@@ -1,137 -1,0 +1,143 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +import ( + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func testAuthClientServer(t *testing.T, copts []ConnectionOption, sopts []ConnectionOption) (got connectionSettings, err error) { + client, server := newClientServerOpts(t, copts, sopts) + defer closeClientServer(client, server) + + go func() { + for in := range server.Incoming() { + switch in := in.(type) { + case *IncomingConnection: + got = connectionSettings{user: in.User(), virtualHost: in.VirtualHost()} + } + in.Accept() + } + }() + + err = client.Sync() + return +} + +func TestAuthAnonymous(t *testing.T) { - configureSASL() + got, err := testAuthClientServer(t, + []ConnectionOption{User("fred"), VirtualHost("vhost"), SASLAllowInsecure(true)}, + []ConnectionOption{SASLAllowedMechs("ANONYMOUS"), SASLAllowInsecure(true)}) + fatalIf(t, err) + errorIf(t, checkEqual(connectionSettings{user: "anonymous", virtualHost: "vhost"}, got)) +} + +func TestAuthPlain(t *testing.T) { - if !SASLExtended() { - t.Skip() - } - fatalIf(t, configureSASL()) ++ extendedSASL.startTest(t) + got, err := testAuthClientServer(t, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("fred@proton"), Password([]byte("xxx"))}, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")}) + fatalIf(t, err) + errorIf(t, checkEqual(connectionSettings{user: "fred@proton"}, got)) +} + +func TestAuthBadPass(t *testing.T) { - if !SASLExtended() { - t.Skip() - } - fatalIf(t, configureSASL()) ++ extendedSASL.startTest(t) + _, err := testAuthClientServer(t, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("fred@proton"), Password([]byte("yyy"))}, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")}) + if err == nil { + t.Error("Expected auth failure for bad pass") + } +} + +func TestAuthBadUser(t *testing.T) { - if !SASLExtended() { - t.Skip() - } - fatalIf(t, configureSASL()) ++ extendedSASL.startTest(t) + _, err := testAuthClientServer(t, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("foo@bar"), Password([]byte("yyy"))}, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")}) + if err == nil { + t.Error("Expected auth failure for bad user") + } +} + - var confDir string - var confErr error ++type extendedSASLState struct { ++ err error ++ dir string ++} + - func configureSASL() error { - if confDir != "" || confErr != nil { - return confErr - } - confDir, confErr = ioutil.TempDir("", "") - if confErr != nil { - return confErr ++func (s *extendedSASLState) setup() { ++ if SASLExtended() { ++ if s.dir, s.err = ioutil.TempDir("", ""); s.err == nil { ++ GlobalSASLConfigDir(s.dir) ++ GlobalSASLConfigName("test") ++ conf := filepath.Join(s.dir, "test.conf") ++ db := filepath.Join(s.dir, "proton.sasldb") ++ saslpasswd := os.Getenv("SASLPASSWD") ++ if saslpasswd == "" { ++ saslpasswd = "saslpasswd2" ++ } ++ cmd := exec.Command(saslpasswd, "-c", "-p", "-f", db, "-u", "proton", "fred") ++ cmd.Stdin = strings.NewReader("xxx") // Password ++ if _, s.err = cmd.CombinedOutput(); s.err == nil { ++ confStr := fmt.Sprintf(` ++sasldb_path: %s ++mech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS ++`, db) ++ s.err = ioutil.WriteFile(conf, []byte(confStr), os.ModePerm) ++ } ++ } + } ++ // Note we don't do anything with s.err now, tests that need the ++ // extended SASL config will fail if s.err != nil. If no such tests ++ // are run then it is not an error that we couldn't set it up. ++} + - GlobalSASLConfigDir(confDir) - GlobalSASLConfigName("test") - conf := filepath.Join(confDir, "test.conf") - - db := filepath.Join(confDir, "proton.sasldb") - saslpasswd := os.Getenv("SASLPASSWD"); - if saslpasswd == "" { - saslpasswd = "saslpasswd2" - } - cmd := exec.Command(saslpasswd, "-c", "-p", "-f", db, "-u", "proton", "fred") - cmd.Stdin = strings.NewReader("xxx") // Password - if out, err := cmd.CombinedOutput(); err != nil { - confErr = fmt.Errorf("saslpasswd2 failed: %s\n%s", err, out) - return confErr ++func (s extendedSASLState) teardown() { ++ if s.dir != "" { ++ _ = os.RemoveAll(s.dir) + } - confStr := "sasldb_path: " + db + "\nmech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS\n" - if err := ioutil.WriteFile(conf, []byte(confStr), os.ModePerm); err != nil { - confErr = fmt.Errorf("write conf file %s failed: %s", conf, err) ++} ++ ++func (s extendedSASLState) startTest(t *testing.T) { ++ if !SASLExtended() { ++ t.Skipf("Extended SASL not enabled") ++ } else if extendedSASL.err != nil { ++ t.Skipf("Extended SASL setup error: %v", extendedSASL.err) + } - return confErr +} + ++var extendedSASL extendedSASLState ++ +func TestMain(m *testing.M) { ++ // Do global SASL setup/teardown in main. ++ // Doing it on-demand makes the tests fragile to parallel test runs and ++ // changes in test ordering. ++ extendedSASL.setup() + status := m.Run() - if confDir != "" { - _ = os.RemoveAll(confDir) - } ++ extendedSASL.teardown() + os.Exit(status) +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/electron/connection.go ---------------------------------------------------------------------- diff --cc electron/connection.go index 731e64d,0000000..9c0ef31 mode 100644,000000..100644 --- a/electron/connection.go +++ b/electron/connection.go @@@ -1,421 -1,0 +1,442 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +// #include <proton/disposition.h> +import "C" + +import ( + "net" + "qpid.apache.org/proton" + "sync" + "time" +) + +// Settings associated with a Connection. +type ConnectionSettings interface { + // Authenticated user name associated with the connection. + User() string + + // The AMQP virtual host name for the connection. + // + // Optional, useful when the server has multiple names and provides different + // service based on the name the client uses to connect. + // + // By default it is set to the DNS host name that the client uses to connect, + // but it can be set to something different at the client side with the + // VirtualHost() option. + // + // Returns error if the connection fails to authenticate. + VirtualHost() string + + // Heartbeat is the maximum delay between sending frames that the remote peer + // has requested of us. If the interval expires an empty "heartbeat" frame + // will be sent automatically to keep the connection open. + Heartbeat() time.Duration +} + +// Connection is an AMQP connection, created by a Container. +type Connection interface { + Endpoint + ConnectionSettings + + // Sender opens a new sender on the DefaultSession. + Sender(...LinkOption) (Sender, error) + + // Receiver opens a new Receiver on the DefaultSession(). + Receiver(...LinkOption) (Receiver, error) + + // DefaultSession() returns a default session for the connection. It is opened + // on the first call to DefaultSession and returned on subsequent calls. + DefaultSession() (Session, error) + + // Session opens a new session. + Session(...SessionOption) (Session, error) + + // Container for the connection. + Container() Container + + // Disconnect the connection abruptly with an error. + Disconnect(error) + + // Wait waits for the connection to be disconnected. + Wait() error + + // WaitTimeout is like Wait but returns Timeout if the timeout expires. + WaitTimeout(time.Duration) error + + // Incoming returns a channel for incoming endpoints opened by the remote peer. + // See the Incoming interface for more detail. + // + // Note: this channel will first return an *IncomingConnection for the + // connection itself which allows you to look at security information and + // decide whether to Accept() or Reject() the connection. Then it will return + // *IncomingSession, *IncomingSender and *IncomingReceiver as they are opened + // by the remote end. + // + // Note 2: you must receiving from Incoming() and call Accept/Reject to avoid + // blocking electron event loop. Normally you would run a loop in a goroutine + // to handle incoming types that interest and Accept() those that don't. + Incoming() <-chan Incoming +} + +type connectionSettings struct { + user, virtualHost string + heartbeat time.Duration +} + +func (c connectionSettings) User() string { return c.user } +func (c connectionSettings) VirtualHost() string { return c.virtualHost } +func (c connectionSettings) Heartbeat() time.Duration { return c.heartbeat } + +// ConnectionOption can be passed when creating a connection to configure various options +type ConnectionOption func(*connection) + +// User returns a ConnectionOption sets the user name for a connection +func User(user string) ConnectionOption { + return func(c *connection) { + c.user = user + c.pConnection.SetUser(user) + } +} + +// VirtualHost returns a ConnectionOption to set the AMQP virtual host for the connection. +// Only applies to outbound client connection. +func VirtualHost(virtualHost string) ConnectionOption { + return func(c *connection) { + c.virtualHost = virtualHost + c.pConnection.SetHostname(virtualHost) + } +} + +// Password returns a ConnectionOption to set the password used to establish a +// connection. Only applies to outbound client connection. +// +// The connection will erase its copy of the password from memory as soon as it +// has been used to authenticate. If you are concerned about passwords staying in +// memory you should never store them as strings, and should overwrite your +// copy as soon as you are done with it. +// +func Password(password []byte) ConnectionOption { + return func(c *connection) { c.pConnection.SetPassword(password) } +} + +// Server returns a ConnectionOption to put the connection in server mode for incoming connections. +// +// A server connection will do protocol negotiation to accept a incoming AMQP +// connection. Normally you would call this for a connection created by +// net.Listener.Accept() +// +func Server() ConnectionOption { + return func(c *connection) { c.engine.Server(); c.server = true; AllowIncoming()(c) } +} + +// AllowIncoming returns a ConnectionOption to enable incoming endpoints, see +// Connection.Incoming() This is automatically set for Server() connections. +func AllowIncoming() ConnectionOption { + return func(c *connection) { c.incoming = make(chan Incoming) } +} + +// Parent returns a ConnectionOption that associates the Connection with it's Container +// If not set a connection will create its own default container. +func Parent(cont Container) ConnectionOption { + return func(c *connection) { c.container = cont.(*container) } +} + +type connection struct { + endpoint + connectionSettings + + defaultSessionOnce, closeOnce sync.Once + + container *container + conn net.Conn + server bool + incoming chan Incoming + handler *handler + engine *proton.Engine + pConnection proton.Connection + + defaultSession Session +} + +// NewConnection creates a connection with the given options. +func NewConnection(conn net.Conn, opts ...ConnectionOption) (*connection, error) { + c := &connection{ + conn: conn, + } + c.handler = newHandler(c) + var err error + c.engine, err = proton.NewEngine(c.conn, c.handler.delegator) + if err != nil { + return nil, err + } + c.pConnection = c.engine.Connection() + for _, set := range opts { + set(c) + } + if c.container == nil { + c.container = NewContainer("").(*container) + } + c.pConnection.SetContainer(c.container.Id()) - globalSASLInit(c.engine) - ++ saslConfig.setup(c.engine) + c.endpoint.init(c.engine.String()) + go c.run() + return c, nil +} + +func (c *connection) run() { + if !c.server { + c.pConnection.Open() + } + _ = c.engine.Run() + if c.incoming != nil { + close(c.incoming) + } + _ = c.closed(Closed) +} + +func (c *connection) Close(err error) { + c.err.Set(err) + c.engine.Close(err) +} + +func (c *connection) Disconnect(err error) { + c.err.Set(err) + c.engine.Disconnect(err) +} + +func (c *connection) Session(opts ...SessionOption) (Session, error) { + var s Session + err := c.engine.InjectWait(func() error { + if c.Error() != nil { + return c.Error() + } + pSession, err := c.engine.Connection().Session() + if err == nil { + pSession.Open() + if err == nil { + s = newSession(c, pSession, opts...) + } + } + return err + }) + return s, err +} + +func (c *connection) Container() Container { return c.container } + +func (c *connection) DefaultSession() (s Session, err error) { + c.defaultSessionOnce.Do(func() { + c.defaultSession, err = c.Session() + }) + if err == nil { + err = c.Error() + } + return c.defaultSession, err +} + +func (c *connection) Sender(opts ...LinkOption) (Sender, error) { + if s, err := c.DefaultSession(); err == nil { + return s.Sender(opts...) + } else { + return nil, err + } +} + +func (c *connection) Receiver(opts ...LinkOption) (Receiver, error) { + if s, err := c.DefaultSession(); err == nil { + return s.Receiver(opts...) + } else { + return nil, err + } +} + +func (c *connection) Connection() Connection { return c } + +func (c *connection) Wait() error { return c.WaitTimeout(Forever) } +func (c *connection) WaitTimeout(timeout time.Duration) error { + _, err := timedReceive(c.done, timeout) + if err == Timeout { + return Timeout + } + return c.Error() +} + +func (c *connection) Incoming() <-chan Incoming { + assert(c.incoming != nil, "Incoming() is only allowed for a Connection created with the Server() option: %s", c) + return c.incoming +} + +type IncomingConnection struct { + incoming + connectionSettings + c *connection +} + +func newIncomingConnection(c *connection) *IncomingConnection { + c.user = c.pConnection.Transport().User() + c.virtualHost = c.pConnection.RemoteHostname() + return &IncomingConnection{ + incoming: makeIncoming(c.pConnection), + connectionSettings: c.connectionSettings, + c: c} +} + +// AcceptConnection is like Accept() but takes ConnectionOption s +// For example you can set the Heartbeat() for the accepted connection. +func (in *IncomingConnection) AcceptConnection(opts ...ConnectionOption) Connection { + return in.accept(func() Endpoint { + for _, opt := range opts { + opt(in.c) + } + in.c.pConnection.Open() + return in.c + }).(Connection) +} + +func (in *IncomingConnection) Accept() Endpoint { + return in.AcceptConnection() +} + +func sasl(c *connection) proton.SASL { return c.engine.Transport().SASL() } + +// SASLEnable returns a ConnectionOption that enables SASL authentication. +// Only required if you don't set any other SASL options. +func SASLEnable() ConnectionOption { return func(c *connection) { sasl(c) } } + +// SASLAllowedMechs returns a ConnectionOption to set the list of allowed SASL +// mechanisms. +// +// Can be used on the client or the server to restrict the SASL for a connection. +// mechs is a space-separated list of mechanism names. +// +func SASLAllowedMechs(mechs string) ConnectionOption { + return func(c *connection) { sasl(c).AllowedMechs(mechs) } +} + +// SASLAllowInsecure returns a ConnectionOption that allows or disallows clear +// text SASL authentication mechanisms +// +// By default the SASL layer is configured not to allow mechanisms that disclose +// the clear text of the password over an unencrypted AMQP connection. This specifically +// will disallow the use of the PLAIN mechanism without using SSL encryption. +// +// This default is to avoid disclosing password information accidentally over an +// insecure network. +// +func SASLAllowInsecure(b bool) ConnectionOption { + return func(c *connection) { sasl(c).SetAllowInsecureMechs(b) } +} + +// Heartbeat returns a ConnectionOption that requests the maximum delay +// between sending frames for the remote peer. If we don't receive any frames +// within 2*delay we will close the connection. +// +func Heartbeat(delay time.Duration) ConnectionOption { + // Proton-C divides the idle-timeout by 2 before sending, so compensate. + return func(c *connection) { c.engine.Transport().SetIdleTimeout(2 * delay) } +} + ++type saslConfigState struct { ++ lock sync.Mutex ++ name string ++ dir string ++ initialized bool ++} ++ ++func (s *saslConfigState) set(target *string, value string) { ++ s.lock.Lock() ++ defer s.lock.Unlock() ++ if s.initialized { ++ panic("SASL configuration cannot be changed after a Connection has been created") ++ } ++ *target = value ++} ++ ++// Apply the global SASL configuration the first time a proton.Engine needs it ++// ++// TODO aconway 2016-09-15: Current pn_sasl C impl config is broken, so all we ++// can realistically offer is global configuration. Later if/when the pn_sasl C ++// impl is fixed we can offer per connection over-rides. ++func (s *saslConfigState) setup(eng *proton.Engine) { ++ s.lock.Lock() ++ defer s.lock.Unlock() ++ if !s.initialized { ++ s.initialized = true ++ sasl := eng.Transport().SASL() ++ if s.name != "" { ++ sasl.ConfigName(saslConfig.name) ++ } ++ if s.dir != "" { ++ sasl.ConfigPath(saslConfig.dir) ++ } ++ } ++} ++ ++var saslConfig saslConfigState ++ +// GlobalSASLConfigDir sets the SASL configuration directory for every +// Connection created in this process. If not called, the default is determined +// by your SASL installation. +// +// You can set SASLAllowInsecure and SASLAllowedMechs on individual connections. +// - func GlobalSASLConfigDir(dir string) { globalSASLConfigDir = dir } ++// Must be called at most once, before any connections are created. ++func GlobalSASLConfigDir(dir string) { saslConfig.set(&saslConfig.dir, dir) } + +// GlobalSASLConfigName sets the SASL configuration name for every Connection +// created in this process. If not called the default is "proton-server". +// +// The complete configuration file name is +// <sasl-config-dir>/<sasl-config-name>.conf +// +// You can set SASLAllowInsecure and SASLAllowedMechs on individual connections. +// - func GlobalSASLConfigName(dir string) { globalSASLConfigName = dir } ++// Must be called at most once, before any connections are created. ++func GlobalSASLConfigName(name string) { saslConfig.set(&saslConfig.name, name) } + +// Do we support extended SASL negotiation? +// All implementations of Proton support ANONYMOUS and EXTERNAL on both +// client and server sides and PLAIN on the client side. +// +// Extended SASL implememtations use an external library (Cyrus SASL) +// to support other mechanisms beyond these basic ones. +func SASLExtended() bool { return proton.SASLExtended() } + - var ( - globalSASLConfigName string - globalSASLConfigDir string - ) - - // TODO aconway 2016-09-15: Current pn_sasl C impl config is broken, so all we - // can realistically offer is global configuration. Later if/when the pn_sasl C - // impl is fixed we can offer per connection over-rides. - func globalSASLInit(eng *proton.Engine) { - sasl := eng.Transport().SASL() - if globalSASLConfigName != "" { - sasl.ConfigName(globalSASLConfigName) - } - if globalSASLConfigDir != "" { - sasl.ConfigPath(globalSASLConfigDir) - } - } - +// Dial is shorthand for using net.Dial() then NewConnection() +// See net.Dial() for the meaning of the network, address arguments. +func Dial(network, address string, opts ...ConnectionOption) (c Connection, err error) { + conn, err := net.Dial(network, address) + if err == nil { + c, err = NewConnection(conn, opts...) + } + return +} + +// DialWithDialer is shorthand for using dialer.Dial() then NewConnection() +// See net.Dial() for the meaning of the network, address arguments. +func DialWithDialer(dialer *net.Dialer, network, address string, opts ...ConnectionOption) (c Connection, err error) { + conn, err := dialer.Dial(network, address) + if err == nil { + c, err = NewConnection(conn, opts...) + } + return +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@qpid.apache.org For additional commands, e-mail: commits-h...@qpid.apache.org