[Mlir-commits] [mlir] [mlir][linalg] convert arith ops to destination-passing-style. (PR #157854)

Javed Absar llvmlistbot at llvm.org
Sun Sep 14 04:27:59 PDT 2025


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/157854

>From 846f3f7802de4c691457998fd150e9849f6eef50 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sun, 31 Aug 2025 18:52:07 -0400
Subject: [PATCH 1/4] [mlir][linalg] convert arith ops to
 destination-passing-style.

Converts arith ops that operate on tensors but are not in destination
passing style (DPS) to equivalent linalg generic which is in DPS.
This new pass `linalg-convert-to-dps` has general use, but specifically
is useful for loewr-quant-ops which operate on tensors and ops like qcast
generates arith ops on tensors which without dps cannot bufferize.

 e.g.
    `%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>`
    gets rewritten as:

    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
    %0 = tensor.empty(%dim) : tensor<?xf32>
    %1 = linalg.generic
          {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
          ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
           ^bb0(%in: i32, %out: f32):
             %2 = arith.uitofp %in : i32 to f32
             linalg.yield %2 : f32
    } -> tensor<?xf32>
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  26 +++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  17 +++
 .../TransformOps/LinalgTransformOps.cpp       |  10 +-
 .../Transforms/ConvertToDestinationStyle.cpp  | 105 +++++++++++++++++-
 ...-rewrite-in-destination-passing-style.mlir |  61 ++++++++++
 .../Dialect/Quant/lower-quant-ops-to-dps.mlir |  26 +++++
 6 files changed, 240 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..365356d3c7d6b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
+  let summary = "Convert ops to destination-passing-style";
+  let description = [{
+    Converts ops that operate on tensors but are not in
+    destination passing style (DPS) to equivalent linalg
+    generic which is in DPS. e.g.
+    ```mlir
+      %0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
+    ```
+    gets rewritten as:
+    ```mlir
+       %c0 = arith.constant 0 : index
+       %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+       %0 = tensor.empty(%dim) : tensor<?xf32>
+       %1 = linalg.generic
+           {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+           ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+           ^bb0(%in: i32, %out: f32):
+             %2 = arith.uitofp %in : i32 to f32
+             linalg.yield %2 : f32
+       } -> tensor<?xf32>
+     ```
+   }];
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0cfc8821c0add..c0d492cf69492 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
                                                         tensor::PadOp padOp);
 
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::UIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::SIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::FPToUIOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::FPToSIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::AddIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::AddFOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::DivFOp op);
+
 /// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
 /// and linalg.matmul.
 ///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..b150dc084aaa7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
 /// pattern failed to apply. Extra arguments are forwarded to the pattern
 /// constructor.
 template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
   // Check if the given operation has the type expected by the pattern.
-  using OpTy = typename llvm::function_traits<
-      decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+  using OpTy = typename llvm::function_traits<decltype(
+      &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
   auto op = dyn_cast<OpTy>(operation);
   if (!op)
     return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
   rewriter.setInsertionPoint(target);
   FailureOr<Operation *> maybeResult =
       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
+                arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
+                arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
               [&rewriter](auto op) {
                 return rewriteInDestinationPassingStyle(rewriter, op);
               });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4f0e9cf..79f44ff87b3f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -17,13 +17,22 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-convert-to-dps"
+
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   RankedTensorType resultType = padOp.getResultType();
 
-  // Examine the yielded value to decide if a linalg.generic is neede or a
+  // Examine the yielded value to decide if a linalg.generic is needed or a
   // linalg.fill is sufficient.
   Value yieldedValue =
       cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
 }
 
 namespace {
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+  // reject ops such as `arith.constant` and `arith.select`.
+  auto numOperands = op->getNumOperands();
+  if (numOperands == 0 || numOperands > 2)
+    return failure();
+
+  // destination passing style rewrite is only for ops on tensor types.
+  Type resultType = op->getResult(0).getType();
+  auto tensorType = dyn_cast<RankedTensorType>(resultType);
+  if (!tensorType)
+    return failure();
+
+  auto loc = op.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+  // Create tensor.empty.
+  Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+  // Create linalg.generic
+  auto rank = tensorType.getRank();
+  SmallVector<AffineMap> indexingMaps(numOperands + 1,
+                                      rewriter.getMultiDimIdentityMap(rank));
+  SmallVector<utils::IteratorType> iteratorTypes(rank,
+                                                 utils::IteratorType::parallel);
+  auto genericOp = linalg::GenericOp::create(
+      rewriter, loc, tensorType,
+      op->getOperands(), // inputs
+      ValueRange{empty}, // outputs
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &builder, Location loc, ValueRange args) {
+        Value res;
+        if (args.size() == 2) {
+          res =
+              builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+                  .getResult();
+        } else if (args.size() == 3) {
+          res = builder.create<OpTy>(loc, args[2].getType(),
+                                     ValueRange{args[0], args[1]});
+        } else
+          llvm_unreachable("did not expect ops other than nary and binary");
+        linalg::YieldOp::create(builder, loc, res);
+      });
+
+  rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
+  rewriter.eraseOp(op);
+  return genericOp.getOperation();
+}
 
 template <typename OpTy>
 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
 
 } // namespace
 
+#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY)                                        \
+  FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle(             \
+      RewriterBase &rewriter, OPTY op) {                                       \
+    return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op);          \
+  }
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
+
 void linalg::populateConvertToDestinationStylePatterns(
     RewritePatternSet &patterns) {
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
 }
+
+namespace {
+struct LinalgConvertToDPSPass
+    : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
+  using impl::LinalgConvertToDPSPassBase<
+      LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
+
+  void runOnOperation() override;
+};
+
+void LinalgConvertToDPSPass::runOnOperation() {
+
+  RewritePatternSet patterns(&getContext());
+  linalg::populateConvertToDestinationStylePatterns(patterns);
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+    signalPassFailure();
+}
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 63c9f1f27517b..a1df34c6555f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+//  CHECK-SAME:   %[[X:.+]]: tensor<64xi32>) ->  tensor<64xf32> {
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+  %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+  return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+//  CHECK-SAME:   %[[X:.+]]: tensor<?xf32>,  %[[Y:.+]]: tensor<?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+    -> tensor<?xf32> {
+  %z = arith.addf %x, %y : tensor<?xf32>
+  return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
new file mode 100644
index 0000000000000..0fc9f1e3ed9be
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
+// RUN:             -linalg-specialize-generic-ops -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @lower_qcast_to_dps(
+// CHECK-SAME:    %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK-DAG:     %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
+// CHECK-DAG:     %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
+// CHECK:         %[[E:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK:         %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK-SAME:                             outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
+//
+// CHECK:         %[[SITOFP:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
+// CHECK:            %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
+//
+// CHECK:         %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK:         %{{.*}} = linalg.generic
+// CHECK-SAME:       ins(%[[ADD]] : tensor<10xf32>)
+// CHECK:            %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
+
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
+  return %0 : tensor<10x!qalias>
+}

>From 348e8aa207b6198f25a20ebceb74623584dd03b8 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 13 Sep 2025 06:15:30 -0400
Subject: [PATCH 2/4] address review comments

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    | 16 ++++++-------
 .../TransformOps/LinalgTransformOps.cpp       |  6 ++---
 .../Transforms/ConvertToDestinationStyle.cpp  | 23 ++++++++++++++-----
 3 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 365356d3c7d6b..8deb208573203 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -182,17 +182,17 @@ def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
     ```
     gets rewritten as:
     ```mlir
-       %c0 = arith.constant 0 : index
-       %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
-       %0 = tensor.empty(%dim) : tensor<?xf32>
-       %1 = linalg.generic
-           {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
-           ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+      %c0 = arith.constant 0 : index
+      %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+      %0 = tensor.empty(%dim) : tensor<?xf32>
+      %1 = linalg.generic
+             {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+             ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
            ^bb0(%in: i32, %out: f32):
              %2 = arith.uitofp %in : i32 to f32
              linalg.yield %2 : f32
-       } -> tensor<?xf32>
-     ```
+      } -> tensor<?xf32>
+    ```
    }];
   let dependentDialects = ["linalg::LinalgDialect"];
 }
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b150dc084aaa7..94531ff854593 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
 /// pattern failed to apply. Extra arguments are forwarded to the pattern
 /// constructor.
 template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
   // Check if the given operation has the type expected by the pattern.
-  using OpTy = typename llvm::function_traits<decltype(
-      &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+  using OpTy = typename llvm::function_traits<
+      decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
   auto op = dyn_cast<OpTy>(operation);
   if (!op)
     return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 79f44ff87b3f6..6ee5246c7a1f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -31,8 +31,6 @@ namespace mlir {
 #include "mlir/Dialect/Linalg/Passes.h.inc"
 } // namespace mlir
 
-#define DEBUG_TYPE "linalg-convert-to-dps"
-
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -612,10 +610,23 @@ Value linalg::bufferizeToAllocation(
 }
 
 namespace {
+/// Rewrites an arith op operating on tensors, e.g.
+///  `%z = arith.addf %x, %y : tensor<5xf32>`
+/// into an equivalent linalg.generic in destination-passing-style.
+/// ```mlir
+/// %0 = tensor.empty() : tensor<5xf32>
+/// %1 = linalg.generic ...
+///        ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
+///        outs(%0 : tensor<5xf32>) {
+///      ^bb0(%in: f32, %in_0: f32, %out: f32):
+///         %2 = arith.addf %in, %in_0 : f32
+///         linalg.yield %2 : f32
+///     } -> tensor<5xf32>
 template <typename OpTy>
 FailureOr<Operation *>
 rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
-  // reject ops such as `arith.constant` and `arith.select`.
+  // Reject ops such as `arith.constant` and `arith.select`.
+  // constants don't need dps conversion and select is a a `todo`.
   auto numOperands = op->getNumOperands();
   if (numOperands == 0 || numOperands > 2)
     return failure();
@@ -630,8 +641,8 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
   OpBuilder::InsertionGuard g(rewriter);
   auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
 
-  // Create tensor.empty.
-  Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+  // Create tensor.empty for `outs` of destination-passing-style.
+  Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
 
   // Create linalg.generic
   auto rank = tensorType.getRank();
@@ -642,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
   auto genericOp = linalg::GenericOp::create(
       rewriter, loc, tensorType,
       op->getOperands(), // inputs
-      ValueRange{empty}, // outputs
+      ValueRange{outs}, // outputs
       indexingMaps, iteratorTypes,
       [&](OpBuilder &builder, Location loc, ValueRange args) {
         Value res;

>From 22e02b353cfc840b17cdb280f768119fa75f6b9b Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 13 Sep 2025 09:52:38 -0400
Subject: [PATCH 3/4] fix clang-format

---
 .../lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6ee5246c7a1f7..3fec9b8e62cf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -653,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
   auto genericOp = linalg::GenericOp::create(
       rewriter, loc, tensorType,
       op->getOperands(), // inputs
-      ValueRange{outs}, // outputs
+      ValueRange{outs},  // outputs
       indexingMaps, iteratorTypes,
       [&](OpBuilder &builder, Location loc, ValueRange args) {
         Value res;

>From b6c4d4d675994c4017dfee96a8d972d2893b67ec Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sun, 14 Sep 2025 07:27:13 -0400
Subject: [PATCH 4/4] add fastmath propagation

---
 .../Transforms/ConvertToDestinationStyle.cpp  | 35 ++++++++++++++++---
 ...-rewrite-in-destination-passing-style.mlir | 32 +++++++++++++++++
 2 files changed, 62 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 3fec9b8e62cf3..3fa921e5d3e0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -650,6 +650,11 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
                                       rewriter.getMultiDimIdentityMap(rank));
   SmallVector<utils::IteratorType> iteratorTypes(rank,
                                                  utils::IteratorType::parallel);
+
+  // Check 'fast-math'. If present, propagate it.
+  auto fmfOpInterface =
+      llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation());
+
   auto genericOp = linalg::GenericOp::create(
       rewriter, loc, tensorType,
       op->getOperands(), // inputs
@@ -658,12 +663,32 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
       [&](OpBuilder &builder, Location loc, ValueRange args) {
         Value res;
         if (args.size() == 2) {
-          res =
-              builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
-                  .getResult();
+          if (fmfOpInterface) {
+            auto attr = fmfOpInterface.getFastMathFlagsAttr();
+            auto fmf = rewriter.getNamedAttr("fastmath", attr);
+            res = builder
+                      .create<OpTy>(loc, args[1].getType(), ValueRange{args[0]},
+                                    fmf)
+                      .getResult();
+          } else {
+            res = builder
+                      .create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+                      .getResult();
+          }
         } else if (args.size() == 3) {
-          res = builder.create<OpTy>(loc, args[2].getType(),
-                                     ValueRange{args[0], args[1]});
+          if (fmfOpInterface) {
+            auto attr = fmfOpInterface.getFastMathFlagsAttr();
+            auto fmf = rewriter.getNamedAttr("fastmath", attr);
+            res = builder
+                      .create<OpTy>(loc, args[2].getType(),
+                                    ValueRange{args[0], args[1]}, fmf)
+                      .getResult();
+          } else {
+            res = builder
+                      .create<OpTy>(loc, args[2].getType(),
+                                    ValueRange{args[0], args[1]})
+                      .getResult();
+          }
         } else
           llvm_unreachable("did not expect ops other than nary and binary");
         linalg::YieldOp::create(builder, loc, res);
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index a1df34c6555f2..f22a4120e18fc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -313,3 +313,35 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop_fastmath(
+//  CHECK-SAME:   %[[X:.+]]: tensor<?xf32>,  %[[Y:.+]]: tensor<?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop_fastmath(%x : tensor<?xf32>, %y : tensor<?xf32>)
+    -> tensor<?xf32> {
+  %z = arith.addf %x, %y fastmath<fast> : tensor<?xf32>
+  return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+     transform.yield
+  }
+}



More information about the Mlir-commits mailing list