[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 16 07:49:49 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172715
>From 356f1519aa2bc393c07e4f5fbb38020d88a6705c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 10:53:55 -0800
Subject: [PATCH 1/8] [mlir][math] Add vector support for math-to-apfloat
---
.../ArithAndMathToAPFloat/ArithToAPFloat.cpp | 67 -------------------
.../ArithAndMathToAPFloat/Utils.cpp | 25 ++++++-
.../Conversion/ArithAndMathToAPFloat/Utils.h | 56 ++++++++++++++++
3 files changed, 79 insertions(+), 69 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 813a854f2fc97..98185697e4591 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -46,73 +46,6 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
{i32Type, i64Type, i64Type}, symbolTables);
}
-/// Given two operands of vector type and vector result type (with the same
-/// shape), call the given function for each pair of scalar operands and
-/// package the result into a vector. If the given operands and result type are
-/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
-static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
- Value operand1, Value operand2, Type resultType,
- Fn fn) {
- auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
- if (operand2) {
- // Sanity check: Operand types must match.
- assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
- "expected same vector types");
- }
- if (!vecTy1) {
- // Not a vector. Call the function directly.
- return fn(operand1, operand2, resultType);
- }
-
- // Prepare scalar operands.
- ResultRange sclars1 =
- vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
- SmallVector<Value> scalars2;
- if (!operand2) {
- // No second operand. Create a vector of empty values.
- scalars2.assign(vecTy1.getNumElements(), Value());
- } else {
- llvm::append_range(
- scalars2,
- vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
- }
-
- // Call the function for each pair of scalar operands.
- auto resultVecType = cast<VectorType>(resultType);
- SmallVector<Value> results;
- for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
- Value result = fn(scalar1, scalar2, resultVecType.getElementType());
- results.push_back(result);
- }
-
- // Package the results into a vector.
- return vector::FromElementsOp::create(
- rewriter, loc,
- vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
- results);
-}
-
-/// Check preconditions for the conversion:
-/// 1. All operands / results must be integers or floats (or vectors thereof).
-/// 2. The bitwidth of the operands / results must be <= 64.
-static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
- for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
- Type type = value.getType();
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- type = vecTy.getElementType();
- }
- if (!type.isIntOrFloat()) {
- return rewriter.notifyMatchFailure(
- op, "only integers and floats (or vectors thereof) are supported");
- }
- if (type.getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
- return success();
-}
-
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
index 2b5857367dc40..340e015404d86 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -9,14 +9,35 @@
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
-mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
- FloatType floatTy) {
+using namespace mlir;
+
+Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
+ FloatType floatTy) {
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
return arith::ConstantOp::create(b, loc, b.getI32Type(),
b.getIntegerAttr(b.getI32Type(), sem));
}
+
+LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
index 5f11d24261b43..d38d3b4c93945 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -9,6 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class Value;
class OpBuilder;
@@ -16,6 +19,59 @@ class Location;
class FloatType;
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op);
+
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
>From ddbcffafcf762a6357ed2e5b9cb3fe58e182e4ea Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 15:40:12 -0800
Subject: [PATCH 2/8] vectorize isop and abs (but not tests)
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 93 +++++++++----------
1 file changed, 46 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 784028f5cf2eb..fad59b6a4530f 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -33,16 +33,8 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -52,23 +44,30 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
- SymbolRefAttr::get(*fn), params)
- ->getResult(0);
-
- // Truncate result to the original width.
- Value truncatedBits =
- arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+ // Truncate result to the original width.
+ auto truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -85,16 +84,8 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1 = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -107,16 +98,24 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
- SymbolRefAttr::get(*fn), params);
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = dyn_cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ return func::CallOp::create(rewriter, loc, TypeRange(i1),
+ SymbolRefAttr::get(*fn), params)
+ .getResult(0);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
>From a1c003e74c02d650a24c97185ff293a24baaa939 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 10:19:00 -0800
Subject: [PATCH 3/8] vectorize fma
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 72 +++++++++++++------
1 file changed, 50 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index fad59b6a4530f..172af125e959e 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -130,16 +130,10 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
LogicalResult matchAndRewrite(math::FmaOp op,
PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Cast operands to 64-bit integers.
auto floatTy = cast<FloatType>(op.getResult().getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
-
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
@@ -150,8 +144,52 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
+ IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
+ IntegerType int64Type = rewriter.getI64Type();
+
+ auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType](
+ Value operand, Value multiplicand, Value addend) {
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
+ };
+
+ if (VectorType vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
+ "expected same vector types");
+ assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
+ "expected same vector types");
+ // Prepare scalar operands.
+ ResultRange scalarOperands =
+ vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
+ ResultRange scalarMultiplicands =
+ vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
+ ResultRange scalarAddends =
+ vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
+ // Call the function for each pair of scalar operands.
+ SmallVector<Value> results;
+ for (auto [operand, multiplicand, addend] : llvm::zip_equal(
+ scalarOperands, scalarMultiplicands, scalarAddends)) {
+ results.push_back(scalarFMA(operand, multiplicand, addend));
+ }
+ // Package the results into a vector.
+ auto fromElements = vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+ rewriter.replaceOp(op, fromElements);
+ return success();
+ }
+
Value operand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
@@ -161,18 +199,8 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Value addend = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operand, multiplicand, addend};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+ Value repl = scalarFMA(operand, multiplicand, addend);
+ rewriter.replaceOp(op, repl);
return success();
}
>From 982454b51447387c8f756b83bc3991e8b7cfe4c1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 15 Jan 2026 15:33:11 -0800
Subject: [PATCH 4/8] remove template
---
.../ArithAndMathToAPFloat/Utils.cpp | 42 +++++++++++++++++++
.../Conversion/ArithAndMathToAPFloat/Utils.h | 41 +-----------------
2 files changed, 43 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
index 340e015404d86..01f55a8da15a3 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -25,6 +25,48 @@ Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
b.getIntegerAttr(b.getI32Type(), sem));
}
+Value mlir::forEachScalarValue(
+ mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2,
+ Type resultType, llvm::function_ref<Value(Value, Value, Type)> fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
Type type = value.getType();
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
index d38d3b4c93945..dfadf9449b497 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -24,48 +24,9 @@ Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
/// shape), call the given function for each pair of scalar operands and
/// package the result into a vector. If the given operands and result type are
/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc,
Value operand1, Value operand2, Type resultType,
- Fn fn) {
- auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
- if (operand2) {
- // Sanity check: Operand types must match.
- assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
- "expected same vector types");
- }
- if (!vecTy1) {
- // Not a vector. Call the function directly.
- return fn(operand1, operand2, resultType);
- }
-
- // Prepare scalar operands.
- ResultRange sclars1 =
- vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
- SmallVector<Value> scalars2;
- if (!operand2) {
- // No second operand. Create a vector of empty values.
- scalars2.assign(vecTy1.getNumElements(), Value());
- } else {
- llvm::append_range(
- scalars2,
- vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
- }
-
- // Call the function for each pair of scalar operands.
- auto resultVecType = cast<VectorType>(resultType);
- SmallVector<Value> results;
- for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
- Value result = fn(scalar1, scalar2, resultVecType.getElementType());
- results.push_back(result);
- }
-
- // Package the results into a vector.
- return vector::FromElementsOp::create(
- rewriter, loc,
- vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
- results);
-}
+ llvm::function_ref<Value(Value, Value, Type)> fn);
/// Check preconditions for the conversion:
/// 1. All operands / results must be integers or floats (or vectors thereof).
>From f2038d191583cfdd53deae07f86644023851cbb2 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 15 Jan 2026 16:44:18 -0800
Subject: [PATCH 5/8] add/fix tests
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 32 +++++++++------
.../CPU/test-apfloat-emulation-vector.mlir | 2 -
.../CPU/test-apfloat-emulation-vector.mlir | 41 +++++++++++++++++++
.../Math/CPU/test-apfloat-emulation.mlir | 14 +++++--
4 files changed, 70 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 172af125e959e..1ce0cab0bd08f 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -133,7 +133,13 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
if (failed(checkPreconditions(rewriter, op)))
return failure();
// Cast operands to 64-bit integers.
- auto floatTy = cast<FloatType>(op.getResult().getType());
+ mlir::Type resType = op.getResult().getType();
+ auto floatTy = dyn_cast<FloatType>(resType);
+ if (!floatTy) {
+ auto vecTy1 = dyn_cast<VectorType>(resType);
+ assert(vecTy1 && "expected VectorType");
+ floatTy = llvm::cast<FloatType>(vecTy1.getElementType());
+ }
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
@@ -147,8 +153,17 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
IntegerType int64Type = rewriter.getI64Type();
- auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType](
- Value operand, Value multiplicand, Value addend) {
+ auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType,
+ &int64Type](Value a, Value b, Value c) {
+ Value operand = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, a));
+ Value multiplicand = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, b));
+ Value addend = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, c));
// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operand, multiplicand, addend};
@@ -190,16 +205,7 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
return success();
}
- Value operand = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
- Value multiplicand = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
- Value addend = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
- Value repl = scalarFMA(operand, multiplicand, addend);
+ Value repl = scalarFMA(op.getA(), op.getB(), op.getC());
rewriter.replaceOp(op, repl);
return success();
}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
index dfd9e7c4aaa14..cc773e60dda3e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
@@ -1,6 +1,4 @@
// REQUIRES: system-linux || system-darwin
-// TODO: Run only on Linux until we figure out how to build
-// mlir_apfloat_wrappers in a platform-independent way.
// All floating-point arithmetics is lowered through APFloat.
// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-vector-to-scf \
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir
new file mode 100644
index 0000000000000..c0b2d858c1fec
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation-vector.mlir
@@ -0,0 +1,41 @@
+// REQUIRES: system-linux || system-darwin
+
+// All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-math-to-apfloat --convert-vector-to-scf \
+// RUN: --convert-scf-to-cf --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+func.func @entry() {
+
+ %neg14fp8 = arith.constant dense<[-1.4, -1.4, -1.4, -1.4]> : vector<4xf8E4M3FN>
+ %absfp8 = math.absf %neg14fp8 : vector<4xf8E4M3FN>
+ // CHECK: ( 1.375, 1.375, 1.375, 1.375 )
+ vector.print %absfp8 : vector<4xf8E4M3FN>
+
+ %a1_vec = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf8E4M3FN>
+ %b1_vec = arith.constant dense<[4.0, 4.0, 4.0, 4.0]> : vector<4xf8E4M3FN>
+ %c1_vec = arith.constant dense<[8.0, 8.0, 8.0, 8.0]> : vector<4xf8E4M3FN>
+ %d1_vec = math.fma %a1_vec, %b1_vec, %c1_vec : vector<4xf8E4M3FN> // not supported by LLVM
+ // CHECK: ( 16, 16, 16, 16 )
+ vector.print %d1_vec : vector<4xf8E4M3FN>
+
+ // CHECK: ( 0, 0, 0, 0 )
+ %isinffp8 = math.isinf %neg14fp8 : vector<4xf8E4M3FN>
+ vector.print %isinffp8 : vector<4xi1>
+
+ %isnanfp8 = math.isnan %neg14fp8 : vector<4xf8E4M3FN>
+ // CHECK: ( 0, 0, 0, 0 )
+ vector.print %isnanfp8 : vector<4xi1>
+
+ %isnormalfp8 = math.isnormal %neg14fp8 : vector<4xf8E4M3FN>
+ // CHECK: ( 1, 1, 1, 1 )
+ vector.print %isnormalfp8 : vector<4xi1>
+
+ %isfinitefp8 = math.isfinite %neg14fp8 : vector<4xf8E4M3FN>
+ // CHECK: ( 1, 1, 1, 1 )
+ vector.print %isfinitefp8 : vector<4xi1>
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
index c890b470b563a..0cc3d3f2218f0 100644
--- a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir
@@ -24,15 +24,18 @@ func.func @entry() {
// CHECK: 16
vector.print %fmafp8 : f8E8M0FNU
- // CHECK: 0
%isinffp8 = math.isinf %neg14fp8 : f8E4M3FN
- vector.print %isinffp8 : i1
// CHECK: 0
+ vector.print %isinffp8 : i1
+
%isnanfp8 = math.isnan %neg14fp8 : f8E4M3FN
+ // CHECK: 0
vector.print %isnanfp8 : i1
+
%isnormalfp8 = math.isnormal %neg14fp8 : f8E4M3FN
// CHECK: 1
vector.print %isnormalfp8 : i1
+
%isfinitefp8 = math.isfinite %neg14fp8 : f8E4M3FN
// CHECK: 1
vector.print %isfinitefp8 : i1
@@ -51,15 +54,18 @@ func.func @entry() {
// CHECK: 16
vector.print %fmafp32 : f32
- // CHECK: 0
%isinffp32 = math.isinf %neg14fp32 : f32
- vector.print %isinffp32 : i1
// CHECK: 0
+ vector.print %isinffp32 : i1
+
%isnanfp32 = math.isnan %neg14fp32 : f32
+ // CHECK: 0
vector.print %isnanfp32 : i1
+
%isnormalfp32 = math.isnormal %neg14fp32 : f32
// CHECK: 1
vector.print %isnormalfp32 : i1
+
%isfinitefp32 = math.isfinite %neg14fp32 : f32
// CHECK: 1
vector.print %isfinitefp32 : i1
>From d972311361745abfc09df64e326b6525382d335c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 15 Jan 2026 18:25:37 -0800
Subject: [PATCH 6/8] [mlir][Python] remove stray nb::cast
---
mlir/lib/Bindings/Python/TransformInterpreter.cpp | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index a9f204ff9d0a5..b263e65ff8cf8 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -71,12 +71,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
PyOperationBase &transformModule, const PyTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(transformRoot.getOperation()));
-
- // Calling back into Python to invalidate everything under the payload
- // root. This is awkward, but we don't have access to PyMlirContext
- // object here otherwise.
- nb::object obj = nb::cast(payloadRoot);
-
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot.getOperation(), transformRoot.getOperation(),
transformModule.getOperation(), options.options);
>From 96bfc572d47b41f9724b8dfabaab2af02eb76fcb Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 15 Jan 2026 18:35:11 -0800
Subject: [PATCH 7/8] address comments
---
.../Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 1ce0cab0bd08f..af4a42aa308b3 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -102,7 +102,7 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
Value repl = forEachScalarValue(
rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
[&](Value operand, Value, Type resultType) {
- auto floatTy = dyn_cast<FloatType>(operand.getType());
+ auto floatTy = cast<FloatType>(operand.getType());
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
Value operandBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
@@ -110,7 +110,7 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
// Call APFloat function.
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
+ Value params[] = {semValue, operandBits};
return func::CallOp::create(rewriter, loc, TypeRange(i1),
SymbolRefAttr::get(*fn), params)
.getResult(0);
@@ -136,8 +136,7 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
mlir::Type resType = op.getResult().getType();
auto floatTy = dyn_cast<FloatType>(resType);
if (!floatTy) {
- auto vecTy1 = dyn_cast<VectorType>(resType);
- assert(vecTy1 && "expected VectorType");
+ auto vecTy1 = cast<VectorType>(resType);
floatTy = llvm::cast<FloatType>(vecTy1.getElementType());
}
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -177,7 +176,7 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
};
- if (VectorType vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
+ if (auto vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
// Sanity check: Operand types must match.
assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
"expected same vector types");
>From f09321d7efd21c55da79cfbb9769794bf8570aaf Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 16 Jan 2026 07:49:33 -0800
Subject: [PATCH 8/8] remove stray merge
---
mlir/lib/Bindings/Python/TransformInterpreter.cpp | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index b263e65ff8cf8..a9f204ff9d0a5 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -71,6 +71,12 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
PyOperationBase &transformModule, const PyTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(transformRoot.getOperation()));
+
+ // Calling back into Python to invalidate everything under the payload
+ // root. This is awkward, but we don't have access to PyMlirContext
+ // object here otherwise.
+ nb::object obj = nb::cast(payloadRoot);
+
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot.getOperation(), transformRoot.getOperation(),
transformModule.getOperation(), options.options);
More information about the Mlir-commits
mailing list