Skip to content

Commit 044f81f

Browse files
committed
implement round ste
1 parent ada4ea1 commit 044f81f

File tree

6 files changed

+60
-0
lines changed

6 files changed

+60
-0
lines changed

python/mxnet/ndarray/ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,14 @@ def round(self, *args, **kwargs):
14061406
"""
14071407
return op.round(self, *args, **kwargs)
14081408

1409+
def round_ste(self, *args, **kwargs):
1410+
"""Convenience fluent method for :py:func:`round_ste`.
1411+
1412+
The arguments are the same as for :py:func:`round_ste`, with
1413+
this array as data.
1414+
"""
1415+
return op.round_ste(self, *args, **kwargs)
1416+
14091417
def rint(self, *args, **kwargs):
14101418
"""Convenience fluent method for :py:func:`rint`.
14111419

python/mxnet/symbol/symbol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,14 @@ def round(self, *args, **kwargs):
21502150
"""
21512151
return op.round(self, *args, **kwargs)
21522152

2153+
def round_ste(self, *args, **kwargs):
2154+
"""Convenience fluent method for :py:func:`round_ste`.
2155+
2156+
The arguments are the same as for :py:func:`round_ste`, with
2157+
this array as data.
2158+
"""
2159+
return op.round_ste(self, *args, **kwargs)
2160+
21532161
def rint(self, *args, **kwargs):
21542162
"""Convenience fluent method for :py:func:`rint`.
21552163

smd_hpi/tests/test_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,18 @@ def test_det_sign():
1515
assert_almost_equal(exp_y, y.asnumpy())
1616
y.backward()
1717
assert_almost_equal(exp_grad, x.grad.asnumpy())
18+
19+
20+
def test_round_ste():
21+
npy = np.random.uniform(-10, 10, (2, 3, 4))
22+
23+
exp_y = np.round(npy)
24+
exp_grad = np.ones_like(npy)
25+
26+
x = mx.nd.array(npy)
27+
x.attach_grad()
28+
with autograd.record():
29+
y = x.round_ste()
30+
assert_almost_equal(exp_y, y.asnumpy())
31+
y.backward()
32+
assert_almost_equal(exp_grad, x.grad.asnumpy())

src/operator/operator_tune.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad); // NOLINT()
273273
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
274274
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
275275
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round); // NOLINT()
276+
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::round); // NOLINT()
276277
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor); // NOLINT()
277278
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc); // NOLINT()
278279
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint); // NOLINT()

src/operator/tensor/elemwise_unary_op_basic.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,25 @@ The storage type of ``round`` output depends upon the input storage type:
723723
)code" ADD_FILELINE)
724724
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
725725

726+
// round_ste
727+
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP(round_ste, cpu, mshadow_op::round)
728+
MXNET_ADD_SPARSE_OP_ALIAS(round_ste)
729+
.describe(R"code(Returns element-wise rounded value to the nearest integer of the input but with STE.
730+
731+
Example::
732+
733+
round_ste([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 2., -2., 2., 2.]
734+
735+
The storage type of ``round_ste`` output depends upon the input storage type:
736+
737+
- round_ste(default) = default
738+
- round_ste(row_sparse) = row_sparse
739+
740+
)code" ADD_FILELINE)
741+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_round_ste"});
742+
743+
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_round_ste, unary_bwd<mshadow_op::identity_grad>);
744+
726745
// rint
727746
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(rint, cpu, mshadow_op::rint)
728747
.describe(R"code(Returns element-wise rounded value to the nearest integer of the input.

src/operator/tensor/elemwise_unary_op_basic.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ NNVM_REGISTER_OP(round)
173173
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)
174174
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::round>);
175175

176+
// round_ste
177+
NNVM_REGISTER_OP(round_ste)
178+
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)
179+
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::round>);
180+
181+
NNVM_REGISTER_OP(_backward_round_ste)
182+
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
183+
gpu, unary_bwd<mshadow_op::identity_grad> >);
184+
176185
// ceil
177186
NNVM_REGISTER_OP(ceil)
178187
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::ceil>)

0 commit comments

Comments
 (0)