optima2005 commented on a change in pull request #4639: [Relay/Topi][Op] Conv1D URL: https://github.com/apache/incubator-tvm/pull/4639#discussion_r365048590
########## File path: python/tvm/relay/frontend/onnx.py ########## @@ -223,37 +223,64 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # infer pads for auto_pad + # Use shape of input to determine convolution type. + input_shape = infer_shape(inputs[0]) + if 'auto_pad' in attr: attr['auto_pad'] = attr['auto_pad'].decode('utf-8') if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): - input_shape = infer_shape(inputs[0]) - in_h, in_w = input_shape[2], input_shape[3] - stride_h, stride_w = attr['strides'] - kernel_h, kernel_w = attr['kernel_shape'] - dilation_h, dilation_w = attr['dilations'] - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) - attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + pad_tuple = [] + for axis in range(len(input_shape) - 2): + axis_shape = input_shape[2 + axis] + stride = attr['strides'][axis] + kernel = attr['kernel_shape'][axis] + dilation = attr['dilations'][axis] + dilated_kernel = (kernel - 1) * dilation + 1 + pad = get_pad_pair(axis_shape, dilated_kernel, stride) + pad_tuple.append(pad) + pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) + attr['pads'] = pad_tuple elif attr['auto_pad'] == 'VALID': - attr['pads'] = (0, 0) + attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)]) elif attr['auto_pad'] == 'NOTSET': pass else: msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' - raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) + raise tvm.error.OpAttributeInvalid( + msg.format(attr['auto_pad'])) attr.pop('auto_pad') - out = AttrCvt( - op_name=dimension_picker('conv'), - transforms={ + # Handle attribute conversion for different convolution types + + # Conv1D + if len(input_shape) == 3: Review comment: I suggest to update dimension_picker() and dimension_constraint(). please see tensoflow frontend. And only conditional switch those default values. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services