[Mlir-commits] [mlir] Extend `TilingInterface` to allow more flexible tiling (PR #95422)

Srinath Avadhanula llvmlistbot at llvm.org
Fri Jun 14 04:49:24 PDT 2024


https://github.com/srinathava updated https://github.com/llvm/llvm-project/pull/95422

>From 18ddecd8ab7738d44448beb3aa81b9db4f4cd6f2 Mon Sep 17 00:00:00 2001
From: Srinath Avadhanula <srinath.avadhanula at getcruise.com>
Date: Thu, 13 Jun 2024 08:29:31 -0700
Subject: [PATCH 1/2] initial commit

---
 .../SCF/Transforms/TileUsingInterface.h       |  2 +
 .../include/mlir/Interfaces/TilingInterface.h |  4 ++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  8 +++-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 11 +++++-
 .../SCF/Transforms/TileUsingInterface.cpp     | 37 ++++++++++---------
 5 files changed, 41 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..fecd33193eb0d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,7 @@ struct SCFTilingResult {
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
+  SmallVector<Operation *> extractSliceOps;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult {
   OpResult origProducer;       // Original untiled producer.
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
+  SmallVector<Operation *> extractSliceOps;
 };
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..e5ed016d53fc1 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -28,9 +28,13 @@ namespace mlir {
 /// are returned to the caller for further transformations.
 /// - `tiledValues` contains the tiled value corresponding to the result of the
 /// untiled operation.
+/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in
+/// generating the `tiledOps`. Usually these are operands to the `tiledOps`
+/// but they can be embedded in regions owned by `tiledOps`.
 struct TilingResult {
   SmallVector<Operation *> tiledOps;
   SmallVector<Value> tiledValues;
+  SmallVector<Operation *> extractSliceOps;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..5198e0bceaa6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  SmallVector<Operation *> sliceOps;
+  for (Value operand : tiledOperands)
+    if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+      sliceOps.push_back(sliceOp);
+
+  return TilingResult{
+      {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
 }
 
 LogicalResult SoftmaxOp::getResultTilePosition(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..f25ccc38ba0a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -129,7 +129,13 @@ struct LinalgOpTilingInterface
     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
 
-    return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+    SmallVector<Operation *> sliceOps;
+    for (Value operand : tiledOperands)
+      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+        sliceOps.push_back(sliceOp);
+
+    return TilingResult{
+        {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
   }
 
   /// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -247,7 +253,8 @@ struct LinalgOpTilingInterface
 
     return TilingResult{
         tilingResult->tiledOps,
-        SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+        SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+        tilingResult->extractSliceOps};
   }
 
   /// Method to generate the tiled implementation of an operation from the tile
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..fb3ec2a5fa0a8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
-          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+                       /*extractSliceOps=*/{}};
       return success();
     }
 
@@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // op.
   if (loops.empty()) {
     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues};
+                                tilingResult->tiledValues,
+                                tilingResult->extractSliceOps};
   }
 
   SmallVector<Value> replacements = llvm::map_to_vector(
       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+                              tilingResult->extractSliceOps};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
         .set(origDestinationTensors[resultNumber]);
   }
-  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           tileAndFuseResult->tiledValues[0],
-                                           tileAndFuseResult->tiledOps};
+  return scf::SCFFuseProducerOfSliceResult{
+      fusableProducer, tileAndFuseResult->tiledValues[0],
+      tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
                   .getDefiningOp<DestinationStyleOpInterface>()) {
         rewriter.setInsertionPoint(tiledDestStyleOp);
         Value newRegionArg = newRegionIterArgs.back();
-        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
         unsigned resultNumber = fusableProducer.getResultNumber();
-        rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
-          tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
-        });
+        auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
+                             .getDefiningOp<tensor::ExtractSliceOp>();
+        if (origSlice) {
+          origSlice.getSourceMutable().set(newRegionArg);
+        }
       }
       Block *block = rewriter.getInsertionPoint()->getBlock();
       rewriter.setInsertionPoint(block->getTerminator());
@@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   //    operations. If the producers of the source of the `tensor.extract_slice`
   //    can be tiled such that the tiled value is generated in-place, that
   //    effectively tiles + fuses the operations.
-  auto addCandidateSlices = [](Operation *fusedOp,
+  auto addCandidateSlices = [](const SmallVector<Operation *> &newSliceOps,
                                std::deque<tensor::ExtractSliceOp> &candidates) {
-    for (Value operand : fusedOp->getOperands())
-      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
-        candidates.push_back(sliceOp);
+    for (auto *op : newSliceOps)
+      candidates.push_back(llvm::cast<tensor::ExtractSliceOp>(op));
   };
 
   std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tilingResult->extractSliceOps, candidates);
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
     // Traverse the slices in BFS fashion.
@@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
       tiledAndFusedOps.insert(tiledAndFusedOp);
-      addCandidateSlices(tiledAndFusedOp, candidates);
+      addCandidateSlices(fusedResult->extractSliceOps, candidates);
     }
   }
 

>From af5f7a5b21af2137da0598b41b7c8c032b89a264 Mon Sep 17 00:00:00 2001
From: Srinath Avadhanula <srinath.avadhanula at getcruise.com>
Date: Fri, 14 Jun 2024 04:48:24 -0700
Subject: [PATCH 2/2] also add extractSliceOps to TensorTilingInterfaceImpl

---
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   | 36 +++++++++++++++----
 1 file changed, 29 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 9b2a97eb2b006..33db5a5f043f3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -99,6 +99,16 @@ static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
   applyPermutationToVector<OpFoldResult>(sizes, permutation);
 }
 
+static SmallVector<Operation *> sliceOperandsOf(Operation *op) {
+  SmallVector<Operation *> sliceOps;
+  for (auto operand : op->getOperands()) {
+    if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) {
+      sliceOps.push_back(sliceOp);
+    }
+  }
+  return sliceOps;
+}
+
 struct PackOpTiling
     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
 
@@ -192,7 +202,8 @@ struct PackOpTiling
         loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
 
     return TilingResult{{tiledPackOp},
-                        SmallVector<Value>(tiledPackOp->getResults())};
+                        SmallVector<Value>(tiledPackOp->getResults()),
+                        sliceOperandsOf(tiledPackOp)};
   }
 
   LogicalResult
@@ -440,12 +451,16 @@ struct UnPackOpTiling
 
     if (isPerfectTilingCase)
       return TilingResult{{tiledUnpackOp},
-                          SmallVector<Value>(tiledUnpackOp->getResults())};
+                          SmallVector<Value>(tiledUnpackOp->getResults()),
+                          sliceOperandsOf(tiledUnpackOp)};
 
     auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
+
+    return TilingResult{{tiledUnpackOp},
+                        {extractSlice.getResult()},
+                        sliceOperandsOf(tiledUnpackOp)};
   }
 
   LogicalResult
@@ -567,7 +582,8 @@ struct UnPackOpTiling
                            tiledOperands, op->getAttrs());
 
     return TilingResult{{tiledUnPackOp},
-                        SmallVector<Value>(tiledUnPackOp->getResults())};
+                        SmallVector<Value>(tiledUnPackOp->getResults()),
+                        sliceOperandsOf(tiledUnPackOp)};
   }
 };
 
@@ -756,7 +772,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   // the original data source x is not used.
   if (hasZeroLen) {
     Operation *generateOp = createGenerateOp();
-    return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+    return TilingResult{{generateOp},
+                        {castResult(generateOp->getResult(0))},
+                        /*extractSliceOps=*/{}};
   }
 
   // If there are dynamic dimensions: Generate an scf.if check to avoid
@@ -776,11 +794,15 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
           elseOp = createPadOfExtractSlice();
           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
         });
-    return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
+    return TilingResult{{elseOp},
+                        SmallVector<Value>(result->getResults()),
+                        sliceOperandsOf(elseOp)};
   }
 
   Operation *newPadOp = createPadOfExtractSlice();
-  return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
+  return TilingResult{{newPadOp},
+                      {castResult(newPadOp->getResult(0))},
+                      sliceOperandsOf(newPadOp)};
 }
 
 void mlir::tensor::registerTilingInterfaceExternalModels(



More information about the Mlir-commits mailing list