[Mlir-commits] [mlir] 5ad4213 - [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for (#120118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 27 05:20:02 PST 2024


Author: Kunwar Grover
Date: 2024-12-27T13:19:58Z
New Revision: 5ad4213ef48253a6be1f9880f17555fc36efdd19

URL: https://github.com/llvm/llvm-project/commit/5ad4213ef48253a6be1f9880f17555fc36efdd19
DIFF: https://github.com/llvm/llvm-project/commit/5ad4213ef48253a6be1f9880f17555fc36efdd19.diff

LOG: [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for (#120118)

The API used internally expects PartialReductionOpInterface. This patch
allows any operation implementing this interface to use this transform
op (instead of just LinalgOp).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2e713bca24efc5..081bf9b6d3b239 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   let arguments = (ins TransformHandleTypeInterface:$target,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
-                      TransformHandleTypeInterface:$split_linalg_op,
-                      TransformHandleTypeInterface:$combining_linalg_op,
+                      TransformHandleTypeInterface:$split_op,
+                      TransformHandleTypeInterface:$combining_op,
                       TransformHandleTypeInterface:$for_op);
 
   let builders = [
@@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
+        Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 221ca27b80fdd0..a1d619c8cd19dc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2626,12 +2626,19 @@ void transform::TileReductionUsingForOp::build(
 }
 
 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgOp target,
+    transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
+
+  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+  if (!partialReductionOp) {
+    return emitSilenceableFailure(
+        target->getLoc(),
+        "Operation should implement PartialReductionOpInterface");
+  }
   FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
-      rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+      rewriter, partialReductionOp,
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
   if (failed(result))


        


More information about the Mlir-commits mailing list