[Mlir-commits] [mlir] [mlir][arith] `arith-to-apfloat`: Add vector support (PR #171024)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Dec 7 02:03:00 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add support for vectorized operations such as `arith.addf ... : vector<4xf4E2M1FN>`. The computation is scalarized: scalar operands are extracted with `vector.to_elements`, multiple scalar computations are performed and the result is inserted back into a vector with `vector.from_elements`.
---
Patch is 35.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171024.diff
4 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+2-1)
- (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+364-264)
- (modified) mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt (+1)
- (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+39)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 75ab4b64b7f38..fcbaf3ccc1486 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,7 +198,8 @@ def ArithToAPFloatConversionPass
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
- let dependentDialects = ["func::FuncDialect"];
+ let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
+ "vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 4776ba0f49b94..e18316eae486b 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -90,6 +91,75 @@ static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
b.getIntegerAttr(b.getI32Type(), sem));
}
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ auto sclars1 = vector::ToElementsOp::create(rewriter, loc, operand1);
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip(sclars1->getResults(), scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ SmallVector<Value> values;
+ llvm::append_range(values, op->getOperands());
+ llvm::append_range(values, op->getResults());
+ for (Value value : values) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
+
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -102,9 +172,8 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
FailureOr<FuncOp> fn =
@@ -112,31 +181,37 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
-
- // Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(resultType);
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -152,10 +227,8 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -165,30 +238,36 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inFloatTy = cast<FloatType>(op.getOperand().getType());
- auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
-
- // Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
- auto outFloatTy = cast<FloatType>(op.getType());
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
- std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -204,10 +283,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -219,33 +296,39 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inFloatTy = cast<FloatType>(op.getOperand().getType());
- auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
-
- // Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
- auto outIntTy = cast<IntegerType>(op.getType());
- Value outWidthValue = arith::ConstantOp::create(
- rewriter, loc, i32Type,
- rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
- Value isUnsignedValue = arith::ConstantOp::create(
- rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
- SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
- operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy,
- resultOp->getResult(0));
- rewriter.replaceOp(op, truncatedBits);
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outIntTy = cast<IntegerType>(resultType);
+ Value outWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {inSemValue, outWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ return arith::TruncIOp::create(rewriter, loc, outIntTy,
+ resultOp->getResult(0));
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -262,10 +345,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -277,42 +358,48 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inIntTy = cast<IntegerType>(op.getOperand().getType());
- Value operandBits = op.getOperand();
- if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
- if (isUnsigned) {
- operandBits =
- arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
- } else {
- operandBits =
- arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
- }
- }
-
- // Call APFloat function.
- auto outFloatTy = cast<FloatType>(op.getType());
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
- Value inWidthValue = arith::ConstantOp::create(
- rewriter, loc, i32Type,
- rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
- Value isUnsignedValue = arith::ConstantOp::create(
- rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
- SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
- operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
- resultOp->getResult(0));
- Value result =
- arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
- rewriter.replaceOp(op, result);
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inIntTy = cast<IntegerType>(operand1.getType());
+ Value operandBits = operand1;
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
+ }
+ }
+
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -327,9 +414,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
- if (op.getLhs().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/171024
More information about the Mlir-commits
mailing list