https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/169759
Add support for `arith.negf`. >From 410d05f5a41910efd8dca6c0f031a34efd677247 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Thu, 27 Nov 2025 04:55:05 +0000 Subject: [PATCH] [mlir][arith] Add support for `negf` to `ArithToAPFloat` --- .../ArithToAPFloat/ArithToAPFloat.cpp | 44 +++++++++++++++++++ mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 9 ++++ .../ArithToApfloat/arith-to-apfloat.mlir | 10 +++++ .../Arith/CPU/test-apfloat-emulation.mlir | 4 ++ 4 files changed, 67 insertions(+) diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 566632bd8707f..230abb51e8158 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -449,6 +449,49 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { SymbolOpInterface symTable; }; +struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> { + NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::NegFOp op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Cast operand to 64-bit integer. + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto floatTy = cast<FloatType>(op.getOperand().getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, op.getOperand())); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, operandBits}; + Value negatedBits = + func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + negatedBits); + Value result = + arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits); + rewriter.replaceOp(op, result); + return success(); + } + + SymbolOpInterface symTable; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { @@ -482,6 +525,7 @@ void ArithToAPFloatConversionPass::runOnOperation() { patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), /*isUnsigned=*/true); patterns.add<CmpFOpToAPFloatConversion>(context, getOperation()); + patterns.add<NegFOpToAPFloatConversion>(context, getOperation()); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 77f7137264888..f2d5254be6b57 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -142,4 +142,13 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); return static_cast<int8_t>(x.compare(y)); } + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + x.changeSign(); + return x.bitcastToAPInt().getZExtValue(); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index 78ce3640ecc67..775cb5ea60f22 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -213,3 +213,13 @@ func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64 +func.func @negf(%arg0: f32) { + %0 = arith.negf %arg0 : f32 + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 433d058d025cf..555cc9a531966 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -43,6 +43,10 @@ func.func @entry() { %cvt = arith.truncf %b2 : f32 to f8E4M3FN vector.print %cvt : f8E4M3FN + // CHECK-NEXT: -2.25 + %negated = arith.negf %cvt : f8E4M3FN + vector.print %negated : f8E4M3FN + // CHECK-NEXT: 1 %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN vector.print %cmp1 : i1 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
