[Mlir-commits] [mlir] [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_forall (PR #157932)

Bangtian Liu llvmlistbot at llvm.org
Wed Sep 10 11:50:35 PDT 2025


https://github.com/bangtianliu created https://github.com/llvm/llvm-project/pull/157932

None

>From 915facfd21a0664ebdfbb587531474d3554fe2fb Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Wed, 10 Sep 2025 11:51:45 -0700
Subject: [PATCH] [mlir][Linalg] Allow PartialReductionOpInterface ops in
 tile_reduction_using_forall

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 .../Linalg/TransformOps/LinalgTransformOps.td     |  6 +++---
 .../Linalg/TransformOps/LinalgTransformOps.cpp    | 15 +++++++++++----
 2 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..8f3232f01544f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp :
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
-                      TransformHandleTypeInterface:$split_linalg_op,
-                      TransformHandleTypeInterface:$combining_linalg_op,
+                      TransformHandleTypeInterface:$split_op,
+                      TransformHandleTypeInterface:$combining_op,
                       TransformHandleTypeInterface:$forall_op);
 
   let builders = [
@@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp :
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
+        Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..f3db8f7ccfaa1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
 }
 
 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgOp target,
+    transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
+
+  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+  if (!partialReductionOp) {
+    return emitSilenceableFailure(
+        target->getLoc(),
+        "Operation should implement PartialReductionOpInterface");
+  }
   SmallVector<OpFoldResult> numThreads =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
   SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
       extractFromIntegerArrayAttr<unsigned>(getReductionDims());
   if (reductionDims.empty()) {
     for (auto [idx, iteratorType] :
-         llvm::enumerate(target.getIteratorTypesArray())) {
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
       if (iteratorType == utils::IteratorType::reduction)
         reductionDims.push_back(idx);
     }
   }
   options.setReductionDims(reductionDims);
-  FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
-      rewriter, cast<TilingInterface>(target.getOperation()), options);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
 
   if (failed(result)) {
     auto diag = emitSilenceableError() << "could not tile reduction";



More information about the Mlir-commits mailing list