[Mlir-commits] [mlir] Reapply "[mlir][math] Add FP software implementation lowering pass: math-to-apfloat" (#172714) (PR #172716)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 17 12:09:42 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-func

@llvm/pr-subscribers-mlir-execution-engine

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

Fix builder by linking `MLIRTransformUtils`. Also move headers to `mlir/Conversion/ArithAndMathToAPFloat`.

---

Patch is 44.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172716.diff


18 Files Affected:

- (renamed) mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h (+3-3) 
- (added) mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h (+21) 
- (modified) mlir/include/mlir/Conversion/Passes.h (+2-1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+15) 
- (modified) mlir/include/mlir/Dialect/Func/Utils/Utils.h (+16) 
- (renamed) mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp (+26-68) 
- (added) mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt (+50) 
- (added) mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp (+219) 
- (added) mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp (+22) 
- (added) mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h (+21) 
- (removed) mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt (-19) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1-1) 
- (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+5-2) 
- (modified) mlir/lib/Dialect/Func/Utils/Utils.cpp (+39) 
- (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+64-1) 
- (renamed) mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir () 
- (added) mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir (+66) 
- (added) mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir (+68) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
similarity index 73%
rename from mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
rename to mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
index 64a42a228199e..6702aca045ba4 100644
--- a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
-#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
 
 #include <memory>
 
@@ -18,4 +18,4 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..6cb44c89ecebb
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..7c2b450ca6710 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -11,8 +11,9 @@
 
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h"
+#include "mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
-#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+    : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+  let summary = "Convert Math ops to APFloat runtime library calls";
+  let description = [{
+    This pass converts supported Math ops to APFloat-based runtime library
+    calls (APFloatWrappers.cpp). APFloat is a software implementation of
+    floating-point mathmetic operations.
+  }];
+  let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
                                FunctionType funcT,
                                SymbolTableCollection *symbolTables = nullptr);
 
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                    FunctionType funcT, bool setPrivate,
+                    SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+                     TypeRange paramTypes,
+                     SymbolTableCollection *symbolTables = nullptr,
+                     Type resultType = {});
+
 } // namespace func
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..813a854f2fc97 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -6,8 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
 
+#include "mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -25,47 +26,6 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::func;
 
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
-                           StringRef name, FunctionType funcT, bool setPrivate,
-                           SymbolTableCollection *symbolTables = nullptr) {
-  OpBuilder::InsertionGuard g(b);
-  assert(!symTable->getRegion(0).empty() && "expected non-empty region");
-  b.setInsertionPointToStart(&symTable->getRegion(0).front());
-  FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
-  if (setPrivate)
-    funcOp.setPrivate();
-  if (symbolTables) {
-    SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
-    symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
-  }
-  return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
-                        StringRef name, TypeRange paramTypes,
-                        SymbolTableCollection *symbolTables = nullptr,
-                        Type resultType = {}) {
-  if (!resultType)
-    resultType = IntegerType::get(symTable->getContext(), 64);
-  std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
-  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
-  FailureOr<FuncOp> func =
-      lookupFnDecl(symTable, funcName, funcT, symbolTables);
-  // Failed due to type mismatch.
-  if (failed(func))
-    return func;
-  // Successfully matched existing decl.
-  if (*func)
-    return *func;
-
-  return createFnDecl(b, symTable, funcName, funcT,
-                      /*setPrivate=*/true, symbolTables);
-}
-
 /// Helper function to look up or create the symbol for a runtime library
 /// function for a binary arithmetic operation.
 ///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                        SymbolTableCollection *symbolTables = nullptr) {
   auto i32Type = IntegerType::get(symTable->getContext(), 32);
   auto i64Type = IntegerType::get(symTable->getContext(), 64);
-  return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
-                                 symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
-  int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
-  return arith::ConstantOp::create(b, loc, b.getI32Type(),
-                                   b.getIntegerAttr(b.getI32Type(), sem));
+  std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+  return lookupOrCreateFnDecl(b, symTable, funcName,
+                              {i32Type, i64Type, i64Type}, symbolTables);
 }
 
 /// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
               arith::BitcastOp::create(rewriter, loc, intWType, rhs));
 
           // Call APFloat function.
-          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, lhsBits, rhsBits};
           auto resultOp = func::CallOp::create(rewriter, loc,
                                                TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
-        rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+    FailureOr<FuncOp> fn =
+        lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+                             {i32Type, i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
               arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
 
           // Call APFloat function.
-          Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+          Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
           auto outFloatTy = cast<FloatType>(resultType);
-          Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+          Value outSemValue =
+              getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
           std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
           auto resultOp = func::CallOp::create(rewriter, loc,
                                                TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
-                                {i32Type, i32Type, i1Type, i64Type});
+        lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+                             {i32Type, i32Type, i1Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
               arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
 
           // Call APFloat function.
-          Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+          Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
           auto outIntTy = cast<IntegerType>(resultType);
           Value outWidthValue = arith::ConstantOp::create(
               rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
     auto i1Type = IntegerType::get(symTable->getContext(), 1);
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
-                                {i32Type, i32Type, i1Type, i64Type});
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_convert_from_int",
+        {i32Type, i32Type, i1Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
 
           // Call APFloat function.
           auto outFloatTy = cast<FloatType>(resultType);
-          Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+          Value outSemValue =
+              getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
           Value inWidthValue = arith::ConstantOp::create(
               rewriter, loc, i32Type,
               rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "compare",
-                                {i32Type, i64Type, i64Type}, nullptr, i8Type);
+        lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+                             {i32Type, i64Type, i64Type}, nullptr, i8Type);
     if (failed(fn))
       return fn;
 
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
               arith::BitcastOp::create(rewriter, loc, intWType, rhs));
 
           // Call APFloat function.
-          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, lhsBits, rhsBits};
           Value comparisonResult =
               func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
-    FailureOr<FuncOp> fn =
-        lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
     if (failed(fn))
       return fn;
 
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
               arith::BitcastOp::create(rewriter, loc, intWType, operand1));
 
           // Call APFloat function.
-          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
           SmallVector<Value> params = {semValue, operandBits};
           Value negatedBits =
               func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..e8fd9c493b975
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,50 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+  Utils.cpp
+  PARTIAL_SOURCES_INTENDED
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  )
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+  ArithToAPFloat.cpp
+  PARTIAL_SOURCES_INTENDED
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  ArithAndMathToAPFloatUtils
+  MLIRArithDialect
+  MLIRArithTransforms
+  MLIRFuncDialect
+  MLIRFuncUtils
+  MLIRVectorDialect
+  )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+  MathToAPFloat.cpp
+  PARTIAL_SOURCES_INTENDED
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRTransformUtils
+  ArithAndMathToAPFloatUtils
+  MLIRMathDialect
+  MLIRFuncDialect
+  MLIRFuncUtils
+  )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..784028f5cf2eb
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,219 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Utils.h"
+
+#include "mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+  AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+  LogicalResult matchAndRewrite(math::AbsFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Cast operands to 64-bit integers.
+    auto operand = op.getOperand();
+    auto floatTy = dyn_cast<FloatType>(operand.getType());
+    if (!floatTy)
+      return rewriter.notifyMatchFailure(op,
+                                         "only scalar FloatTypes supported");
+    if (floatTy.getIntOrFloatBitWidth() > 64) {
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+    }
+    // Get APFloat function from runtime library.
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+        rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+    if (failed(fn))
+      return fn;
+    Location loc = op.getLoc();
+    rewriter.setInsertionPoint(op);
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+    // Call APFloat function.
+    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operandBits};
+    Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                                             SymbolRefAttr::get(*fn), params)
+                            ->getResult(0);
+
+    // Truncate result to the original width.
+    Value truncatedBits =
+        arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+    rewriter.replaceOp(
+        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172716


More information about the Mlir-commits mailing list