[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 9 10:47:51 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172715
>From 356f1519aa2bc393c07e4f5fbb38020d88a6705c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 10:53:55 -0800
Subject: [PATCH 1/3] [mlir][math] Add vector support for math-to-apfloat
---
.../ArithAndMathToAPFloat/ArithToAPFloat.cpp | 67 -------------------
.../ArithAndMathToAPFloat/Utils.cpp | 25 ++++++-
.../Conversion/ArithAndMathToAPFloat/Utils.h | 56 ++++++++++++++++
3 files changed, 79 insertions(+), 69 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 813a854f2fc97..98185697e4591 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -46,73 +46,6 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
{i32Type, i64Type, i64Type}, symbolTables);
}
-/// Given two operands of vector type and vector result type (with the same
-/// shape), call the given function for each pair of scalar operands and
-/// package the result into a vector. If the given operands and result type are
-/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
-static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
- Value operand1, Value operand2, Type resultType,
- Fn fn) {
- auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
- if (operand2) {
- // Sanity check: Operand types must match.
- assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
- "expected same vector types");
- }
- if (!vecTy1) {
- // Not a vector. Call the function directly.
- return fn(operand1, operand2, resultType);
- }
-
- // Prepare scalar operands.
- ResultRange sclars1 =
- vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
- SmallVector<Value> scalars2;
- if (!operand2) {
- // No second operand. Create a vector of empty values.
- scalars2.assign(vecTy1.getNumElements(), Value());
- } else {
- llvm::append_range(
- scalars2,
- vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
- }
-
- // Call the function for each pair of scalar operands.
- auto resultVecType = cast<VectorType>(resultType);
- SmallVector<Value> results;
- for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
- Value result = fn(scalar1, scalar2, resultVecType.getElementType());
- results.push_back(result);
- }
-
- // Package the results into a vector.
- return vector::FromElementsOp::create(
- rewriter, loc,
- vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
- results);
-}
-
-/// Check preconditions for the conversion:
-/// 1. All operands / results must be integers or floats (or vectors thereof).
-/// 2. The bitwidth of the operands / results must be <= 64.
-static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
- for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
- Type type = value.getType();
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- type = vecTy.getElementType();
- }
- if (!type.isIntOrFloat()) {
- return rewriter.notifyMatchFailure(
- op, "only integers and floats (or vectors thereof) are supported");
- }
- if (type.getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
- return success();
-}
-
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
index 2b5857367dc40..340e015404d86 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -9,14 +9,35 @@
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
-mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
- FloatType floatTy) {
+using namespace mlir;
+
+Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
+ FloatType floatTy) {
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
return arith::ConstantOp::create(b, loc, b.getI32Type(),
b.getIntegerAttr(b.getI32Type(), sem));
}
+
+LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
index 5f11d24261b43..d38d3b4c93945 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -9,6 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class Value;
class OpBuilder;
@@ -16,6 +19,59 @@ class Location;
class FloatType;
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op);
+
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
>From ddbcffafcf762a6357ed2e5b9cb3fe58e182e4ea Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 15:40:12 -0800
Subject: [PATCH 2/3] vectorize isop and abs (but not tests)
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 93 +++++++++----------
1 file changed, 46 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 784028f5cf2eb..fad59b6a4530f 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -33,16 +33,8 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -52,23 +44,30 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
- SymbolRefAttr::get(*fn), params)
- ->getResult(0);
-
- // Truncate result to the original width.
- Value truncatedBits =
- arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+ // Truncate result to the original width.
+ auto truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -85,16 +84,8 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1 = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -107,16 +98,24 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
- SymbolRefAttr::get(*fn), params);
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = dyn_cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ return func::CallOp::create(rewriter, loc, TypeRange(i1),
+ SymbolRefAttr::get(*fn), params)
+ .getResult(0);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
>From a1c003e74c02d650a24c97185ff293a24baaa939 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 10:19:00 -0800
Subject: [PATCH 3/3] vectorize fma
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 72 +++++++++++++------
1 file changed, 50 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index fad59b6a4530f..172af125e959e 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -130,16 +130,10 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
LogicalResult matchAndRewrite(math::FmaOp op,
PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Cast operands to 64-bit integers.
auto floatTy = cast<FloatType>(op.getResult().getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
-
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
@@ -150,8 +144,52 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
+ IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
+ IntegerType int64Type = rewriter.getI64Type();
+
+ auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType](
+ Value operand, Value multiplicand, Value addend) {
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
+ };
+
+ if (VectorType vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
+ "expected same vector types");
+ assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
+ "expected same vector types");
+ // Prepare scalar operands.
+ ResultRange scalarOperands =
+ vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
+ ResultRange scalarMultiplicands =
+ vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
+ ResultRange scalarAddends =
+ vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
+ // Call the function for each pair of scalar operands.
+ SmallVector<Value> results;
+ for (auto [operand, multiplicand, addend] : llvm::zip_equal(
+ scalarOperands, scalarMultiplicands, scalarAddends)) {
+ results.push_back(scalarFMA(operand, multiplicand, addend));
+ }
+ // Package the results into a vector.
+ auto fromElements = vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+ rewriter.replaceOp(op, fromElements);
+ return success();
+ }
+
Value operand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
@@ -161,18 +199,8 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Value addend = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operand, multiplicand, addend};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+ Value repl = scalarFMA(operand, multiplicand, addend);
+ rewriter.replaceOp(op, repl);
return success();
}
More information about the Mlir-commits
mailing list