[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 9 13:44:29 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This PR adds vector type support to `math-to-apfloat`. It also adds `supported-types` (matching the convention/semantics for the pass arg established by `-arith-emulate-unsupported-floats`) pass arguments to both `arith-to-apfloat` and `math-to-apfloat` to filter down which types will be converted. Note, by default (i.e., empty `supported-types`) all `fp` types will be converted (i.e., `empty` -> `convert all`). 

TODO: add lit tests

---

Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172715.diff


7 Files Affected:

- (modified) mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h (+1) 
- (modified) mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+8) 
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp (+54-97) 
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp (+123-82) 
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp (+45-2) 
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h (+61) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
index 6702aca045ba4..2dacc2e11b049 100644
--- a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
 #define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
 
+#include "llvm/ADT/SmallVector.h"
 #include <memory>
 
 namespace mlir {
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
index 6cb44c89ecebb..06548c250a27b 100644
--- a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
 #define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
 
+#include "llvm/ADT/SmallVector.h"
 #include <memory>
 
 namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..fb2860bee6d43 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,6 +198,10 @@ def ArithToAPFloatConversionPass
     calls (APFloatWrappers.cpp). APFloat is a software implementation of
     floating-point arithmetic operations.
   }];
+  let options = [
+    ListOption<"sourceTypeStrs", "source-types", "std::string",
+      "MLIR types without arithmetic support on a given target">,
+  ];
   let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
                            "vector::VectorDialect"];
 }
@@ -787,6 +791,10 @@ def MathToAPFloatConversionPass
     calls (APFloatWrappers.cpp). APFloat is a software implementation of
     floating-point mathmetic operations.
   }];
+  let options = [
+    ListOption<"sourceTypeStrs", "source-types", "std::string",
+      "MLIR types without arithmetic support on a given target">,
+  ];
   let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
 }
 
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 813a854f2fc97..52eb32de6586b 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -46,86 +46,20 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                               {i32Type, i64Type, i64Type}, symbolTables);
 }
 
-/// Given two operands of vector type and vector result type (with the same
-/// shape), call the given function for each pair of scalar operands and
-/// package the result into a vector. If the given operands and result type are
-/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
-static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
-                                Value operand1, Value operand2, Type resultType,
-                                Fn fn) {
-  auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
-  if (operand2) {
-    // Sanity check: Operand types must match.
-    assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
-           "expected same vector types");
-  }
-  if (!vecTy1) {
-    // Not a vector. Call the function directly.
-    return fn(operand1, operand2, resultType);
-  }
-
-  // Prepare scalar operands.
-  ResultRange sclars1 =
-      vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
-  SmallVector<Value> scalars2;
-  if (!operand2) {
-    // No second operand. Create a vector of empty values.
-    scalars2.assign(vecTy1.getNumElements(), Value());
-  } else {
-    llvm::append_range(
-        scalars2,
-        vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
-  }
-
-  // Call the function for each pair of scalar operands.
-  auto resultVecType = cast<VectorType>(resultType);
-  SmallVector<Value> results;
-  for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
-    Value result = fn(scalar1, scalar2, resultVecType.getElementType());
-    results.push_back(result);
-  }
-
-  // Package the results into a vector.
-  return vector::FromElementsOp::create(
-      rewriter, loc,
-      vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
-      results);
-}
-
-/// Check preconditions for the conversion:
-/// 1. All operands / results must be integers or floats (or vectors thereof).
-/// 2. The bitwidth of the operands / results must be <= 64.
-static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
-  for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
-    Type type = value.getType();
-    if (auto vecTy = dyn_cast<VectorType>(type)) {
-      type = vecTy.getElementType();
-    }
-    if (!type.isIntOrFloat()) {
-      return rewriter.notifyMatchFailure(
-          op, "only integers and floats (or vectors thereof) are supported");
-    }
-    if (type.getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-  }
-  return success();
-}
-
 /// Rewrite a binary arithmetic operation to an APFloat function call.
 template <typename OpTy>
 struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
   BinaryArithOpToAPFloatConversion(MLIRContext *context,
                                    const char *APFloatName,
                                    SymbolOpInterface symTable,
+                                   ArrayRef<Type> sourceTypes,
                                    PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
-        APFloatName(APFloatName) {};
+        APFloatName(APFloatName), sourceTypes(sourceTypes) {};
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(checkPreconditions(rewriter, op)))
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
       return failure();
 
     // Get APFloat function from runtime library.
@@ -170,17 +104,19 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
 
   SymbolOpInterface symTable;
   const char *APFloatName;
+  ArrayRef<Type> sourceTypes;
 };
 
 template <typename OpTy>
 struct FpToFpConversion final : OpRewritePattern<OpTy> {
   FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
-                   PatternBenefit benefit = 1)
-      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
+                   ArrayRef<Type> sourceTypes, PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+        sourceTypes(sourceTypes) {}
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(checkPreconditions(rewriter, op)))
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
       return failure();
 
     // Get APFloat function from runtime library.
@@ -227,18 +163,20 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
   }
 
   SymbolOpInterface symTable;
+  ArrayRef<Type> sourceTypes;
 };
 
 template <typename OpTy>
 struct FpToIntConversion final : OpRewritePattern<OpTy> {
   FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
-                    bool isUnsigned, PatternBenefit benefit = 1)
+                    bool isUnsigned, ArrayRef<Type> sourceTypes,
+                    PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
-        isUnsigned(isUnsigned) {}
+        isUnsigned(isUnsigned), sourceTypes(sourceTypes) {}
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(checkPreconditions(rewriter, op)))
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
       return failure();
 
     // Get APFloat function from runtime library.
@@ -289,18 +227,20 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
 
   SymbolOpInterface symTable;
   bool isUnsigned;
+  ArrayRef<Type> sourceTypes;
 };
 
 template <typename OpTy>
 struct IntToFpConversion final : OpRewritePattern<OpTy> {
   IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
-                    bool isUnsigned, PatternBenefit benefit = 1)
+                    bool isUnsigned, ArrayRef<Type> sourceTypes,
+                    PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
-        isUnsigned(isUnsigned) {}
+        isUnsigned(isUnsigned), sourceTypes(sourceTypes) {}
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(checkPreconditions(rewriter, op)))
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
       return failure();
 
     // Get APFloat function from runtime library.
@@ -361,16 +301,19 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
 
   SymbolOpInterface symTable;
   bool isUnsigned;
+  ArrayRef<Type> sourceTypes;
 };
 
 struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
   CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            ArrayRef<Type> sourceTypes,
                             PatternBenefit benefit = 1)
-      : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+      : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable),
+        sourceTypes(sourceTypes) {}
 
   LogicalResult matchAndRewrite(arith::CmpFOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(checkPreconditions(rewriter, op)))
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
       return failure();
 
     // Get APFloat function from runtime library.
@@ -512,16 +455,19 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
   }
 
   SymbolOpInterface symTable;
+  ArrayRef<Type> sourceTypes;
 };
 
 struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
   NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            ArrayRef<Type> sourceTypes,
                             PatternBenefit benefit = 1)
-      : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+      : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable),
+        sourceTypes(sourceTypes) {}
 
   LogicalResult matchAndRewrite(arith::NegFOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(checkPreconditions(rewriter, op)))
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
       return failure();
 
     // Get APFloat function from runtime library.
@@ -564,6 +510,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
   }
 
   SymbolOpInterface symTable;
+  ArrayRef<Type> sourceTypes;
 };
 
 namespace {
@@ -577,36 +524,46 @@ struct ArithToAPFloatConversionPass final
 void ArithToAPFloatConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   RewritePatternSet patterns(context);
-  patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
-                                                                getOperation());
+
+  FailureOr<SmallVector<Type>> sourceTypes =
+      parseSourceTypes(llvm::to_vector(sourceTypeStrs), context);
+  if (failed(sourceTypes))
+    return signalPassFailure();
+
+  patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(
+      context, "add", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
-      context, "subtract", getOperation());
+      context, "subtract", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
-      context, "multiply", getOperation());
+      context, "multiply", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
-      context, "divide", getOperation());
+      context, "divide", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
-      context, "remainder", getOperation());
+      context, "remainder", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
-      context, "minnum", getOperation());
+      context, "minnum", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
-      context, "maxnum", getOperation());
+      context, "maxnum", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
-      context, "minimum", getOperation());
+      context, "minimum", getOperation(), *sourceTypes);
   patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
-      context, "maximum", getOperation());
+      context, "maximum", getOperation(), *sourceTypes);
   patterns
       .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
            CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
-          context, getOperation());
+          context, getOperation(), *sourceTypes);
   patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
-                                                   /*isUnsigned=*/false);
+                                                   /*isUnsigned=*/false,
+                                                   *sourceTypes);
   patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
-                                                   /*isUnsigned=*/true);
+                                                   /*isUnsigned=*/true,
+                                                   *sourceTypes);
   patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
-                                                   /*isUnsigned=*/false);
+                                                   /*isUnsigned=*/false,
+                                                   *sourceTypes);
   patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
-                                                   /*isUnsigned=*/true);
+                                                   /*isUnsigned=*/true,
+                                                   *sourceTypes);
   LogicalResult result = success();
   ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
     if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 784028f5cf2eb..b5e15e5c42bed 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -28,21 +28,15 @@ using namespace mlir::func;
 
 struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
   AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            ArrayRef<Type> sourceTypes,
                             PatternBenefit benefit = 1)
-      : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+      : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable),
+        sourceTypes(sourceTypes) {}
 
   LogicalResult matchAndRewrite(math::AbsFOp op,
                                 PatternRewriter &rewriter) const override {
-    // Cast operands to 64-bit integers.
-    auto operand = op.getOperand();
-    auto floatTy = dyn_cast<FloatType>(operand.getType());
-    if (!floatTy)
-      return rewriter.notifyMatchFailure(op,
-                                         "only scalar FloatTypes supported");
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
+      return failure();
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -52,49 +46,50 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
       return fn;
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    Value operandBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
-    // Call APFloat function.
-    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, operandBits};
-    Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
-                                             SymbolRefAttr::get(*fn), params)
-                            ->getResult(0);
-
-    // Truncate result to the original width.
-    Value truncatedBits =
-        arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
-    rewriter.replaceOp(
-        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    // Scalarize and convert to APFloat runtime calls.
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand, Value, Type resultType) {
+          auto floatTy = cast<FloatType>(operand.getType());
+          auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+          Value operandBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, operand));
+          // Call APFloat function.
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+          SmallVector<Value> params = {semValue, operandBits};
+          Value negatedBits =
+              func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                                   SymbolRefAttr::get(*fn), params)
+                  ->getResult(0);
+          // Truncate result to the original width.
+          auto truncatedBits =
+              arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+          return arith::BitcastOp::create(rewriter, loc, floatTy,
+                                          truncatedBits);
+        });
+
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
   SymbolOpInterface symTable;
+  ArrayRef<Type> sourceTypes;
 };
 
 template <typename OpTy>
 struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
   IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
                           SymbolOpInterface symTable,
+                          ArrayRef<Type> sourceTypes,
                           PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
-        APFloatName(APFloatName) {};
+        APFloatName(APFloatName), sourceTypes(sourceTypes) {};
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    // Cast operands to 64-bit integers.
-    auto operand = op.getOperand();
-    auto floatTy = dyn_cast<FloatType>(operand.getType());
-    if (!floatTy)
-      return rewriter.notifyMatchFailure(op,
-                                         "only scalar FloatTypes supported");
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
+    if (failed(checkPreconditions(rewriter, op, sourceTypes)))
+      return failure();
     // Get APFloat function from runtime library.
     auto i1 = IntegerType::get(symTable->getContext(), 1);
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172715


More information about the Mlir-commits mailing list