[Mlir-commits] [mlir] 17723e4 - [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_forall (#157932)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 10 12:49:28 PDT 2025


Author: Bangtian Liu
Date: 2025-09-10T15:49:24-04:00
New Revision: 17723e472e228be5404ab4377498b52a0c5db03b

URL: https://github.com/llvm/llvm-project/commit/17723e472e228be5404ab4377498b52a0c5db03b
DIFF: https://github.com/llvm/llvm-project/commit/17723e472e228be5404ab4377498b52a0c5db03b.diff

LOG: [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_forall (#157932)

Following [PR
#120118](https://github.com/llvm/llvm-project/pull/120118), this PR
extends transform.structured.tile_reduction_using_forall so that it can
be applied to any operation implementing `PartialReductionOpInterface`,
rather than being restricted to LinalgOp.

Existing tests relevant to linalg ops remain valid:
https://github.com/llvm/llvm-project/blob/2a2296b1aab4614bf6c95c3003000832c9d43de5/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir#L114

Additional tests for non-Linalg operations (e.g., IREE custom ops that
implement `PartialReductionOpInterface`) will be added on the IREE side.

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..8f3232f01544f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp :
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
-                      TransformHandleTypeInterface:$split_linalg_op,
-                      TransformHandleTypeInterface:$combining_linalg_op,
+                      TransformHandleTypeInterface:$split_op,
+                      TransformHandleTypeInterface:$combining_op,
                       TransformHandleTypeInterface:$forall_op);
 
   let builders = [
@@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp :
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
+        Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..f3db8f7ccfaa1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
 }
 
 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgOp target,
+    transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
+
+  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+  if (!partialReductionOp) {
+    return emitSilenceableFailure(
+        target->getLoc(),
+        "Operation should implement PartialReductionOpInterface");
+  }
   SmallVector<OpFoldResult> numThreads =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
   SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
       extractFromIntegerArrayAttr<unsigned>(getReductionDims());
   if (reductionDims.empty()) {
     for (auto [idx, iteratorType] :
-         llvm::enumerate(target.getIteratorTypesArray())) {
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
       if (iteratorType == utils::IteratorType::reduction)
         reductionDims.push_back(idx);
     }
   }
   options.setReductionDims(reductionDims);
-  FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
-      rewriter, cast<TilingInterface>(target.getOperation()), options);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
 
   if (failed(result)) {
     auto diag = emitSilenceableError() << "could not tile reduction";


        


More information about the Mlir-commits mailing list