[Mlir-commits] [mlir] b99d0b3 - [mlir][TilingInterface] Update `PartialReductionOpInterface` to get it more in line with `TilingInterface`. (#95460)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 18 09:07:33 PDT 2024


Author: MaheshRavishankar
Date: 2024-06-18T09:07:29-07:00
New Revision: b99d0b34400176cb9183113b96b245400caaf8d8

URL: https://github.com/llvm/llvm-project/commit/b99d0b34400176cb9183113b96b245400caaf8d8
DIFF: https://github.com/llvm/llvm-project/commit/b99d0b34400176cb9183113b96b245400caaf8d8.diff

LOG: [mlir][TilingInterface] Update `PartialReductionOpInterface` to get it more in line with `TilingInterface`. (#95460)

The `TilingInterface` methods have return values that allow the
interface implementation to return multiple operations, and also return
tiled values explicitly. This is to avoid the assumption that the
interface needs to return a single operation and this operations result
are the expected tiled values. Make the
`PartialReductionOpInterface::tileToPartialReduction` return
`TilingResult` as well for the same reason.

Similarly make the `PartialReductionOpInterface::mergeReductions` also
return a list of generated operations and values to use as replacements.

This is just a refactoring to allow for deprecation of
`linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF`
method.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Interfaces/TilingInterface.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..05e97befdec1f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
 /// Transformation information returned after reduction tiling.
 struct ForallReductionTilingResult {
   /// The partial reduction tiled op generated.
-  Operation *parallelTiledOp;
+  SmallVector<Operation *> parallelTiledOps;
   /// The final reduction operation merging all the partial reductions.
-  Operation *mergeOp;
+  SmallVector<Operation *> mergeOps;
   /// Initial values used for partial reductions.
   SmallVector<Value> initialValues;
   /// The `scf.forall` operation that iterate over the tiles.

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..6316f1d130d19 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 /// Transformation information returned after reduction tiling.
 struct SCFReductionTilingResult {
   /// The partial reduction tiled op generated.
-  Operation *parallelTiledOp;
+  SmallVector<Operation *> parallelTiledOps;
   /// The final reduction operation merging all the partial reductions.
-  Operation *mergeOp;
+  SmallVector<Operation *> mergeOps;
   /// Initial values used for reduction.
   SmallVector<Value> initialValues;
   /// The loop operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
+  /// The replacements to use for the results of the tiled operation.
+  SmallVector<Value> replacements;
 };
 
 /// Method to tile a reduction and generate a parallel op within a serial loop.

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..2f51496d1b110 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -33,6 +33,15 @@ struct TilingResult {
   SmallVector<Value> tiledValues;
 };
 
+/// Container for the result of merge operation of tiling.
+/// - `mergeOps` contains operations created during the merge.
+/// - `replacements` contains the values that represents the result of the
+/// merge. These are used as replacements for the original tiled operation.
+struct MergeResult {
+  SmallVector<Operation *> mergeOps;
+  SmallVector<Value> replacements;
+};
+
 } // namespace mlir
 
 /// Include the ODS generated interface header files.

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 8865aba3b4ef0..3f927865ccf67 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -360,7 +360,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           less or equal to the tile size. This is meant to be used with
           `mergeReductions` method which will combine the partial reductions.
         }],
-        /*retType=*/"Operation*",
+        /*retType=*/"FailureOr<TilingResult>",
         /*methodName=*/"tileToPartialReduction",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -371,7 +371,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "ArrayRef<int>":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return nullptr;
+          return failure();
         }]
       >,
       InterfaceMethod<
@@ -380,7 +380,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           tiled along the reduction dimensions. This will only apply the
           reduction the operation.
         }],
-        /*retType=*/"Operation*",
+        /*retType=*/"FailureOr<MergeResult>",
         /*methodName=*/"mergeReductions",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -389,7 +389,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "ArrayRef<int>":$reductionDim),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return nullptr;
+          return failure();
         }]
       >
   ];

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..2807b3ce42abd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
     return emitDefaultSilenceableFailure(target);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  results.push_back(result->parallelTiledOp);
-  results.push_back(result->mergeOp);
+  for (auto parallelTiledOp : result->parallelTiledOps)
+    results.push_back(parallelTiledOp);
+  for (auto mergeOp : result->mergeOps)
+    results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
 }
@@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   }
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  results.push_back(result->parallelTiledOp);
-  results.push_back(result->mergeOp);
+  for (auto parallelTiledOp : result->parallelTiledOps)
+    results.push_back(parallelTiledOp);
+  for (auto mergeOp : result->mergeOps)
+    results.push_back(mergeOp);
   results.push_back(result->loops);
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index a0a0e11a6903d..d8dee82237156 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -833,16 +833,19 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(forallOp);
-  Operation *mergeOp =
+  FailureOr<MergeResult> mergeResult =
       op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
-  b.replaceOp(op, mergeOp->getResults());
+  if (failed(mergeResult)) {
+    return failure();
+  }
+  b.replaceOp(op, mergeResult->replacements);
 
   // 8. Return.
   ForallReductionTilingResult results;
   results.initialValues = initTensors;
   results.loops = forallOp;
-  results.parallelTiledOp = tiledOp;
-  results.mergeOp = mergeOp;
+  results.parallelTiledOps.push_back(tiledOp);
+  results.mergeOps.append(mergeResult->mergeOps);
   return results;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..b2a1e7c71f58e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface
     return inits;
   }
 
-  Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
-                                    ValueRange init,
-                                    ArrayRef<OpFoldResult> offsets,
-                                    ArrayRef<OpFoldResult> sizes,
-                                    ArrayRef<int> reductionDims) const {
+  FailureOr<TilingResult>
+  tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+                         ValueRange init, ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         ArrayRef<int> reductionDims) const {
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
@@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
-    return genericOp.getOperation();
+    return TilingResult{
+        {genericOp.getOperation()},
+        llvm::map_to_vector(genericOp->getResults(),
+                            [](OpResult r) -> Value { return r; })};
   }
 
-  Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
-                             ValueRange partialReduce,
-                             ArrayRef<int> reductionDims) const {
+  FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
+                                         Location loc, ValueRange partialReduce,
+                                         ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
 
     // Step 1. Recover the dims that actually need to be merged from the
@@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface
           }
           b.create<linalg::YieldOp>(loc, yieldedValues);
         });
-    return reduction.getOperation();
+    return MergeResult{
+        {reduction.getOperation()},
+        llvm::map_to_vector(reduction->getResults(),
+                            [](OpResult r) -> Value { return r; })};
   }
 };
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..35edd490f72eb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   SmallVector<Value> &initTensors = maybeInitTensors.value();
 
   // 3. Define the callback to use for generating the inner most tile loop body.
-  Operation *parallelOp = nullptr;
+  SmallVector<Operation *> parallelTiledOps;
   auto innerYieldTiledValuesFn =
       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
           ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
@@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
     }
 
     // 4a. Clone the operation.
-    auto clonedOp = cast<PartialReductionOpInterface>(
-        cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+    {
+      auto clonedOp = cast<PartialReductionOpInterface>(
+          cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+      // 4b. Tile the cloned operation.
+      FailureOr<TilingResult> partialTilingResult =
+          clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
+                                          sizes, reductionDims);
+      if (failed(partialTilingResult)) {
+        return failure();
+      }
+      std::swap(parallelTiledOps, partialTilingResult->tiledOps);
+      std::swap(tiledResult, partialTilingResult->tiledValues);
 
-    // 4b. Tile the cloned operation.
-    parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
-                                                 offsets, sizes, reductionDims);
-    // 4c. Delete the cloned operation.
-    b.eraseOp(clonedOp);
+      // 4c. Delete the cloned operation.
+      b.eraseOp(clonedOp);
+    }
 
-    tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
     // 4d. Compute the offsets and sizes needed to insert the result of the
     // tiled value back into destination before yielding the destination.
-    for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
+    for (auto result : tiledResult) {
       SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
       resultOffsets.emplace_back(std::move(outOffsets));
 
       SmallVector<OpFoldResult> outSizes;
       for (size_t i = 0; i < offsets.size(); i++) {
-        outSizes.push_back(
-            tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
+        outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
       }
       resultSizes.emplace_back(std::move(outSizes));
     }
@@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
   // 5. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
-  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
-  b.replaceOp(op, mergeOp->getResults());
-
-  SCFReductionTilingResult results;
-  results.initialValues = initTensors;
-  results.loops = loops;
-  results.parallelTiledOp = parallelOp;
-  results.mergeOp = mergeOp;
-  return results;
+  FailureOr<MergeResult> mergeResult =
+      op.mergeReductions(b, loc, replacements, reductionDims);
+  if (failed(mergeResult)) {
+    return failure();
+  }
+  b.replaceOp(op, mergeResult->replacements);
+
+  SCFReductionTilingResult reductionTilingResult;
+  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
+  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
+  std::swap(reductionTilingResult.initialValues, initTensors);
+  std::swap(reductionTilingResult.loops, loops);
+  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
+
+  return reductionTilingResult;
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list