https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/76316
>From a5810363e546da073543cb2d62cceb956c46b2e6 Mon Sep 17 00:00:00 2001 From: Kai Sasaki <lewua...@gmail.com> Date: Fri, 15 Dec 2023 15:53:54 +0900 Subject: [PATCH 1/2] [mlir][complex] Prevent underflow in complex.abs --- .../ComplexToStandard/ComplexToStandard.cpp | 56 ++++++--- .../convert-to-standard.mlir | 115 ++++++++++++++---- .../ComplexToStandard/full-conversion.mlir | 25 +++- .../Dialect/Complex/CPU/correctness.mlir | 38 ++++++ 4 files changed, 194 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index bf753c7062f3664..7c1db57b55f996b 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -26,29 +26,57 @@ namespace mlir { using namespace mlir; namespace { +// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { using OpConversionPattern<complex::AbsOp>::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto type = op.getType(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = - rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); - Value imag = - rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); - Value realSqr = - rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue()); - Value imagSqr = - rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue()); - Value sqNorm = - rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue()); - - rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); + Type elementType = op.getType(); + Value arg = adaptor.getComplex(); + + Value zero = + b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); + Value one = b.create<arith::ConstantOp>(elementType, + b.getFloatAttr(elementType, 1.0)); + + Value real = b.create<complex::ReOp>(elementType, arg); + Value imag = b.create<complex::ImOp>(elementType, arg); + + Value realIsZero = + b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); + + // Real > Imag + Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue()); + Value imagSq = + b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue()); + Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue()); + Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue()); + Value absImag = b.create<arith::MulFOp>(imagSqrt, real, fmf.getValue()); + + // Real <= Imag + Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue()); + Value realSq = + b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue()); + Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue()); + Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue()); + Value absReal = b.create<arith::MulFOp>(realSqrt, imag, fmf.getValue()); + + rewriter.replaceOpWithNewOp<arith::SelectOp>( + op, realIsZero, imag, + b.create<arith::SelectOp>( + imagIsZero, real, + b.create<arith::SelectOp>( + b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag), + absImag, absReal))); + return success(); } }; diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 3af28150fd5c3f3..1028c9aae92c056 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -7,13 +7,28 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 { %abs = complex.abs %arg: complex<f32> return %abs : f32 } + +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: return %[[NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: return %[[ABS3]] : f32 // ----- @@ -241,12 +256,26 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> { %log = complex.log %arg: complex<f32> return %log : complex<f32> } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32> @@ -469,12 +498,26 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> { // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL2]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG2]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32> @@ -716,13 +759,27 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 { %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32> return %abs : f32 } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: return %[[NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: return %[[ABS3]] : f32 // ----- @@ -807,12 +864,26 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> { %log = complex.log %arg fastmath<nnan,contract> : complex<f32> return %log : complex<f32> } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32> diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir index 9983dd46f094334..d710dc8e1adeb7c 100644 --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -6,12 +6,29 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 { %abs = complex.abs %arg: complex<f32> return %abs : f32 } +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] -// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 -// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 +// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32 + +// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32 +// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL]] : f32 + +// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32 +// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG]] : f32 + +// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32 +// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32 +// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32 // CHECK: llvm.return %[[NORM]] : f32 // CHECK-LABEL: llvm.func @complex_eq diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index 349b92a7aefa2e3..b7849945b3cf498 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 { func.return %angle : f32 } +func.func @test_element_f64(%input: tensor<?xcomplex<f64>>, + %func: (complex<f64>) -> f64) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>> + + %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64 + vector.print %val : f64 + scf.yield + } + func.return +} + +func.func @abs(%arg: complex<f64>) -> f64 { + %abs = complex.abs %arg : complex<f64> + func.return %abs : f64 +} + func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -300,5 +321,22 @@ func.func @entry() { call @test_element(%angle_test_cast, %angle_func) : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> () + // complex.abs test + %abs_test = arith.constant dense<[ + (1.0, 1.0), + // CHECK: 1.414 + (1.0e300, 1.0e300), + // CHECK-NEXT: 1.41421e+300 + (1.0e-300, 1.0e-300) + // CHECK-NEXT: 1.41421e-300 + ]> : tensor<3xcomplex<f64>> + %abs_test_cast = tensor.cast %abs_test + : tensor<3xcomplex<f64>> to tensor<?xcomplex<f64>> + + %abs_func = func.constant @abs : (complex<f64>) -> f64 + + call @test_element_f64(%abs_test_cast, %abs_func) + : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> () + func.return } >From c4e8c8b2c67933896f2f42a86fc994811011b6ad Mon Sep 17 00:00:00 2001 From: Kai Sasaki <lewua...@gmail.com> Date: Wed, 24 Jan 2024 09:22:50 +0900 Subject: [PATCH 2/2] [mlir][complex] Add test case to check zero element handling --- .../Integration/Dialect/Complex/CPU/correctness.mlir | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index b7849945b3cf498..c8327e94def8abf 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -327,11 +327,17 @@ func.func @entry() { // CHECK: 1.414 (1.0e300, 1.0e300), // CHECK-NEXT: 1.41421e+300 - (1.0e-300, 1.0e-300) + (1.0e-300, 1.0e-300), // CHECK-NEXT: 1.41421e-300 - ]> : tensor<3xcomplex<f64>> + (5.0, 0.0), + // CHECK-NEXT: 5 + (0.0, 6.0), + // CHECK-NEXT: 6 + (7.0, 8.0) + // CHECK-NEXT: 10.6301 + ]> : tensor<6xcomplex<f64>> %abs_test_cast = tensor.cast %abs_test - : tensor<3xcomplex<f64>> to tensor<?xcomplex<f64>> + : tensor<6xcomplex<f64>> to tensor<?xcomplex<f64>> %abs_func = func.constant @abs : (complex<f64>) -> f64 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits