[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 9 13:44:29 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
This PR adds vector type support to `math-to-apfloat`. It also adds `supported-types` (matching the convention/semantics for the pass arg established by `-arith-emulate-unsupported-floats`) pass arguments to both `arith-to-apfloat` and `math-to-apfloat` to filter down which types will be converted. Note, by default (i.e., empty `supported-types`) all `fp` types will be converted (i.e., `empty` -> `convert all`).
TODO: add lit tests
---
Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172715.diff
7 Files Affected:
- (modified) mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h (+1)
- (modified) mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h (+1)
- (modified) mlir/include/mlir/Conversion/Passes.td (+8)
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp (+54-97)
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp (+123-82)
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp (+45-2)
- (modified) mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h (+61)
``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
index 6702aca045ba4..2dacc2e11b049 100644
--- a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H
+#include "llvm/ADT/SmallVector.h"
#include <memory>
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
index 6cb44c89ecebb..06548c250a27b 100644
--- a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
+++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H
+#include "llvm/ADT/SmallVector.h"
#include <memory>
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7f24e58671aab..fb2860bee6d43 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,6 +198,10 @@ def ArithToAPFloatConversionPass
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without arithmetic support on a given target">,
+ ];
let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
"vector::VectorDialect"];
}
@@ -787,6 +791,10 @@ def MathToAPFloatConversionPass
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point mathmetic operations.
}];
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without arithmetic support on a given target">,
+ ];
let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
}
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 813a854f2fc97..52eb32de6586b 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -46,86 +46,20 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
{i32Type, i64Type, i64Type}, symbolTables);
}
-/// Given two operands of vector type and vector result type (with the same
-/// shape), call the given function for each pair of scalar operands and
-/// package the result into a vector. If the given operands and result type are
-/// not vectors, call the function directly. The second operand is optional.
-template <typename Fn, typename... Values>
-static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
- Value operand1, Value operand2, Type resultType,
- Fn fn) {
- auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
- if (operand2) {
- // Sanity check: Operand types must match.
- assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
- "expected same vector types");
- }
- if (!vecTy1) {
- // Not a vector. Call the function directly.
- return fn(operand1, operand2, resultType);
- }
-
- // Prepare scalar operands.
- ResultRange sclars1 =
- vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
- SmallVector<Value> scalars2;
- if (!operand2) {
- // No second operand. Create a vector of empty values.
- scalars2.assign(vecTy1.getNumElements(), Value());
- } else {
- llvm::append_range(
- scalars2,
- vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
- }
-
- // Call the function for each pair of scalar operands.
- auto resultVecType = cast<VectorType>(resultType);
- SmallVector<Value> results;
- for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
- Value result = fn(scalar1, scalar2, resultVecType.getElementType());
- results.push_back(result);
- }
-
- // Package the results into a vector.
- return vector::FromElementsOp::create(
- rewriter, loc,
- vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
- results);
-}
-
-/// Check preconditions for the conversion:
-/// 1. All operands / results must be integers or floats (or vectors thereof).
-/// 2. The bitwidth of the operands / results must be <= 64.
-static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
- for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
- Type type = value.getType();
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- type = vecTy.getElementType();
- }
- if (!type.isIntOrFloat()) {
- return rewriter.notifyMatchFailure(
- op, "only integers and floats (or vectors thereof) are supported");
- }
- if (type.getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
- return success();
-}
-
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
BinaryArithOpToAPFloatConversion(MLIRContext *context,
const char *APFloatName,
SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- APFloatName(APFloatName) {};
+ APFloatName(APFloatName), sourceTypes(sourceTypes) {};
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -170,17 +104,19 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
const char *APFloatName;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct FpToFpConversion final : OpRewritePattern<OpTy> {
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
- PatternBenefit benefit = 1)
- : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
+ ArrayRef<Type> sourceTypes, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -227,18 +163,20 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct FpToIntConversion final : OpRewritePattern<OpTy> {
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
- bool isUnsigned, PatternBenefit benefit = 1)
+ bool isUnsigned, ArrayRef<Type> sourceTypes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- isUnsigned(isUnsigned) {}
+ isUnsigned(isUnsigned), sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -289,18 +227,20 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
bool isUnsigned;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct IntToFpConversion final : OpRewritePattern<OpTy> {
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
- bool isUnsigned, PatternBenefit benefit = 1)
+ bool isUnsigned, ArrayRef<Type> sourceTypes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- isUnsigned(isUnsigned) {}
+ isUnsigned(isUnsigned), sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -361,16 +301,19 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
SymbolOpInterface symTable;
bool isUnsigned;
+ ArrayRef<Type> sourceTypes;
};
struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -512,16 +455,19 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+ : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(arith::NegFOp op,
PatternRewriter &rewriter) const override {
- if (failed(checkPreconditions(rewriter, op)))
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
return failure();
// Get APFloat function from runtime library.
@@ -564,6 +510,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
namespace {
@@ -577,36 +524,46 @@ struct ArithToAPFloatConversionPass final
void ArithToAPFloatConversionPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
- getOperation());
+
+ FailureOr<SmallVector<Type>> sourceTypes =
+ parseSourceTypes(llvm::to_vector(sourceTypeStrs), context);
+ if (failed(sourceTypes))
+ return signalPassFailure();
+
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(
+ context, "add", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
- context, "subtract", getOperation());
+ context, "subtract", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
- context, "multiply", getOperation());
+ context, "multiply", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
- context, "divide", getOperation());
+ context, "divide", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
- context, "remainder", getOperation());
+ context, "remainder", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
- context, "minnum", getOperation());
+ context, "minnum", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
- context, "maxnum", getOperation());
+ context, "maxnum", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
- context, "minimum", getOperation());
+ context, "minimum", getOperation(), *sourceTypes);
patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
- context, "maximum", getOperation());
+ context, "maximum", getOperation(), *sourceTypes);
patterns
.add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
- context, getOperation());
+ context, getOperation(), *sourceTypes);
patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
- /*isUnsigned=*/false);
+ /*isUnsigned=*/false,
+ *sourceTypes);
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
- /*isUnsigned=*/true);
+ /*isUnsigned=*/true,
+ *sourceTypes);
patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
- /*isUnsigned=*/false);
+ /*isUnsigned=*/false,
+ *sourceTypes);
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
- /*isUnsigned=*/true);
+ /*isUnsigned=*/true,
+ *sourceTypes);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
index 784028f5cf2eb..b5e15e5c42bed 100644
--- a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -28,21 +28,15 @@ using namespace mlir::func;
struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
- : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable),
+ sourceTypes(sourceTypes) {}
LogicalResult matchAndRewrite(math::AbsFOp op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -52,49 +46,50 @@ struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
return fn;
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, operand));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
- SymbolRefAttr::get(*fn), params)
- ->getResult(0);
-
- // Truncate result to the original width.
- Value truncatedBits =
- arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ // Scalarize and convert to APFloat runtime calls.
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand, Value, Type resultType) {
+ auto floatTy = cast<FloatType>(operand.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+ // Truncate result to the original width.
+ auto truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+
+ rewriter.replaceOp(op, repl);
return success();
}
SymbolOpInterface symTable;
+ ArrayRef<Type> sourceTypes;
};
template <typename OpTy>
struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
SymbolOpInterface symTable,
+ ArrayRef<Type> sourceTypes,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
- APFloatName(APFloatName) {};
+ APFloatName(APFloatName), sourceTypes(sourceTypes) {};
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- // Cast operands to 64-bit integers.
- auto operand = op.getOperand();
- auto floatTy = dyn_cast<FloatType>(operand.getType());
- if (!floatTy)
- return rewriter.notifyMatchFailure(op,
- "only scalar FloatTypes supported");
- if (floatTy.getIntOrFloatBitWidth() > 64) {
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
- }
+ if (failed(checkPreconditions(rewriter, op, sourceTypes)))
+ return failure();
// Get APFloat function from runtime library.
auto i1 = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172715
More information about the Mlir-commits
mailing list