[Mlir-commits] [mlir] d55d0b6 - [mlir][Linalg] Improve HoistPadding to propagate through iter_args

Nicolas Vasilache llvmlistbot at llvm.org
Wed Mar 1 05:22:26 PST 2023


Author: Nicolas Vasilache
Date: 2023-03-01T05:22:20-08:00
New Revision: d55d0b69e9b9cab7a11db6ceccfddf0f24a9a025

URL: https://github.com/llvm/llvm-project/commit/d55d0b69e9b9cab7a11db6ceccfddf0f24a9a025
DIFF: https://github.com/llvm/llvm-project/commit/d55d0b69e9b9cab7a11db6ceccfddf0f24a9a025.diff

LOG: [mlir][Linalg] Improve HoistPadding to propagate through iter_args

This revision properly plumbs the subsitution of a padded op through
iter_args in the case of an scf::ForOp consumer.

Differential Revision: https://reviews.llvm.org/D145036

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
    mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index dcea3a12208eb..6b56565eaf130 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -100,6 +100,9 @@ struct HoistingAnalysis {
   /// The ExtractSliceOp that feeds the PadOp we want to hoist.
   tensor::ExtractSliceOp sliceOp;
 
+  /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`.
+  scf::ForOp padConsumingForOp;
+
 private:
   /// Drop any non-index dependencies of `padOp` and `sliceOp` from
   /// `backwardSlice`. The method follows the use-def chains of the index
@@ -224,9 +227,12 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
     LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
     return;
   }
+  if (sliceOp->hasOneUse()) {
+    padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin()));
+  }
 
-  // Check the region of `padOp` depends on a constant only. Adding
-  // hoisting support for arbitrary padding regions would require cloning all
+  // Check the region of `padOp` depends on a constant only. Adding hoisting
+  // support for arbitrary padding regions would require cloning all
   // dependencies captured by the padding region.
   Value paddingValue = padOp.getConstantPaddingValue();
   if (!paddingValue ||
@@ -259,6 +265,13 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
     if (backwardSlice.contains(forOp))
       packingLoops.push_back(forOp);
 
+  // TODO: for multiple loops we need to track the use to the innermost loop.
+  if (packingLoops.size() > 1 && padConsumingForOp) {
+    LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "
+                         "Downgrade to 1 loop\n");
+    packingLoops.resize(1);
+  }
+
   // Note: at this point, packing loops may be empty but we would still like
   // to hoist the padding if so specified.
 
@@ -512,18 +525,21 @@ static PackingLoopNestResult buildPackingLoopNest(
     paddedTensor = maybeTransposeOp.getResult(0);
   }
 
-  // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
-  // optionally transposed padded slice into the packed tensor.
-  Value inserted = rewriter.create<tensor::InsertSliceOp>(
-      loc, paddedTensor, packedTensor, offsets, sizes, strides);
-
-  // Step 5. Iteratively pop the stack and propagate the yield.
-  Value valueToYield = inserted;
-  for (Value iv : llvm::reverse(clonedLoopIvs)) {
-    auto forOp = scf::getForInductionVarOwner(iv);
-    rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
-    rewriter.create<scf::YieldOp>(loc, valueToYield);
-    valueToYield = forOp.getResult(0);
+  // Innermost tensor.insert_slice and yields are optional / need loops.
+  if (nPackedLoops > 0) {
+    // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
+    // optionally transposed padded slice into the packed tensor.
+    Value inserted = rewriter.create<tensor::InsertSliceOp>(
+        loc, paddedTensor, packedTensor, offsets, sizes, strides);
+
+    // Step 5. Iteratively pop the stack and propagate the yield.
+    Value valueToYield = inserted;
+    for (Value iv : llvm::reverse(clonedLoopIvs)) {
+      auto forOp = scf::getForInductionVarOwner(iv);
+      rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
+      rewriter.create<scf::YieldOp>(loc, valueToYield);
+      valueToYield = forOp.getResult(0);
+    }
   }
 
   return PackingLoopNestResult{offsets,
@@ -534,6 +550,36 @@ static PackingLoopNestResult buildPackingLoopNest(
                                maybeTransposeOp};
 }
 
+// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter
+// arg), propagate the `packedTensor` value through the same iter arg.
+// TODO: for multiple loops we need to track the use to the innermost loop.
+static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
+                                   tensor::ExtractSliceOp sliceOp,
+                                   scf::ForOp forOp) {
+  OpOperand *pUse = nullptr;
+  for (OpOperand &use : sliceOp->getUses()) {
+    if (use.getOwner() == forOp) {
+      assert(!pUse && "Multiple slice uses in the for loop");
+      pUse = &use;
+    }
+  }
+  assert(pUse && "No slice use in the for loop");
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPointAfter(packedTensor.getDefiningOp());
+  Value casted = rewriter.create<tensor::CastOp>(
+      packedTensor.getLoc(), pUse->get().getType(), packedTensor);
+
+  std::optional<unsigned> operandNumber =
+      forOp.getIterArgNumberForOpOperand(*pUse);
+  assert(operandNumber.has_value() && "expected a proper iter arg number");
+  SmallVector<Value> initArgs = forOp.getInitArgs();
+  initArgs[operandNumber.value()] = casted;
+  rewriter.startRootUpdate(forOp);
+  forOp.getInitArgsMutable().assign(initArgs);
+  rewriter.finalizeRootUpdate(forOp);
+  return forOp.getRegionIterArgForOpOperand(*pUse);
+}
+
 /// Produce a tensor extracted from the packingResult. This can be used as a
 /// replacement for `opToHoist` in callers.
 static Value replaceByPackingLoopNestResult(
@@ -588,8 +634,12 @@ static Value replaceByPackingLoopNestResult(
 
   LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n");
 
-  // TODO: atm we are missing the plumbing of packedTensor through the loop
-  // bbarg when required (i.e. when hoisting init tensors).
+  // If the consumer of `padOp` was a `forOp`, propagate through iter args.
+  scf::ForOp forOp = analysis.padConsumingForOp;
+  if (forOp) {
+    packedTensor =
+        padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp);
+  }
 
   // offsets = [maybe_leading_ivs, 0 .. 0].
   // sizes = [1 .. 1, transposedShape] (defined above).

diff  --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
index 18f3b641a7565..02e4698e0d0f5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
@@ -161,14 +161,13 @@ func.func @pad_and_hoist_init(
   //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) {
   //      CHECK:   %[[PADDED:.*]] = tensor.pad %{{.*}} 
   //      CHECK:     : tensor<?x25xf32> to tensor<5x25xf32>
-  //      CHECK:   scf.for %{{.*}} -> (tensor<?x25xf32>) {
-  //      CHECK:     %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[PADDED]] : tensor<5x25xf32>
-  //
-  // TODO: atm we are missing the plumbing of packedTensor through the loop bbarg
-  // when required (i.e. when hoisting init tensors).
-  //      CHECK:     %[[RES_EXTRACTED:.*]] = tensor.extract_slice %[[RES]][0, 0] [%{{.*}}, 25] [1, 1] 
-  // CHECK-SAME:       : tensor<5x25xf32> to tensor<?x25xf32>
-  //      CHECK:     scf.yield %[[RES_EXTRACTED]] : tensor<?x25xf32>
+  //      CHECK:   scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>)
+  //      CHECK:     %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]]
+  // CHECK-SAME:       : tensor<5x25xf32>
+  //      CHECK:     scf.yield %[[RES]] : tensor<5x25xf32>
+  //      CHECK:   %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor<?x25xf32>
+  //      CHECK:   tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+  // CHECK-SAME:     : tensor<?x25xf32> into tensor<24x25xf32>
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }


        


More information about the Mlir-commits mailing list