[Mlir-commits] [mlir] [mlir][linalg] convert arith ops to destination-passing-style. (PR #157854)
Javed Absar
llvmlistbot at llvm.org
Sun Sep 14 04:27:59 PDT 2025
https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/157854
>From 846f3f7802de4c691457998fd150e9849f6eef50 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sun, 31 Aug 2025 18:52:07 -0400
Subject: [PATCH 1/4] [mlir][linalg] convert arith ops to
destination-passing-style.
Converts arith ops that operate on tensors but are not in destination
passing style (DPS) to equivalent linalg generic which is in DPS.
This new pass `linalg-convert-to-dps` has general use, but specifically
is useful for loewr-quant-ops which operate on tensors and ops like qcast
generates arith ops on tensors which without dps cannot bufferize.
e.g.
`%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>`
gets rewritten as:
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?xi32>
%0 = tensor.empty(%dim) : tensor<?xf32>
%1 = linalg.generic
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
^bb0(%in: i32, %out: f32):
%2 = arith.uitofp %in : i32 to f32
linalg.yield %2 : f32
} -> tensor<?xf32>
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 26 +++++
.../Dialect/Linalg/Transforms/Transforms.h | 17 +++
.../TransformOps/LinalgTransformOps.cpp | 10 +-
.../Transforms/ConvertToDestinationStyle.cpp | 105 +++++++++++++++++-
...-rewrite-in-destination-passing-style.mlir | 61 ++++++++++
.../Dialect/Quant/lower-quant-ops-to-dps.mlir | 26 +++++
6 files changed, 240 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..365356d3c7d6b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
+ let summary = "Convert ops to destination-passing-style";
+ let description = [{
+ Converts ops that operate on tensors but are not in
+ destination passing style (DPS) to equivalent linalg
+ generic which is in DPS. e.g.
+ ```mlir
+ %0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
+ ```
+ gets rewritten as:
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+ %0 = tensor.empty(%dim) : tensor<?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+ ^bb0(%in: i32, %out: f32):
+ %2 = arith.uitofp %in : i32 to f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ ```
+ }];
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0cfc8821c0add..c0d492cf69492 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
tensor::PadOp padOp);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::UIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::SIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::FPToUIOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::FPToSIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::AddIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::AddFOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::DivFOp op);
+
/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
/// and linalg.matmul.
///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..b150dc084aaa7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
// Check if the given operation has the type expected by the pattern.
- using OpTy = typename llvm::function_traits<
- decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+ using OpTy = typename llvm::function_traits<decltype(
+ &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+ .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
+ arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
+ arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4f0e9cf..79f44ff87b3f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -17,13 +17,22 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-convert-to-dps"
+
using namespace mlir;
using namespace mlir::tensor;
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
RankedTensorType resultType = padOp.getResultType();
- // Examine the yielded value to decide if a linalg.generic is neede or a
+ // Examine the yielded value to decide if a linalg.generic is needed or a
// linalg.fill is sufficient.
Value yieldedValue =
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
}
namespace {
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+ // reject ops such as `arith.constant` and `arith.select`.
+ auto numOperands = op->getNumOperands();
+ if (numOperands == 0 || numOperands > 2)
+ return failure();
+
+ // destination passing style rewrite is only for ops on tensor types.
+ Type resultType = op->getResult(0).getType();
+ auto tensorType = dyn_cast<RankedTensorType>(resultType);
+ if (!tensorType)
+ return failure();
+
+ auto loc = op.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+ // Create tensor.empty.
+ Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+ // Create linalg.generic
+ auto rank = tensorType.getRank();
+ SmallVector<AffineMap> indexingMaps(numOperands + 1,
+ rewriter.getMultiDimIdentityMap(rank));
+ SmallVector<utils::IteratorType> iteratorTypes(rank,
+ utils::IteratorType::parallel);
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, tensorType,
+ op->getOperands(), // inputs
+ ValueRange{empty}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ Value res;
+ if (args.size() == 2) {
+ res =
+ builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+ .getResult();
+ } else if (args.size() == 3) {
+ res = builder.create<OpTy>(loc, args[2].getType(),
+ ValueRange{args[0], args[1]});
+ } else
+ llvm_unreachable("did not expect ops other than nary and binary");
+ linalg::YieldOp::create(builder, loc, res);
+ });
+
+ rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
+ rewriter.eraseOp(op);
+ return genericOp.getOperation();
+}
template <typename OpTy>
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
} // namespace
+#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \
+ FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \
+ RewriterBase &rewriter, OPTY op) { \
+ return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
+ }
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
+
void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
}
+
+namespace {
+struct LinalgConvertToDPSPass
+ : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
+ using impl::LinalgConvertToDPSPassBase<
+ LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgConvertToDPSPass::runOnOperation() {
+
+ RewritePatternSet patterns(&getContext());
+ linalg::populateConvertToDestinationStylePatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 63c9f1f27517b..a1df34c6555f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+ %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+ return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+ -> tensor<?xf32> {
+ %z = arith.addf %x, %y : tensor<?xf32>
+ return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
new file mode 100644
index 0000000000000..0fc9f1e3ed9be
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
+// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @lower_qcast_to_dps(
+// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
+// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
+// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
+//
+// CHECK: %[[SITOFP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
+// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
+//
+// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK: %{{.*}} = linalg.generic
+// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>)
+// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
+
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
+ return %0 : tensor<10x!qalias>
+}
>From 348e8aa207b6198f25a20ebceb74623584dd03b8 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 13 Sep 2025 06:15:30 -0400
Subject: [PATCH 2/4] address review comments
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 16 ++++++-------
.../TransformOps/LinalgTransformOps.cpp | 6 ++---
.../Transforms/ConvertToDestinationStyle.cpp | 23 ++++++++++++++-----
3 files changed, 28 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 365356d3c7d6b..8deb208573203 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -182,17 +182,17 @@ def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
```
gets rewritten as:
```mlir
- %c0 = arith.constant 0 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
- %0 = tensor.empty(%dim) : tensor<?xf32>
- %1 = linalg.generic
- {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
- ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+ %0 = tensor.empty(%dim) : tensor<?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
^bb0(%in: i32, %out: f32):
%2 = arith.uitofp %in : i32 to f32
linalg.yield %2 : f32
- } -> tensor<?xf32>
- ```
+ } -> tensor<?xf32>
+ ```
}];
let dependentDialects = ["linalg::LinalgDialect"];
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b150dc084aaa7..94531ff854593 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
// Check if the given operation has the type expected by the pattern.
- using OpTy = typename llvm::function_traits<decltype(
- &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+ using OpTy = typename llvm::function_traits<
+ decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 79f44ff87b3f6..6ee5246c7a1f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -31,8 +31,6 @@ namespace mlir {
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
-#define DEBUG_TYPE "linalg-convert-to-dps"
-
using namespace mlir;
using namespace mlir::tensor;
@@ -612,10 +610,23 @@ Value linalg::bufferizeToAllocation(
}
namespace {
+/// Rewrites an arith op operating on tensors, e.g.
+/// `%z = arith.addf %x, %y : tensor<5xf32>`
+/// into an equivalent linalg.generic in destination-passing-style.
+/// ```mlir
+/// %0 = tensor.empty() : tensor<5xf32>
+/// %1 = linalg.generic ...
+/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
+/// outs(%0 : tensor<5xf32>) {
+/// ^bb0(%in: f32, %in_0: f32, %out: f32):
+/// %2 = arith.addf %in, %in_0 : f32
+/// linalg.yield %2 : f32
+/// } -> tensor<5xf32>
template <typename OpTy>
FailureOr<Operation *>
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
- // reject ops such as `arith.constant` and `arith.select`.
+ // Reject ops such as `arith.constant` and `arith.select`.
+ // constants don't need dps conversion and select is a a `todo`.
auto numOperands = op->getNumOperands();
if (numOperands == 0 || numOperands > 2)
return failure();
@@ -630,8 +641,8 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
OpBuilder::InsertionGuard g(rewriter);
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
- // Create tensor.empty.
- Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+ // Create tensor.empty for `outs` of destination-passing-style.
+ Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
// Create linalg.generic
auto rank = tensorType.getRank();
@@ -642,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, tensorType,
op->getOperands(), // inputs
- ValueRange{empty}, // outputs
+ ValueRange{outs}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value res;
>From 22e02b353cfc840b17cdb280f768119fa75f6b9b Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 13 Sep 2025 09:52:38 -0400
Subject: [PATCH 3/4] fix clang-format
---
.../lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6ee5246c7a1f7..3fec9b8e62cf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -653,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, tensorType,
op->getOperands(), // inputs
- ValueRange{outs}, // outputs
+ ValueRange{outs}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value res;
>From b6c4d4d675994c4017dfee96a8d972d2893b67ec Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sun, 14 Sep 2025 07:27:13 -0400
Subject: [PATCH 4/4] add fastmath propagation
---
.../Transforms/ConvertToDestinationStyle.cpp | 35 ++++++++++++++++---
...-rewrite-in-destination-passing-style.mlir | 32 +++++++++++++++++
2 files changed, 62 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 3fec9b8e62cf3..3fa921e5d3e0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -650,6 +650,11 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType> iteratorTypes(rank,
utils::IteratorType::parallel);
+
+ // Check 'fast-math'. If present, propagate it.
+ auto fmfOpInterface =
+ llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation());
+
auto genericOp = linalg::GenericOp::create(
rewriter, loc, tensorType,
op->getOperands(), // inputs
@@ -658,12 +663,32 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value res;
if (args.size() == 2) {
- res =
- builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
- .getResult();
+ if (fmfOpInterface) {
+ auto attr = fmfOpInterface.getFastMathFlagsAttr();
+ auto fmf = rewriter.getNamedAttr("fastmath", attr);
+ res = builder
+ .create<OpTy>(loc, args[1].getType(), ValueRange{args[0]},
+ fmf)
+ .getResult();
+ } else {
+ res = builder
+ .create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+ .getResult();
+ }
} else if (args.size() == 3) {
- res = builder.create<OpTy>(loc, args[2].getType(),
- ValueRange{args[0], args[1]});
+ if (fmfOpInterface) {
+ auto attr = fmfOpInterface.getFastMathFlagsAttr();
+ auto fmf = rewriter.getNamedAttr("fastmath", attr);
+ res = builder
+ .create<OpTy>(loc, args[2].getType(),
+ ValueRange{args[0], args[1]}, fmf)
+ .getResult();
+ } else {
+ res = builder
+ .create<OpTy>(loc, args[2].getType(),
+ ValueRange{args[0], args[1]})
+ .getResult();
+ }
} else
llvm_unreachable("did not expect ops other than nary and binary");
linalg::YieldOp::create(builder, loc, res);
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index a1df34c6555f2..f22a4120e18fc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -313,3 +313,35 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop_fastmath(
+// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop_fastmath(%x : tensor<?xf32>, %y : tensor<?xf32>)
+ -> tensor<?xf32> {
+ %z = arith.addf %x, %y fastmath<fast> : tensor<?xf32>
+ return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list