[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 9 13:29:20 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172715
>From 356f1519aa2bc393c07e4f5fbb38020d88a6705c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 10:53:55 -0800
Subject: [PATCH 1/4] [mlir][math] Add vector support for math-to-apfloat
---
.../ArithAndMathToAPFloat/ArithToAPFloat.cpp | 67 -------------------
.../ArithAndMathToAPFloat/Utils.cpp | 25 ++++++-
.../Conversion/ArithAndMathToAPFloat/Utils.h | 56 ++++++++++++++++
3 files changed, 79 insertions(+), 69 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 813a854f2fc97..98185697e4591 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -46,73 +46,6 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
{i32Type, i64Type, i64Type}, symbolTables);
}
-/// Given two operands of vector type and vector result type (with the same
-/// shape), call the given function for each pair of scalar operands and
-/// package the result into a vector. If the given operands and result type are
-/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
-static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
- Value operand1, Value operand2, Type resultType,
- Fn fn) {
- auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
- if (operand2) {
- // Sanity check: Operand types must match.
- assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
- "expected same vector types");
- }
- if (!vecTy1) {
- // Not a vector. Call the function directly.
- return fn(operand1, operand2, resultType);
- }
-
- // Prepare scalar operands.
- ResultRange sclars1 =
- vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
- SmallVector<Value> scalars2;
- if (!operand2) {
- // No second operand. Create a vector of empty values.
- scalars2.assign(vecTy1.getNumElements(), Value());
- } else {
- llvm::append_range(
- scalars2,
- vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
- }
-
- // Call the function for each pair of scalar operands.
- auto resultVecType = cast<VectorType>(resultType);
- SmallVector<Value> results;
- for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
- Value result = fn(scalar1, scalar2, resultVecType.getElementType());
- results.push_back(result);
- }
-
- // Package the results into a vector.
- return vector::FromElementsOp::create(
- rewriter, loc,
- vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
- results);
-}
-
-/// Check preconditions for the conversion:
-/// 1. All operands / results must be integers or floats (or vectors thereof).
-/// 2. The bitwidth of the operands / results must be <= 64.
-static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
- for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
- Type type = value.getType();
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- type = vecTy.getElementType();
- }
- if (!type.isIntOrFloat()) {
- return rewriter.notifyMatchFailure(
- op, "only integers and floats (or vectors thereof) are supported");
- }
- if (type.getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
- return success();
-}
-
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
index 2b5857367dc40..340e015404d86 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -9,14 +9,35 @@
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
-mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
- FloatType floatTy) {
+using namespace mlir;
+
+Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
+ FloatType floatTy) {
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
return arith::ConstantOp::create(b, loc, b.getI32Type(),
b.getIntegerAttr(b.getI32Type(), sem));
}
+
+LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
index 5f11d24261b43..d38d3b4c93945 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -9,6 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class Value;
class OpBuilder;
@@ -16,6 +19,59 @@ class Location;
class FloatType;
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy);
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op);
+
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
>From ddbcffafcf762a6357ed2e5b9cb3fe58e182e4ea Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 17 Dec 2025 15:40:12 -0800
Subject: [PATCH 2/4] vectorize isop and abs (but not tests)
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 93 +++++++++----------
1 file changed, 46 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 784028f5cf2eb..fad59b6a4530f 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -33,16 +33,8 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -52,23 +44,30 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
- SymbolRefAttr::get(*fn), params)
- ->getResult(0);
-
- // Truncate result to the original width.
- Value truncatedBits =
- arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+ // Truncate result to the original width.
+ auto truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -85,16 +84,8 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1 = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -107,16 +98,24 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1),
- SymbolRefAttr::get(*fn), params);
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = dyn_cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ return func::CallOp::create(rewriter, loc, TypeRange(i1),
+ SymbolRefAttr::get(*fn), params)
+ .getResult(0);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
>From a1c003e74c02d650a24c97185ff293a24baaa939 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 10:19:00 -0800
Subject: [PATCH 3/4] vectorize fma
---
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 72 +++++++++++++------
1 file changed, 50 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index fad59b6a4530f..172af125e959e 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -130,16 +130,10 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
LogicalResult matchAndRewrite(math::FmaOp op,
PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Cast operands to 64-bit integers.
auto floatTy = cast<FloatType>(op.getResult().getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
-
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
@@ -150,8 +144,52 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
+ IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
+ IntegerType int64Type = rewriter.getI64Type();
+
+ auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType](
+ Value operand, Value multiplicand, Value addend) {
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
+ };
+
+ if (VectorType vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(op.getB().getType()) &&
+ "expected same vector types");
+ assert(vecTy1 == dyn_cast<VectorType>(op.getC().getType()) &&
+ "expected same vector types");
+ // Prepare scalar operands.
+ ResultRange scalarOperands =
+ vector::ToElementsOp::create(rewriter, loc, op.getA())->getResults();
+ ResultRange scalarMultiplicands =
+ vector::ToElementsOp::create(rewriter, loc, op.getB())->getResults();
+ ResultRange scalarAddends =
+ vector::ToElementsOp::create(rewriter, loc, op.getC())->getResults();
+ // Call the function for each pair of scalar operands.
+ SmallVector<Value> results;
+ for (auto [operand, multiplicand, addend] : llvm::zip_equal(
+ scalarOperands, scalarMultiplicands, scalarAddends)) {
+ results.push_back(scalarFMA(operand, multiplicand, addend));
+ }
+ // Package the results into a vector.
+ auto fromElements = vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+ rewriter.replaceOp(op, fromElements);
+ return success();
+ }
+
Value operand = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
@@ -161,18 +199,8 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Value addend = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operand, multiplicand, addend};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+ Value repl = scalarFMA(operand, multiplicand, addend);
+ rewriter.replaceOp(op, repl);
return success();
}
>From 102d672855ad74a37f0f28f26c1f4d1bfa8b712d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 11:27:42 -0800
Subject: [PATCH 4/4] add source type pre-condition
---
.../ArithAndMathToAPFloat/ArithToAPFloat.h | 1 +
.../ArithAndMathToAPFloat/MathToAPFloat.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 8 ++
.../ArithAndMathToAPFloat/ArithToAPFloat.cpp | 84 ++++++++++++-------
.../ArithAndMathToAPFloat/MathToAPFloat.cpp | 46 ++++++----
.../ArithAndMathToAPFloat/Utils.cpp | 30 ++++++-
.../Conversion/ArithAndMathToAPFloat/Utils.h | 7 +-
7 files changed, 126 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
index 6702aca045ba4..2dacc2e11b049 100644
--- a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
+#include "llvm/ADT/SmallVector.h"
#include <memory>
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
index 6cb44c89ecebb..06548c250a27b 100644
--- a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
+#include "llvm/ADT/SmallVector.h"
#include <memory>
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..fb2860bee6d43 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,6 +198,10 @@ def ArithToAPFloatConversionPass
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without arithmetic support on a given target">,
+ ];
let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
"vector::VectorDialect"];
}
@@ -787,6 +791,10 @@ def MathToAPFloatConversionPass
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point mathmetic operations.
}];
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without arithmetic support on a given target">,
+ ];
let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 98185697e4591..52eb32de6586b 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -52,13 +52,14 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
BinaryArithOpToAPFloatConversion(MLIRContext *context,
const char *APFloatName,
SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- APFloatName(APFloatName) {};
+ APFloatName(APFloatName), sourceTypes(sourceTypes) {};
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -103,17 +104,19 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
const char *APFloatName;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct FpToFpConversion final : OpRewritePattern<OpTy> {
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
- PatternBenefit benefit = 1)
- : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
+ ArrayRef<Type> sourceTypes, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -160,18 +163,20 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct FpToIntConversion final : OpRewritePattern<OpTy> {
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
- bool isUnsigned, PatternBenefit benefit = 1)
+ bool isUnsigned, ArrayRef<Type> sourceTypes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- isUnsigned(isUnsigned) {}
+ isUnsigned(isUnsigned), sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -222,18 +227,20 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
bool isUnsigned;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct IntToFpConversion final : OpRewritePattern<OpTy> {
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
- bool isUnsigned, PatternBenefit benefit = 1)
+ bool isUnsigned, ArrayRef<Type> sourceTypes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- isUnsigned(isUnsigned) {}
+ isUnsigned(isUnsigned), sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -294,16 +301,19 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
bool isUnsigned;
+ ArrayRef<Type> sourceTypes;
};
struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -445,16 +455,19 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+ : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(arith::NegFOp op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -497,6 +510,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
namespace {
@@ -510,36 +524,46 @@ struct ArithToAPFloatConversionPass final
void ArithToAPFloatConversionPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
- getOperation());
+
+ FailureOr<SmallVector<Type>> sourceTypes =
+ parseSourceTypes(llvm::to_vector(sourceTypeStrs), context);
+ if (failed(sourceTypes))
+ return signalPassFailure();
+
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(
+ context, "add", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
- context, "subtract", getOperation());
+ context, "subtract", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
- context, "multiply", getOperation());
+ context, "multiply", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
- context, "divide", getOperation());
+ context, "divide", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
- context, "remainder", getOperation());
+ context, "remainder", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
- context, "minnum", getOperation());
+ context, "minnum", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
- context, "maxnum", getOperation());
+ context, "maxnum", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
- context, "minimum", getOperation());
+ context, "minimum", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
- context, "maximum", getOperation());
+ context, "maximum", getOperation(), *sourceTypes);
patterns
.add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
- context, getOperation());
+ context, getOperation(), *sourceTypes);
patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
- /*isUnsigned=*/false);
+ /*isUnsigned=*/false,
+ *sourceTypes);
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
- /*isUnsigned=*/true);
+ /*isUnsigned=*/true,
+ *sourceTypes);
patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
- /*isUnsigned=*/false);
+ /*isUnsigned=*/false,
+ *sourceTypes);
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
- /*isUnsigned=*/true);
+ /*isUnsigned=*/true,
+ *sourceTypes);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 172af125e959e..b5e15e5c42bed 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -28,12 +28,14 @@ using namespace mlir::func;
struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -72,19 +74,21 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- APFloatName(APFloatName) {};
+ APFloatName(APFloatName), sourceTypes(sourceTypes) {};
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
auto i1 = IntegerType::get(symTable->getContext(), 1);
@@ -121,16 +125,19 @@ struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
const char *APFloatName;
+ ArrayRef<Type> sourceTypes;
};
struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable) {};
+ : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {};
LogicalResult matchAndRewrite(math::FmaOp op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Cast operands to 64-bit integers.
auto floatTy = cast<FloatType>(op.getResult().getType());
@@ -205,6 +212,7 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
namespace {
@@ -219,16 +227,22 @@ void MathToAPFloatConversionPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- patterns.add<AbsFOpToAPFloatConversion>(context, getOperation());
- patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context, "finite",
- getOperation());
- patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context, "infinite",
- getOperation());
- patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context, "nan",
- getOperation());
- patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal",
- getOperation());
- patterns.add<FmaOpToAPFloatConversion>(context, getOperation());
+ FailureOr<SmallVector<Type>> sourceTypes =
+ parseSourceTypes(llvm::to_vector(sourceTypeStrs), context);
+ if (failed(sourceTypes))
+ return signalPassFailure();
+
+ patterns.add<AbsFOpToAPFloatConversion>(context, getOperation(),
+ *sourceTypes);
+ patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(
+ context, "finite", getOperation(), *sourceTypes);
+ patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(
+ context, "infinite", getOperation(), *sourceTypes);
+ patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(
+ context, "nan", getOperation(), *sourceTypes);
+ patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(
+ context, "normal", getOperation(), *sourceTypes);
+ patterns.add<FmaOpToAPFloatConversion>(context, getOperation(), *sourceTypes);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
index 340e015404d86..fb31d9528d2cf 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -25,19 +26,40 @@ Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc,
b.getIntegerAttr(b.getI32Type(), sem));
}
-LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op) {
+LogicalResult mlir::checkPreconditions(RewriterBase &rewriter, Operation *op,
+ ArrayRef<Type> sourceTypes) {
for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
Type type = value.getType();
- if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (auto vecTy = dyn_cast<VectorType>(type))
type = vecTy.getElementType();
- }
+ if (!sourceTypes.empty() && !llvm::is_contained(sourceTypes, type))
+ return rewriter.notifyMatchFailure(op, "unsupported source type");
if (!type.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "only integers and floats (or vectors thereof) are supported");
}
- if (type.getIntOrFloatBitWidth() > 64)
+ if (type.getIntOrFloatBitWidth() > 64) {
return rewriter.notifyMatchFailure(op,
"bitwidth > 64 bits is not supported");
+ }
}
return success();
}
+
+FailureOr<SmallVector<Type>>
+mlir::parseSourceTypes(SmallVector<std::string> sourceTypeStrs,
+ MLIRContext *ctx) {
+ SmallVector<Type> sourceTypes;
+ for (StringRef sourceTypeStr : sourceTypeStrs) {
+ std::optional<FloatType> maybeSourceType =
+ arith::parseFloatType(ctx, sourceTypeStr);
+ if (!maybeSourceType) {
+ emitError(UnknownLoc::get(ctx), "could not map source type '" +
+ sourceTypeStr +
+ "' to a known floating-point type");
+ return failure();
+ }
+ sourceTypes.push_back(*maybeSourceType);
+ }
+ return sourceTypes;
+}
\ No newline at end of file
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
index d38d3b4c93945..4be95590e6b98 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h
@@ -11,6 +11,8 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
class Value;
@@ -70,8 +72,11 @@ Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc,
/// Check preconditions for the conversion:
/// 1. All operands / results must be integers or floats (or vectors thereof).
/// 2. The bitwidth of the operands / results must be <= 64.
-LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op);
+LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op,
+ ArrayRef<Type> sourceTypes);
+FailureOr<SmallVector<Type>>
+parseSourceTypes(SmallVector<std::string> sourceTypeStrs, MLIRContext *ctx);
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_
More information about the Mlir-commits
mailing list