This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 88a7ced [Frontend][PaddlePaddle] Support
conv2d_transpose/rnn/fill_constant_batch_size_like (#9564)
88a7ced is described below
commit 88a7ceddb875eea4bb0d05b89338bc81d0d08607
Author: Jason <[email protected]>
AuthorDate: Fri Nov 26 16:43:45 2021 +0800
[Frontend][PaddlePaddle] Support
conv2d_transpose/rnn/fill_constant_batch_size_like (#9564)
* add conv2dtranspose, rnn, fill_batch_size_like
* fix conv_transpose and add RNN test case
* Update paddlepaddle.py
* Create paddlepaddle.py
* fix scale attr of convert_interpolate
* black code
Co-authored-by: heliqi <[email protected]>
---
python/tvm/relay/frontend/paddlepaddle.py | 406 ++++++++++++++++++++-
tests/python/frontend/paddlepaddle/test_forward.py | 69 +++-
2 files changed, 466 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/paddlepaddle.py
b/python/tvm/relay/frontend/paddlepaddle.py
index 9672385..46f96b7 100644
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -301,6 +301,66 @@ def convert_conv2d(g, op, block):
g.add_node(op.output("Output")[0], out)
+def convert_conv2d_transpose(g, op, block):
+ """Operator converter for conv2d_transpose."""
+
+ dilations = op.attr("dilations")
+ groups = op.attr("groups")
+ paddings = op.attr("paddings")
+ padding_algorithm = op.attr("padding_algorithm")
+ strides = op.attr("strides")
+ output_padding = op.attr("output_padding") if op.attr("output_padding")
else [0, 0]
+
+ kernel = g.get_node(op.input("Filter")[0])
+ input_x = g.get_node(op.input("Input")[0])
+ _, out_channels, k_h, k_w = infer_shape(kernel)
+ k_size = [k_h, k_w]
+ if padding_algorithm == "VALID":
+ paddings = [0, 0]
+ elif padding_algorithm == "SAME":
+ # SAME padding of conv2d_transpose is not same with conv2d
+ # We cannot use auto_pad here, only static shape is supported now
+ dilations = [1, 1]
+ input_shape = shape_of(input_x)
+ h_w = _op.strided_slice(input_shape, [2], [4])
+ try:
+ h_w = infer_value(h_w, g.get_params()).numpy().tolist()
+ except Exception as e:
+ msg = "The SAME padding algorithm of conv2d_transpose not support
dynamic shape"
+ raise tvm.error.OpAttributeInvalid(msg) from e
+ paddings = []
+ for i in range(2):
+ if strides[i] == 1 or h_w[i] % strides[i] == 0:
+ pad = max(k_size[i] - strides[i], 0)
+ else:
+ pad = max(k_size[i] - (h_w[i] % strides[i]), 0)
+ pad_before = pad // 2
+ pad_after = pad - pad_before
+ paddings.insert(-1, pad_before)
+ paddings.append(pad_after)
+ elif padding_algorithm == "EXPLICIT":
+ if len(paddings) == 2:
+ paddings = [paddings[0], paddings[1], paddings[0], paddings[1]]
+ elif len(paddings) == 4:
+ paddings = [paddings[0], paddings[2], paddings[1], paddings[3]]
+ else:
+ msg = 'Value {} in attribute "padding" of operator Conv is not
"valid."'
+ raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))
+
+ out = _op.nn.conv2d_transpose(
+ input_x,
+ kernel,
+ strides=strides,
+ padding=paddings,
+ dilation=dilations,
+ groups=groups,
+ channels=out_channels * groups,
+ kernel_size=k_size,
+ output_padding=output_padding,
+ )
+ g.add_node(op.output("Output")[0], out)
+
+
def convert_cumsum(g, op, block):
"""Operator converter for cumsum."""
@@ -462,6 +522,36 @@ def convert_fill_constant(g, op, block):
g.add_node(op.output("Out")[0], out)
+def convert_fill_constant_batch_size_like(g, op, block):
+ """Operator converter for fill_constant_batch_size_like."""
+
+ x = g.get_node(op.input("Input")[0])
+ value = op.attr("value")
+ shape = op.attr("shape")
+ input_dim_idx = op.attr("input_dim_idx")
+ output_dim_idx = op.attr("output_dim_idx")
+ dtype = op.attr("dtype")
+
+ dtype = _convert_dtype_value(dtype)
+ input_shape = shape_of(x)
+ batch = _op.strided_slice(input_shape, begin=[input_dim_idx],
end=[input_dim_idx + 1]).astype(
+ "int32"
+ )
+ shape_before = shape[:output_dim_idx]
+ shape_before = _expr.const(shape_before, dtype="int32")
+ shape_after = shape[output_dim_idx + 1 :]
+ shape_after = _expr.const(shape_after, dtype="int32")
+
+ out_shape = _op.concatenate([shape_before, batch, shape_after], axis=0)
+ out_shape, infered = try_infer_value(out_shape, g.get_params())
+ if infered:
+ out_shape = out_shape.tolist()
+ constant = _expr.const(value, dtype=dtype).astype(dtype)
+ out = _op.full(constant, out_shape, dtype=dtype)
+
+ g.add_node(op.output("Out")[0], out)
+
+
def convert_flatten(g, op, block):
"""Operator converter for flatten."""
@@ -620,6 +710,9 @@ def convert_interpolate(g, op, block):
layout = op.attr("data_layout")
out_h = op.attr("out_h")
out_w = op.attr("out_w")
+ scale = op.attr("scale")
+ if not isinstance(scale, (list, tuple)):
+ scale = [scale, scale]
x = g.get_node(op.input("X")[0])
x_shape = infer_shape(x)
@@ -629,13 +722,7 @@ def convert_interpolate(g, op, block):
input_scale = op.input("Scale")
rounding_method, interp_method, coordinate_transformation_mode =
get_interpolate_mode(op)
- if input_out_size:
- # if out_size is a tensor
- out_size = g.get_node(input_out_size[0])
- out_size, infered = try_infer_value(out_size,
parameters=g.get_params())
- if infered:
- out_size = out_size.tolist()
- elif input_size_tensor:
+ if input_size_tensor:
# if out_size is a list of tensor
out_size = list()
for name in input_size_tensor:
@@ -659,6 +746,24 @@ def convert_interpolate(g, op, block):
out_size, infered = try_infer_value(out_size,
parameters=g.get_params())
if infered:
out_size = out_size.tolist()
+ elif scale and scale[0] > 0 and scale[1] > 0:
+ # use attribute scale
+ input_shape = shape_of(x).astype("float32")
+ input_scale = _expr.const(np.array([scale[0],
scale[1]]).astype("float32"))
+ if layout.startswith("NC"):
+ out_size = _op.strided_slice(input_shape, begin=[2], end=[4]) *
input_scale
+ else:
+ out_size = _op.strided_slice(input_shape, begin=[1], end=[3]) *
input_scale
+ out_size = out_size.astype("int32")
+ out_size, infered = try_infer_value(out_size,
parameters=g.get_params())
+ if infered:
+ out_size = out_size.tolist()
+ elif input_out_size:
+ # if out_size is a tensor
+ out_size = g.get_node(input_out_size[0])
+ out_size, infered = try_infer_value(out_size,
parameters=g.get_params())
+ if infered:
+ out_size = out_size.tolist()
else:
# if out_size is a constant value
out_size = [out_h, out_w]
@@ -1078,6 +1183,290 @@ def convert_reshape(g, op, block):
g.add_node(op.output("Out")[0], out)
+def convert_rnn(g, op, block):
+ """Operator converter for rnn."""
+
+ def generate_lstm(
+ input_seqs,
+ hidden_state,
+ cell_state,
+ w_inp,
+ w_hid,
+ b_inp,
+ b_hid,
+ f_act,
+ g_act,
+ h_act,
+ backwards=False,
+ ):
+ """Implementation of LSTM cell for paddlepaddle of TVM"""
+
+ h_list = []
+ seq_length = len(input_seqs)
+ for i in range(seq_length):
+ step = input_seqs[i] if not backwards else input_seqs[seq_length -
(i + 1)]
+ step = _op.squeeze(step, axis=[0])
+ gates = _op.nn.dense(step, w_inp) + _op.nn.dense(hidden_state,
w_hid)
+ if b_inp is not None:
+ gates += b_inp
+ if b_hid is not None:
+ gates += b_hid
+ i, f, c, o = _op.split(gates, 4, axis=-1)
+
+ i = f_act(i)
+ f = f_act(f)
+
+ c = g_act(c)
+ C = f * cell_state + i * c
+
+ o = f_act(o)
+
+ H = o * h_act(C)
+
+ hidden_state = H
+ cell_state = C
+ h_list.append(_op.expand_dims(H, axis=0))
+
+ if backwards:
+ h_list = h_list[::-1]
+
+ # Concatenate outputs and add back in direction axis.
+ concatenated = _op.concatenate(h_list, 0)
+ output = _op.expand_dims(concatenated, axis=1)
+ hidden_state = _op.expand_dims(hidden_state, axis=0)
+ cell_state = _op.expand_dims(cell_state, axis=0)
+
+ return output, hidden_state, cell_state
+
+ def generate_gru(
+ input_seqs, hidden_state, w_inp, w_hid, b_inp, b_hid, rz_act, n_act,
backwards=False
+ ):
+ """Implementation of GRU cell for paddlepaddle of TVM"""
+
+ h_list = []
+ seq_length = len(input_seqs)
+ for i in range(seq_length):
+ step = input_seqs[i] if not backwards else input_seqs[seq_length -
(i + 1)]
+ step = _op.squeeze(step, axis=[0])
+ xwt = _op.nn.dense(step, w_inp)
+ hwt = _op.nn.dense(hidden_state, w_hid)
+ if b_inp is not None:
+ xwt += b_inp
+ if b_hid is not None:
+ hwt += b_hid
+ i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
+ h_r, h_z, h_n = _op.split(hwt, 3, axis=-1)
+
+ r_gate = rz_act(i_r + h_r)
+ z_gate = rz_act(i_z + h_z)
+ n_gate = n_act(i_n + r_gate * h_n)
+
+ hidden_state = (hidden_state - n_gate) * z_gate + n_gate
+ h_list.append(_op.expand_dims(hidden_state, axis=0))
+
+ if backwards:
+ h_list = h_list[::-1]
+
+ # Concatenate outputs and add back in direction axis.
+ concatenated = _op.concatenate(h_list, 0)
+ output = _op.expand_dims(concatenated, axis=1)
+ hidden_state = _op.expand_dims(hidden_state, axis=0)
+
+ return output, hidden_state
+
+ def generate_simplernn(
+ input_seqs, hidden_state, w_inp, w_hid, b_inp, b_hid, n_act,
backwards=False
+ ):
+ """Implementation of SimpleRNN cell for paddlepaddle of TVM"""
+
+ h_list = []
+ seq_length = len(input_seqs)
+ for i in range(seq_length):
+ step = input_seqs[i] if not backwards else input_seqs[seq_length -
(i + 1)]
+ step = _op.squeeze(step, axis=[0])
+ xwt = _op.nn.dense(step, w_inp)
+ hwt = _op.nn.dense(hidden_state, w_hid)
+ if b_inp is not None:
+ xwt += b_inp
+ if b_hid is not None:
+ hwt += b_hid
+
+ n_gate = n_act(xwt + hwt)
+
+ hidden_state = n_gate
+ h_list.append(_op.expand_dims(hidden_state, axis=0))
+
+ if backwards:
+ h_list = h_list[::-1]
+
+ # Concatenate outputs and add back in direction axis.
+ concatenated = _op.concatenate(h_list, 0)
+ output = _op.expand_dims(concatenated, axis=1)
+ hidden_state = _op.expand_dims(hidden_state, axis=0)
+
+ return output, hidden_state
+
+ def make_param_inputs(g, node, layer, hidden_size, num_layers):
+ """Param for weight and bias."""
+
+ bidirect_len = 4 if node.attr("is_bidirec") else 2
+ all_layer_param_len = len(node.input("WeightList"))
+ weight_list = node.input("WeightList")[: all_layer_param_len // 2]
+ bias_list = node.input("WeightList")[all_layer_param_len // 2 :]
+
+ layer_weight_list = weight_list[layer * bidirect_len : layer *
bidirect_len + bidirect_len]
+ layer_bias_list = bias_list[layer * bidirect_len : layer *
bidirect_len + bidirect_len]
+ param_list = layer_weight_list + layer_bias_list
+ param_list_len = len(param_list)
+
+ input_weights = param_list[0 : param_list_len // 2 : 2]
+ hidden_weights = param_list[1 : param_list_len // 2 : 2]
+
+ input_bias = param_list[param_list_len // 2 : param_list_len : 2]
+ hidden_bias = param_list[param_list_len // 2 + 1 : param_list_len : 2]
+
+ return input_weights, hidden_weights, input_bias, hidden_bias
+
+ def make_init_param_inputs(g, node, layer):
+ """Init param for inputs."""
+
+ mode = node.attr("mode")
+ if mode == "LSTM":
+ all_init_h, all_init_c = node.input("PreState")
+ bidirect_len = 2 if node.attr("is_bidirec") else 1
+ init_h = _op.strided_slice(
+ g.get_node(all_init_h),
+ [layer * bidirect_len],
+ [layer * bidirect_len + bidirect_len],
+ axes=[0],
+ )
+ init_c = _op.strided_slice(
+ g.get_node(all_init_c),
+ [layer * bidirect_len],
+ [layer * bidirect_len + bidirect_len],
+ axes=[0],
+ )
+ return init_h, init_c
+ all_init_h = node.input("PreState")[0]
+ bidirect_len = 2 if node.attr("is_bidirec") else 1
+ init_h = _op.strided_slice(
+ g.get_node(all_init_h),
+ [layer * bidirect_len],
+ [layer * bidirect_len + bidirect_len],
+ axes=[0],
+ )
+ return init_h
+
+ hidden_size = op.attr("hidden_size")
+ num_layers = op.attr("num_layers")
+ is_bidirec = op.attr("is_bidirec")
+ mode = op.attr("mode")
+
+ input_x = g.get_node(op.input("Input")[0])
+
+ num_directions = 1
+ if is_bidirec:
+ num_directions = 2
+
+ x_shape = infer_shape(input_x)
+ time_steps = x_shape[0]
+ x_steps = _op.split(input_x, indices_or_sections=time_steps, axis=0)
+ for layer in range(num_layers):
+ input_weights, hidden_weights, input_bias, hidden_bias =
make_param_inputs(
+ g, op, layer, hidden_size, num_layers
+ )
+ if mode == "LSTM":
+ init_h, init_c = make_init_param_inputs(g, op, layer)
+ init_hs = _op.split(init_h, num_directions)
+ init_cs = _op.split(init_c, num_directions)
+ result_output = []
+ result_H = []
+ result_C = []
+ for i in range(num_directions):
+ H_t = _op.squeeze(init_hs[i], axis=[0])
+ C_t = _op.squeeze(init_cs[i], axis=[0])
+ W = g.get_node(input_weights[i])
+ R = g.get_node(hidden_weights[i])
+ WB = g.get_node(input_bias[i])
+ RB = g.get_node(hidden_bias[i])
+ output, H, C = generate_lstm(
+ input_seqs=x_steps,
+ hidden_state=H_t,
+ cell_state=C_t,
+ w_inp=W,
+ w_hid=R,
+ b_inp=WB,
+ b_hid=RB,
+ f_act=_op.sigmoid,
+ g_act=_op.tanh,
+ h_act=_op.tanh,
+ backwards=i == 1,
+ )
+ result_output.append(output)
+ result_H.append(H)
+ result_C.append(C)
+ output = _op.concatenate(result_output, axis=1)
+ H = _op.concatenate(result_H, axis=0)
+ C = _op.concatenate(result_C, axis=0)
+ elif mode == "GRU":
+ init_h = make_init_param_inputs(g, op, layer)
+ init_hs = _op.split(init_h, num_directions)
+ result_output = []
+ result_H = []
+ for i in range(num_directions):
+ H_t = _op.squeeze(init_hs[i], axis=[0])
+ W = g.get_node(input_weights[i])
+ R = g.get_node(hidden_weights[i])
+ WB = g.get_node(input_bias[i])
+ RB = g.get_node(hidden_bias[i])
+ output, H = generate_gru(
+ input_seqs=x_steps,
+ hidden_state=H_t,
+ w_inp=W,
+ w_hid=R,
+ b_inp=WB,
+ b_hid=RB,
+ rz_act=_op.sigmoid,
+ n_act=_op.tanh,
+ backwards=i == 1,
+ )
+ result_output.append(output)
+ result_H.append(H)
+ output = _op.concatenate(result_output, axis=1)
+ H = _op.concatenate(result_H, axis=0)
+ elif mode == "RNN_TANH":
+ init_h = make_init_param_inputs(g, op, layer)
+ init_hs = _op.split(init_h, num_directions)
+ result_output = []
+ result_H = []
+ for i in range(num_directions):
+ H_t = _op.squeeze(init_hs[i], axis=[0])
+ W = g.get_node(input_weights[i])
+ R = g.get_node(hidden_weights[i])
+ WB = g.get_node(input_bias[i])
+ RB = g.get_node(hidden_bias[i])
+ output, H = generate_simplernn(
+ input_seqs=x_steps,
+ hidden_state=H_t,
+ w_inp=W,
+ w_hid=R,
+ b_inp=WB,
+ b_hid=RB,
+ n_act=_op.tanh,
+ backwards=i == 1,
+ )
+ result_output.append(output)
+ result_H.append(H)
+ output = _op.concatenate(result_output, axis=1)
+ H = _op.concatenate(result_H, axis=0)
+
+ output = _op.transpose(output, axes=[0, 2, 1, 3])
+ output = _op.reshape(output, newshape=(0, 0, -1))
+ x_steps = _op.split(output, indices_or_sections=time_steps, axis=0)
+
+ g.add_node(op.output("Out")[0], output)
+
+
def convert_scale(g, op, block):
"""Operator converter for scale."""
@@ -1313,6 +1702,7 @@ _convert_map = {
"ceil": convert_unary_op,
"concat": convert_concat,
"conv2d": convert_conv2d,
+ "conv2d_transpose": convert_conv2d_transpose,
"cos": convert_unary_op,
"cosh": convert_unary_op,
"cumsum": convert_cumsum,
@@ -1337,6 +1727,7 @@ _convert_map = {
"feed": convert_feed,
"fill_any_like": convert_fill_any_like,
"fill_constant": convert_fill_constant,
+ "fill_constant_batch_size_like": convert_fill_constant_batch_size_like,
"flatten_contiguous_range": convert_flatten,
"floor": convert_unary_op,
"floor_mod": convert_elementwise_op,
@@ -1386,6 +1777,7 @@ _convert_map = {
"reduce_prod": convert_reduce,
"reduce_sum": convert_reduce,
"reduce_mean": convert_reduce,
+ "rnn": convert_rnn,
"rsqrt": convert_unary_op,
"scale": convert_scale,
"scatter": convert_scatter,
diff --git a/tests/python/frontend/paddlepaddle/test_forward.py
b/tests/python/frontend/paddlepaddle/test_forward.py
index e427d6f..9add8ad 100644
--- a/tests/python/frontend/paddlepaddle/test_forward.py
+++ b/tests/python/frontend/paddlepaddle/test_forward.py
@@ -461,6 +461,37 @@ def test_forward_conv():
@tvm.testing.uses_gpu
+def test_forward_conv_transpose():
+ class Conv2DTranspose(nn.Layer):
+ def __init__(self, stride=1, padding=0, dilation=1, groups=1,
padding_mode="zeros"):
+ super(Conv2DTranspose, self).__init__()
+ self.conv = nn.Conv2DTranspose(
+ 6,
+ 3,
+ 3,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.softmax = nn.Softmax()
+
+ @paddle.jit.to_static
+ def forward(self, inputs):
+ return self.softmax(self.conv(inputs))
+
+ input_shapes = [[1, 6, 10, 10], [2, 6, 8, 8]]
+
+ for input_shape in input_shapes:
+ input_data = paddle.rand(input_shape, dtype="float32")
+ verify_model(Conv2DTranspose(), input_data=input_data)
+ verify_model(Conv2DTranspose(stride=2, padding="VALID"),
input_data=input_data)
+ verify_model(Conv2DTranspose(stride=2, padding="SAME", dilation=1),
input_data=input_data)
+ verify_model(Conv2DTranspose(stride=2, padding=3),
input_data=input_data)
+ verify_model(Conv2DTranspose(stride=3, padding="SAME", groups=1),
input_data=input_data)
+
+
[email protected]_gpu
def test_forward_dot():
class Dot(nn.Layer):
@paddle.jit.to_static
@@ -839,12 +870,17 @@ def test_forward_interpolate():
input_data = paddle.rand([1, 2, 8, 12]).astype("float32")
verify_model(Interpolate(), input_data)
verify_model(Interpolate(use_list=True), input_data)
- verify_model(Interpolate(use_scale=True), input_data)
+ verify_model(Interpolate(use_scale=True, use_const=True), input_data)
verify_model(Interpolate("bilinear", use_scale=True), input_data)
verify_model(Interpolate("bilinear", use_scale=True, align_corners=True),
input_data)
verify_model(
Interpolate(
- "bilinear", use_scale=True, align_corners=True, align_mode=1,
data_format="NHWC"
+ "bilinear",
+ use_scale=True,
+ align_corners=True,
+ align_mode=1,
+ data_format="NHWC",
+ use_const=True,
),
input_data,
)
@@ -1284,5 +1320,34 @@ def test_forward_math_api():
verify_model(MathAPI(api_name), input_data=input_data)
[email protected]_gpu
+def test_forward_rnn():
+ class RNN(nn.Layer):
+ def __init__(self, api_name, input_size, hidden_size, num_layers,
direction="forward"):
+ super(RNN, self).__init__()
+ rnn_func = getattr(paddle.nn, api_name, None)
+ self.rnn = rnn_func(input_size, hidden_size, num_layers,
direction=direction)
+
+ @paddle.jit.to_static
+ def forward(self, inputs, prev_h):
+ y, h = self.rnn(inputs, prev_h)
+ return y
+
+ input_size, hidden_size, num_layers = 8, 16, 2
+ input_shape = [4, 5, 8]
+ input_data = paddle.rand(input_shape, dtype="float32")
+
+ for api_name in ("SimpleRNN", "GRU"):
+ prev_h = paddle.rand([4, 4, 16], dtype="float32")
+ verify_model(
+ RNN(api_name, input_size, hidden_size, num_layers,
direction="bidirectional"),
+ input_data=[input_data, prev_h],
+ )
+ prev_h = paddle.rand([2, 4, 16], dtype="float32")
+ verify_model(
+ RNN(api_name, input_size, hidden_size, num_layers),
input_data=[input_data, prev_h]
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])