[Mlir-commits] [mlir] 23a057f - [mlir][Transform] NFC - Return omitted loop construct in transform.tile_reduction_xxx ops

Nicolas Vasilache llvmlistbot at llvm.org
Mon Dec 12 02:14:08 PST 2022


Author: Nicolas Vasilache
Date: 2022-12-12T02:14:00-08:00
New Revision: 23a057fbc423ce3f277efc54631a8a9afce85081

URL: https://github.com/llvm/llvm-project/commit/23a057fbc423ce3f277efc54631a8a9afce85081
DIFF: https://github.com/llvm/llvm-project/commit/23a057fbc423ce3f277efc54631a8a9afce85081.diff

LOG: [mlir][Transform] NFC - Return omitted loop construct in transform.tile_reduction_xxx ops

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e6bfa0841fa5e..9fe6536f23f62 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -667,7 +667,8 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
     #### Return modes
 
-    This 3 returned handles point to:
+    This 4 returned handles point to:
+      - the parent for op,
       - the fill op used to initialize the neutral element,
       - the parallel tiled op and
       - the result-combining op.
@@ -722,7 +723,8 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
 
   let arguments = (ins PDL_Operation:$target,
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
-  let results = (outs PDL_Operation:$fill_op,
+  let results = (outs PDL_Operation:$for_op,
+                      PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);
 
@@ -756,7 +758,8 @@ def TileReductionUsingForeachThreadOp :
 
     #### Return modes
 
-    This 3 returned handles point to:
+    This 4 returned handles point to:
+      - the parent foreach_thread op,
       - the fill op used to initialize the neutral element,
       - the parallel tiled op and
       - the result-combining op.
@@ -809,7 +812,8 @@ def TileReductionUsingForeachThreadOp :
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
-  let results = (outs PDL_Operation:$fill_op,
+  let results = (outs PDL_Operation:$foreach_thread_op,
+                      PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d35d96ac4310e..3ae4163b5cc6a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1181,7 +1181,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
 }
 
 //===----------------------------------------------------------------------===//
-// SplitReductionOp
+// TileReductionUsingScfOp
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
@@ -1201,6 +1201,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
 
   if (failed(result))
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+  results.push_back(result->loops.front());
   results.push_back(result->initialOp);
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
@@ -1230,6 +1231,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
     diag << "could not tile reduction in target.";
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
+  results.push_back(result->loops);
   results.push_back(result->initialOp);
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index cd0d6d71113cc..13aec82b10a44 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -17,7 +17,7 @@ func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] }
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -71,7 +71,7 @@ func.func @reduction_tile_transpose(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>)
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] }
 }
 
 //     CHECK: func @reduction_tile_transpose
@@ -107,7 +107,7 @@ func.func @reduction_tile_parallel(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] }
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
@@ -159,7 +159,7 @@ func.func @matmul_tile_parallel(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] }
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
@@ -218,7 +218,7 @@ func.func @reduction_tile_parallel_cyclic_dist(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
     { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
 }
 
@@ -284,7 +284,7 @@ func.func @reduction_tile_parallel_cyclic_dist(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
     { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
   
   //      CHECK:     expecting fill


        


More information about the Mlir-commits mailing list