[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)

Maksim Levental llvmlistbot at llvm.org
Fri Jan 9 10:47:51 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172715

>From 356f1519aa2bc393c07e4f5fbb38020d88a6705c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 10:53:55 -0800
Subject: [PATCH 1/3] [mlir][math] Add vector support for math-to-apfloat

---
 .../ArithAndMathToAPFloat/ArithToAPFloat.cpp  | 67 -------------------
 .../ArithAndMathToAPFloat/Utils.cpp           | 25 ++++++-
 .../Conversion/ArithAndMathToAPFloat/Utils.h  | 56 ++++++++++++++++
 3 files changed, 79 insertions(+), 69 deletions(-)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 813a854f2fc97..98185697e4591 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -46,73 +46,6 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
                               {i32Type, i64Type, i64Type}, symbolTables);
 }
 
-/// Given two operands of vector type and vector result type (with the same
-/// shape), call the given function for each pair of scalar operands and
-/// package the result into a vector. If the given operands and result type are
-/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
-static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
-                                Value operand1, Value operand2, Type resultType,
-                                Fn fn) {
-  auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
-  if (operand2) {
-    // Sanity check: Operand types must match.
-    assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
-           "expected same vector types");
-  }
-  if (!vecTy1) {
-    // Not a vector. Call the function directly.
-    return fn(operand1, operand2, resultType);
-  }
-
-  // Prepare scalar operands.
-  ResultRange sclars1 =
-      vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
-  SmallVector<Value> scalars2;
-  if (!operand2) {
-    // No second operand. Create a vector of empty values.
-    scalars2.assign(vecTy1.getNumElements(), Value());
-  } else {
-    llvm::append_range(
-        scalars2,
-        vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
-  }
-
-  // Call the function for each pair of scalar operands.
-  auto resultVecType = cast<VectorType>(resultType);
-  SmallVector<Value> results;
-  for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
-    Value result = fn(scalar1, scalar2, resultVecType.getElementType());
-    results.push_back(result);
-  }
-
-  // Package the results into a vector.
-  return vector::FromElementsOp::create(
-      rewriter, loc,
-      vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
-      results);
-}
-
-/// Check preconditions for the conversion:
-/// 1. All operands / results must be integers or floats (or vectors thereof).
-/// 2. The bitwidth of the operands / results must be <= 64.
-static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
-  for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
-    Type type = value.getType();
-    if (auto vecTy = dyn_cast<VectorType>(type)) {
-      type = vecTy.getElementType();
-    }
-    if (!type.isIntOrFloat()) {
-      return rewriter.notifyMatchFailure(
-          op, "only integers and floats (or vectors thereof) are supported");
-    }
-    if (type.getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-  }
-  return success();
-}
-
 /// Rewrite a binary arithmetic operation to an APFloat function call.
 template <typename OpTy>
 struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
index 2b5857367dc40..340e015404d86 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -9,14 +9,35 @@
 #include "Utils.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
 
-mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
-                                           FloatType floatTy) {
+using namespace mlir;
+
+Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
+                                     FloatType floatTy) {
   int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
   return arith::ConstantOp::create(b, loc, b.getI32Type(),
                                    b.getIntegerAttr(b.getI32Type(), sem));
 }
+
+LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
+  for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+    Type type = value.getType();
+    if (auto vecTy = dyn_cast<VectorType>(type)) {
+      type = vecTy.getElementType();
+    }
+    if (!type.isIntOrFloat()) {
+      return rewriter.notifyMatchFailure(
+          op, "only integers and floats (or vectors thereof) are supported");
+    }
+    if (type.getIntOrFloatBitWidth() > 64)
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+  }
+  return success();
+}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
index 5f11d24261b43..d38d3b4c93945 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -9,6 +9,9 @@
 #ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
 #define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
 
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 class Value;
 class OpBuilder;
@@ -16,6 +19,59 @@ class Location;
 class FloatType;
 
 Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc,
+                         Value operand1, Value operand2, Type resultType,
+                         Fn fn) {
+  auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+  if (operand2) {
+    // Sanity check: Operand types must match.
+    assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+           "expected same vector types");
+  }
+  if (!vecTy1) {
+    // Not a vector. Call the function directly.
+    return fn(operand1, operand2, resultType);
+  }
+
+  // Prepare scalar operands.
+  ResultRange sclars1 =
+      vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+  SmallVector<Value> scalars2;
+  if (!operand2) {
+    // No second operand. Create a vector of empty values.
+    scalars2.assign(vecTy1.getNumElements(), Value());
+  } else {
+    llvm::append_range(
+        scalars2,
+        vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+  }
+
+  // Call the function for each pair of scalar operands.
+  auto resultVecType = cast<VectorType>(resultType);
+  SmallVector<Value> results;
+  for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+    Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+    results.push_back(result);
+  }
+
+  // Package the results into a vector.
+  return vector::FromElementsOp::create(
+      rewriter, loc,
+      vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+      results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_

>From ddbcffafcf762a6357ed2e5b9cb3fe58e182e4ea Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 15:40:12 -0800
Subject: [PATCH 2/3] vectorize isop and abs (but not tests)

---
 .../ArithAndMathToAPFloat/MathToAPFloat.cpp   | 93 +++++++++----------
 1 file changed, 46 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 784028f5cf2eb..fad59b6a4530f 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -33,16 +33,8 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
 
   LogicalResult matchAndRewrite(math::AbsFOp op,
                                 PatternRewriter &rewriter) const override {
-    // Cast operands to 64-bit integers.
-    auto operand = op.getOperand();
-    auto floatTy = dyn_cast<FloatType>(operand.getType());
-    if (!floatTy)
-      return rewriter.notifyMatchFailure(op,
-                                         "only scalar FloatTypes supported");
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -52,23 +44,30 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
       return fn;
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    Value operandBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
-    // Call APFloat function.
-    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, operandBits};
-    Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
-                                             SymbolRefAttr::get(*fn), params)
-                            ->getResult(0);
-
-    // Truncate result to the original width.
-    Value truncatedBits =
-        arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
-    rewriter.replaceOp(
-        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    // Scalarize and convert to APFloat runtime calls.
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand, Value, Type resultType) {
+          auto floatTy = cast<FloatType>(operand.getType());
+          auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+          Value operandBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, operand));
+          // Call APFloat function.
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+          SmallVector<Value> params = {semValue, operandBits};
+          Value negatedBits =
+              func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                                   SymbolRefAttr::get(*fn), params)
+                  ->getResult(0);
+          // Truncate result to the original width.
+          auto truncatedBits =
+              arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+          return arith::BitcastOp::create(rewriter, loc, floatTy,
+                                          truncatedBits);
+        });
+
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
@@ -85,16 +84,8 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    // Cast operands to 64-bit integers.
-    auto operand = op.getOperand();
-    auto floatTy = dyn_cast<FloatType>(operand.getType());
-    if (!floatTy)
-      return rewriter.notifyMatchFailure(op,
-                                         "only scalar FloatTypes supported");
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
     // Get APFloat function from runtime library.
     auto i1 = IntegerType::get(symTable->getContext(), 1);
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -107,16 +98,24 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
       return fn;
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    Value operandBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
-    // Call APFloat function.
-    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, operandBits};
-    rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
-                                              SymbolRefAttr::get(*fn), params);
+    // Scalarize and convert to APFloat runtime calls.
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand, Value, Type resultType) {
+          auto floatTy = dyn_cast<FloatType>(operand.getType());
+          auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+          Value operandBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+          // Call APFloat function.
+          Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+          SmallVector<Value> params = {semValue, operandBits};
+          return func::CallOp::create(rewriter, loc, TypeRange(i1),
+                                      SymbolRefAttr::get(*fn), params)
+              .getResult(0);
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 

>From a1c003e74c02d650a24c97185ff293a24baaa939 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 10:19:00 -0800
Subject: [PATCH 3/3] vectorize fma

---
 .../ArithAndMathToAPFloat/MathToAPFloat.cpp   | 72 +++++++++++++------
 1 file changed, 50 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index fad59b6a4530f..172af125e959e 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -130,16 +130,10 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
 
   LogicalResult matchAndRewrite(math::FmaOp op,
                                 PatternRewriter &rewriter) const override {
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
     // Cast operands to 64-bit integers.
     auto floatTy = cast<FloatType>(op.getResult().getType());
-    if (!floatTy)
-      return rewriter.notifyMatchFailure(op,
-                                         "only scalar FloatTypes supported");
-    if (floatTy.getIntOrFloatBitWidth() > 64) {
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
-    }
-
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
     auto i64Type = IntegerType::get(symTable->getContext(), 64);
     FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
@@ -150,8 +144,52 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
     Location loc = op.getLoc();
     rewriter.setInsertionPoint(op);
 
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    auto int64Type = rewriter.getI64Type();
+    IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
+    IntegerType int64Type = rewriter.getI64Type();
+
+    auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType](
+                         Value operand, Value multiplicand, Value addend) {
+      // Call APFloat function.
+      Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+      SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+      auto resultOp =
+          func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                               SymbolRefAttr::get(*fn), params);
+
+      // Truncate result to the original width.
+      auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
+                                           resultOp->getResult(0));
+      return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
+    };
+
+    if (VectorType vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
+      // Sanity check: Operand types must match.
+      assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
+             "expected same vector types");
+      assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
+             "expected same vector types");
+      // Prepare scalar operands.
+      ResultRange scalarOperands =
+          vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
+      ResultRange scalarMultiplicands =
+          vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
+      ResultRange scalarAddends =
+          vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
+      // Call the function for each pair of scalar operands.
+      SmallVector<Value> results;
+      for (auto [operand, multiplicand, addend] : llvm::zip_equal(
+               scalarOperands, scalarMultiplicands, scalarAddends)) {
+        results.push_back(scalarFMA(operand, multiplicand, addend));
+      }
+      // Package the results into a vector.
+      auto fromElements = vector::FromElementsOp::create(
+          rewriter, loc,
+          vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+          results);
+      rewriter.replaceOp(op, fromElements);
+      return success();
+    }
+
     Value operand = arith::ExtUIOp::create(
         rewriter, loc, int64Type,
         arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
@@ -161,18 +199,8 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
     Value addend = arith::ExtUIOp::create(
         rewriter, loc, int64Type,
         arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
-
-    // Call APFloat function.
-    Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, operand, multiplicand, addend};
-    auto resultOp =
-        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
-                             SymbolRefAttr::get(*fn), params);
-
-    // Truncate result to the original width.
-    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
-                                                  resultOp->getResult(0));
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+    Value repl = scalarFMA(operand, multiplicand, addend);
+    rewriter.replaceOp(op, repl);
     return success();
   }
 



More information about the Mlir-commits mailing list