[Mlir-commits] [mlir] e83d846 - [mlir][linalg] Adapt hoistPaddingOnTensors signature to support patterns (NFC).

Tobias Gysi llvmlistbot at llvm.org
Thu Oct 28 23:52:24 PDT 2021


Author: Tobias Gysi
Date: 2021-10-29T06:51:38Z
New Revision: e83d8466fbd92f48186e5544295f4ea32f6f0e59

URL: https://github.com/llvm/llvm-project/commit/e83d8466fbd92f48186e5544295f4ea32f6f0e59
DIFF: https://github.com/llvm/llvm-project/commit/e83d8466fbd92f48186e5544295f4ea32f6f0e59.diff

LOG: [mlir][linalg] Adapt hoistPaddingOnTensors signature to support patterns (NFC).

Adapt hoistPaddingOnTensors to leave replacing and erasing the old pad tensor operation to the caller. This change makes the function pattern friendly.

Depends On D112003

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D112255

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h
    mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h
index a21b35a7eb30e..61fc1bfd5e404 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/HoistPadding.h
@@ -9,8 +9,10 @@
 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_
 #define MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_
 
+#include "mlir/Support/LogicalResult.h"
+
 namespace mlir {
-struct LogicalResult;
+class Value;
 
 namespace linalg {
 class PadTensorOp;
@@ -57,7 +59,8 @@ class PadTensorOp;
 ///      }
 ///    }
 /// ```
-LogicalResult hoistPaddingOnTensors(PadTensorOp &padTensorOp, int nLoops);
+FailureOr<Value> hoistPaddingOnTensors(PadTensorOp opToHoist, int numLoops,
+                                       PadTensorOp &hoistedOp);
 
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index a08ac26b43cef..5346e236b1672 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -355,11 +355,12 @@ static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
                                        ValueRange{ivVal, lbVal, stepVal});
 }
 
-LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
-                                                  int nLoops) {
-  LLVM_DEBUG(DBGS() << "Try to hoist " << *(padTensorOp) << " by " << nLoops
+FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(PadTensorOp opToHoist,
+                                                     int numLoops,
+                                                     PadTensorOp &hoistedOp) {
+  LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops
                     << " loops\n");
-  HoistingAnalysis analysis(padTensorOp, nLoops);
+  HoistingAnalysis analysis(opToHoist, numLoops);
   if (!analysis.isValid()) {
     LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n");
     return failure();
@@ -376,8 +377,8 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
   // Update actual number of loops, which may be smaller.
   int nPackedLoops = analysis.packingLoops.size();
 
-  Location loc = padTensorOp->getLoc();
-  RankedTensorType paddedTensorType = padTensorOp.getResultType();
+  Location loc = opToHoist->getLoc();
+  RankedTensorType paddedTensorType = opToHoist.getResultType();
   int paddedRank = paddedTensorType.getRank();
 
   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
@@ -404,8 +405,8 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
   clonedLoopIvs.reserve(nPackedLoops);
   leadingPackedTensorIndexings.reserve(nPackedLoops);
   BlockAndValueMapping bvm;
-  // Insert `padTensorOp` into the backwardSlice so we clone it too.
-  analysis.backwardSlice.insert(padTensorOp);
+  // Insert `opToHoist` into the backwardSlice so we clone it too.
+  analysis.backwardSlice.insert(opToHoist);
   // Stack step 1. iteratively clone loops and push `packedTensor`.
   for (Operation *op : analysis.backwardSlice) {
     // Specifically sit out in the extract_slice(packedTensor) case: this is the
@@ -466,7 +467,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
                                     b.getIndexAttr(1));
 
   Value inserted =
-      b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()),
+      b.create<tensor::InsertSliceOp>(loc, bvm.lookup(opToHoist.result()),
                                       packedTensor, offsets, sizes, strides);
 
   // Stack step 3. iteratively pop the stack and propagate the yield.
@@ -480,7 +481,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
 
   // Now the packed tensor is ready, replace the original padding op by a
   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
-  b.setInsertionPoint(padTensorOp);
+  b.setInsertionPoint(opToHoist);
   SmallVector<Value> loopIterationCounts = llvm::to_vector<4>(
       llvm::map_range(analysis.packingLoops, [&](Operation *loop) {
         return buildLoopIterationCount(b, outer, cast<scf::ForOp>(loop));
@@ -495,18 +496,10 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
   // strides = [1 .. 1] (defined above)
   packedTensor =
       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
-  padTensorOp.replaceAllUsesWith(
-      b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(),
-                                       packedTensor, offsets, sizes, strides)
-          ->getResult(0));
+  Value newResult = b.create<tensor::ExtractSliceOp>(
+      loc, opToHoist.getResultType(), packedTensor, offsets, sizes, strides);
 
-  Operation *toErase = padTensorOp;
-
-  // Make the newly cloned `padTensorOp` available to the caller.
-  padTensorOp =
-      cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
-
-  toErase->erase();
-
-  return success();
+  // Make the newly cloned `opToHoist` available to the caller.
+  hoistedOp = cast<PadTensorOp>(bvm.lookup(opToHoist.result()).getDefiningOp());
+  return newResult;
 }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 1e8620aee75ec..78e76ca9ea311 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -771,7 +771,13 @@ void TestLinalgTransforms::runOnFunction() {
                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
   if (testHoistPadding) {
     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
-      (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
+      PadTensorOp hoistedOp;
+      FailureOr<Value> newResult = linalg::hoistPaddingOnTensors(
+          padTensorOp, testHoistPadding, hoistedOp);
+      if (succeeded(newResult)) {
+        padTensorOp.getResult().replaceAllUsesWith(newResult.getValue());
+        padTensorOp->erase();
+      }
     });
   }
   if (testInterchangePattern.hasValue())


        


More information about the Mlir-commits mailing list