This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5241c1b  Add Gluon Transformer Crop (#14259)
5241c1b is described below

commit 5241c1b5c1dc029d3796aed64f3740550928fafa
Author: Jake Lee <gstu1...@gmail.com>
AuthorDate: Thu Apr 4 07:15:28 2019 +0800

    Add Gluon Transformer Crop (#14259)
    
    * implement crop
    
    * add crop operator
    
    * fix for linter
    
    * add. backword and refactor the code
    
    * fix error namespace
    
    * fix the website build failure
    
    * start adding the unit test of backword
    
    * add unit test for backward
    
    * address the comment
    
    * add missing statement
    
    * fix the website error
    
    * fix the website building
    
    * add missing doc
---
 python/mxnet/gluon/data/vision/transforms.py    |  61 ++++++++
 python/mxnet/image/image.py                     |   2 +-
 src/operator/image/crop-inl.h                   | 190 ++++++++++++++++++++++++
 src/operator/image/crop.cc                      |  85 +++++++++++
 tests/python/unittest/test_gluon_data_vision.py |  75 +++++++++-
 5 files changed, 410 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py 
b/python/mxnet/gluon/data/vision/transforms.py
index 9310e15..dff7f66 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -228,6 +228,67 @@ class RandomResizedCrop(Block):
         return image.random_size_crop(x, *self._args)[0]
 
 
+class CropResize(HybridBlock):
+    r"""Crop the input image with and optionally resize it.
+
+    Makes a crop of the original image then optionally resize it to the 
specified size.
+
+    Parameters
+    ----------
+    x : int
+        Left boundary of the cropping area
+    y : int
+        Top boundary of the cropping area
+    w : int
+        Width of the cropping area
+    h : int
+        Height of the cropping area
+    size : int or tuple of (w, h)
+        Optional, resize to new size after cropping
+    interpolation : int, optional
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+        
https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=resize#resize
+        Note that the Resize on gpu use contrib.bilinearResize2D operator
+        which only support bilinear interpolation(1). The result would be 
slightly
+        different on gpu compared to cpu. OpenCV tend to align center while 
bilinearResize2D
+        use algorithm which aligns corner.
+
+
+    Inputs:
+        - **data**: input tensor with (H x W x C) or (N x H x W x C) shape.
+
+    Outputs:
+        - **out**: input tensor with (H x W x C) or (N x H x W x C) shape.
+
+    Examples
+    --------
+    >>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, 
height=100)
+    >>> image = mx.nd.random.uniform(0, 255, (224, 224, 
3)).astype(dtype=np.uint8)
+    >>> transformer(image)
+    <NDArray 100x100x3 @cpu(0)>
+    >>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 
3)).astype(dtype=np.uint8)
+    >>> transformer(image)
+    <NDArray 3x100x100x3 @cpu(0)>
+    >>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, 
height=100, size=(50, 50), interpolation=1)
+    >>> transformer(image)
+    <NDArray 3x50x50 @cpu(0)>
+    """
+    def __init__(self, x, y, width, height, size=None, interpolation=None):
+        super(CropResize, self).__init__()
+        self._x = x
+        self._y = y
+        self._width = width
+        self._height = height
+        self._size = size
+        self._interpolation = interpolation
+
+    def hybrid_forward(self, F, x):
+        out = F.image.crop(x, self._x, self._y, self._width, self._height)
+        if self._size:
+            out = F.image.resize(out, self._size, False, self._interpolation)
+        return out
+
 class CenterCrop(Block):
     """Crops the image `src` to the given `size` by trimming on all four
     sides and preserving the center of the image. Upsamples if `src` is
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index 1dd6656..d2631e8 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -428,7 +428,7 @@ def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
     NDArray
         An `NDArray` containing the cropped image.
     """
-    out = nd.crop(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, 
int(src.shape[2])))
+    out = nd.slice(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, 
int(src.shape[2])))
     if size is not None and (w, h) != size:
         sizes = (h, w, size[1], size[0])
         out = imresize(out, *size, interp=_get_interp_method(interp, sizes))
diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h
new file mode 100644
index 0000000..a1a4b23
--- /dev/null
+++ b/src/operator/image/crop-inl.h
@@ -0,0 +1,190 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file crop-inl.h
+ * \brief the image crop operator implementation
+ */
+
+#ifndef MXNET_OPERATOR_IMAGE_CROP_INL_H_
+#define MXNET_OPERATOR_IMAGE_CROP_INL_H_
+
+
+#include <algorithm>
+#include <vector>
+
+#include "mxnet/base.h"
+#include "dmlc/optional.h"
+#include "image_utils.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../../common/static_array.h"
+#include "../tensor/matrix_op-inl.h"
+#include "resize-inl.h"
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+struct CropParam : public dmlc::Parameter<CropParam> {
+  int x;
+  int y;
+  int width;
+  int height;
+  DMLC_DECLARE_PARAMETER(CropParam) {
+    DMLC_DECLARE_FIELD(x)
+    .describe("Left boundary of the cropping area.");
+    DMLC_DECLARE_FIELD(y)
+    .describe("Top boundary of the cropping area.");
+    DMLC_DECLARE_FIELD(width)
+    .describe("Width of the cropping area.");
+    DMLC_DECLARE_FIELD(height)
+    .describe("Height of the cropping area.");
+  }
+};
+
+inline bool CropShape(const nnvm::NodeAttrs& attrs,
+                             std::vector<TShape> *in_attrs,
+                             std::vector<TShape> *out_attrs) {
+  // input attrs should only be (h, w, c) or (n, h, w, c)
+  if (in_attrs->at(0).ndim() == 3U) {
+    CHECK((in_attrs->at(0)[2] == 1) || (in_attrs->at(0)[2] == 3))
+      << "Expect channel of the input image is 1 or 3, but got"
+      << in_attrs->at(0)[2];
+  } else if (in_attrs->at(0).ndim() == 4U) {
+    CHECK((in_attrs->at(0)[3] == 1) || (in_attrs->at(0)[3] == 3))
+      << "Expect channel of the input image is 1 or 3, but got"
+      << in_attrs->at(0)[3];
+  } else {
+    LOG(FATAL) << "Image Crop expects inputs of 3D (h, w, c) or 4D (n, h, w, 
c). But got "
+      << in_attrs->at(0).ndim();
+  }
+
+  const auto& ishape = (*in_attrs)[0];
+  const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
+
+  CHECK((param.height > 0) && (param.width > 0))
+    << "Input height and width must be greater than 0";
+  CHECK(param.x + param.width <= ishape[ishape.ndim() - 2])
+    << " x + width should not be greater than input width";
+  CHECK(param.y + param.height <= ishape[ishape.ndim() - 3])
+    << " y + height should not be greater than input height";
+  if (ishape.ndim() == 3) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, 
ishape[C]}));
+  } else {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({ishape[N], param.height, 
param.width, ishape[kC]}));
+  }
+  return true;
+}
+
+inline void CropImpl(int x,
+                      int y,
+                      int width,
+                      int height,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<TBlob> &outputs,
+                      const OpContext &ctx,
+                      const std::vector<OpReqType> &req) {
+  using namespace mshadow;
+  const TBlob& data = inputs[0];
+  const TBlob& out = outputs[0];
+  MXNET_NDIM_SWITCH(data.ndim(), ndim, {
+    Stream<cpu>* s = ctx.get_stream<cpu>();
+    common::StaticArray<index_t, ndim> begin = {0}, step = {1};
+    if (ndim == 3) {
+      begin[0] = y;
+      begin[1] = x;
+    } else {
+      begin[1] = y;
+      begin[2] = x;
+    }
+    MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        size_t num_threads = out.shape_.FlatTo2D()[0];
+        mxnet_op::Kernel<slice_forward<ndim, Req, cpu>, cpu>::Launch(s, 
num_threads,
+          out.dptr<DType>(), data.dptr<DType>(),
+          data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
+      })
+    })
+  })
+}
+
+inline void CropBackwardImpl(int x,
+                      int y,
+                      int width,
+                      int height,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<TBlob> &outputs,
+                      const OpContext &ctx,
+                      const std::vector<OpReqType> &req) {
+  using namespace mshadow;
+  if (req[0] == kNullOp) return;
+  const TBlob& output_grad = inputs[0];
+  const TBlob& input_grad = outputs[0];
+  Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (req[0] == kWriteTo) {
+    Fill(s, input_grad, req[0], 0);
+  } else if (req[0] == kWriteInplace) {
+    LOG(FATAL) << "_backward_image_crop does not support kWriteInplace";
+  }
+  MXNET_NDIM_SWITCH(output_grad.ndim(), ndim, {
+    common::StaticArray<index_t, ndim> begin = {0}, step = {1};
+    if (ndim == 3) {
+      begin[0] = y;
+      begin[1] = x;
+    } else {
+      begin[1] = y;
+      begin[2] = x;
+    }
+    MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        size_t num_threads = output_grad.shape_.FlatTo2D()[0];
+        mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, 
num_threads,
+          input_grad.dptr<DType>(), output_grad.dptr<DType>(),
+          input_grad.shape_.get<ndim>(), output_grad.shape_.get<ndim>(), 
begin, step);
+      })
+    })
+  })
+}
+
+inline void CropOpForward(const nnvm::NodeAttrs &attrs,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  CHECK_EQ(outputs.size(), 1U);
+  const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
+  CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, 
req);
+}
+
+inline void CropOpBackward(const nnvm::NodeAttrs &attrs,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  CHECK_EQ(outputs.size(), 1U);
+  const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
+  CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, 
outputs, ctx, req);
+}
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_IMAGE_CROP_INL_H_
diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc
new file mode 100644
index 0000000..52d2f11
--- /dev/null
+++ b/src/operator/image/crop.cc
@@ -0,0 +1,85 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file crop-cc.h
+ * \brief the image crop operator registration
+ */
+
+#include "mxnet/base.h"
+#include "crop-inl.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+DMLC_REGISTER_PARAMETER(CropParam);
+
+NNVM_REGISTER_OP(_image_crop)
+.describe(R"code(Crop an image NDArray of shape (H x W x C) or (N x H x W x C) 
+to the given size.
+Example:
+    .. code-block:: python
+        image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8)
+        mx.nd.image.crop(image, 1, 1, 2, 2)
+            [[[144  34   4]
+              [ 82 157  38]]
+
+             [[156 111 230]
+              [177  25  15]]]
+            <NDArray 2x2x3 @cpu(0)>
+        image = mx.nd.random.uniform(0, 255, (2, 4, 2, 
3)).astype(dtype=np.uint8)
+        mx.nd.image.crop(image, 1, 1, 2, 2)            
+            [[[[ 35 198  50]
+               [242  94 168]]
+
+              [[223 119 129]
+               [249  14 154]]]
+
+
+              [[[137 215 106]
+                [ 79 174 133]]
+
+               [[116 142 109]
+                [ 35 239  50]]]]
+            <NDArray 2x2x2x3 @cpu(0)>
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<CropParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", CropShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", CropOpForward)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ 
"_backward_image_crop" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(CropParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_image_crop)
+.set_attr_parser(ParamParser<CropParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", CropOpBackward);
+
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_gluon_data_vision.py 
b/tests/python/unittest/test_gluon_data_vision.py
index a855fc8..cc15bec 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -15,14 +15,16 @@
 # specific language governing permissions and limitations
 # under the License.
 from __future__ import print_function
+from collections import namedtuple
+
 import mxnet as mx
 import mxnet.ndarray as nd
 from mxnet.base import MXNetError
 from mxnet import gluon
 from mxnet.base import MXNetError
 from mxnet.gluon.data.vision import transforms
-from mxnet.test_utils import assert_almost_equal
-from mxnet.test_utils import almost_equal
+from mxnet import image
+from mxnet.test_utils import *
 from common import assertRaises, setup_module, with_seed, teardown
 
 import numpy as np
@@ -119,6 +121,75 @@ def test_resize():
 
 
 @with_seed()
+def test_crop_resize():
+    def _test_crop_resize_with_diff_type(dtype):
+        # test normal case
+        data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype)
+        out_nd = transforms.CropResize(0, 0, 3, 2)(data_in)
+        out_np = out_nd.asnumpy()
+        assert(out_np.sum() == 180)
+        assert((out_np[0:2,1,1].flatten() == [4, 16]).all())
+        # test 4D input
+        data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype(dtype)
+        out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in)
+        out_batch_np = out_batch_nd.asnumpy()
+        assert(out_batch_np.sum() == 7524)
+        assert((out_batch_np[0:2,0:4,1,1].flatten() == [37,  52,  67,  82, 
127, 142, 157, 172]).all())
+        # test normal case with resize
+        data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype)
+        out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in)
+        data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 
3)), 25, 25, 2)
+        assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
+        # test 4D input with resize
+        data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 
3)).astype(dtype)
+        out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 
2)(data_bath_in)
+        for i in range(len(out_batch_nd)):
+            assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 
0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(),
+                out_batch_nd[i].asnumpy())
+        # test with resize height and width should be greater than 0
+        transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2)
+        assertRaises(MXNetError, transformer, data_in)
+        # test height and width should be greater than 0 
+        transformer = transforms.CropResize(0, 0, -100, -50)
+        assertRaises(MXNetError, transformer, data_in)
+        # test cropped area is bigger than input data
+        transformer = transforms.CropResize(150, 200, 200, 500)
+        assertRaises(MXNetError, transformer, data_in)
+        assertRaises(MXNetError, transformer, data_bath_in)
+
+    for dtype in ['uint8', 'float32', 'float64']:
+        _test_crop_resize_with_diff_type(dtype)  
+
+    # test nd.image.crop backward
+    def test_crop_backward(test_nd_arr, TestCase):
+        a_np = test_nd_arr.asnumpy()
+        b_np = a_np[(slice(TestCase.y, TestCase.y + TestCase.height), 
slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))]
+
+        data = mx.sym.Variable('data')
+        crop_sym = mx.sym.image.crop(data, TestCase.x, TestCase.y, 
TestCase.width, TestCase.height)
+
+        expected_in_grad = np.zeros_like(a_np)
+        expected_in_grad[(slice(TestCase.y, TestCase.y + TestCase.height), 
slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))] = b_np
+        check_symbolic_backward(crop_sym, [a_np], [b_np], [expected_in_grad])
+
+    TestCase = namedtuple('TestCase', ['x', 'y', 'width', 'height'])
+    test_list = [TestCase(0, 0, 3, 3), TestCase(2, 1, 1, 2), TestCase(0, 1, 3, 
2)]
+
+    for dtype in ['uint8', 'float32', 'float64']:
+        data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype)
+        for test_case in test_list:
+            test_crop_backward(data_in, test_case)
+        
+
+
+    # check numeric gradient of nd.image.crop
+    # in_data = np.arange(36).reshape(3, 4, 3)
+    # data = mx.sym.Variable('data')
+    # image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2)
+    # check_numeric_gradient(image_crop_sym, [in_data])
+
+
+@with_seed()
 def test_flip_left_right():
     data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
     flip_in = data_in[:, ::-1, :]

Reply via email to