[Mlir-commits] [mlir] d55d0b6 - [mlir][Linalg] Improve HoistPadding to propagate through iter_args
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Mar 1 05:22:26 PST 2023
Author: Nicolas Vasilache
Date: 2023-03-01T05:22:20-08:00
New Revision: d55d0b69e9b9cab7a11db6ceccfddf0f24a9a025
URL: https://github.com/llvm/llvm-project/commit/d55d0b69e9b9cab7a11db6ceccfddf0f24a9a025
DIFF: https://github.com/llvm/llvm-project/commit/d55d0b69e9b9cab7a11db6ceccfddf0f24a9a025.diff
LOG: [mlir][Linalg] Improve HoistPadding to propagate through iter_args
This revision properly plumbs the subsitution of a padded op through
iter_args in the case of an scf::ForOp consumer.
Differential Revision: https://reviews.llvm.org/D145036
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index dcea3a12208eb..6b56565eaf130 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -100,6 +100,9 @@ struct HoistingAnalysis {
/// The ExtractSliceOp that feeds the PadOp we want to hoist.
tensor::ExtractSliceOp sliceOp;
+ /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`.
+ scf::ForOp padConsumingForOp;
+
private:
/// Drop any non-index dependencies of `padOp` and `sliceOp` from
/// `backwardSlice`. The method follows the use-def chains of the index
@@ -224,9 +227,12 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
return;
}
+ if (sliceOp->hasOneUse()) {
+ padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin()));
+ }
- // Check the region of `padOp` depends on a constant only. Adding
- // hoisting support for arbitrary padding regions would require cloning all
+ // Check the region of `padOp` depends on a constant only. Adding hoisting
+ // support for arbitrary padding regions would require cloning all
// dependencies captured by the padding region.
Value paddingValue = padOp.getConstantPaddingValue();
if (!paddingValue ||
@@ -259,6 +265,13 @@ HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
if (backwardSlice.contains(forOp))
packingLoops.push_back(forOp);
+ // TODO: for multiple loops we need to track the use to the innermost loop.
+ if (packingLoops.size() > 1 && padConsumingForOp) {
+ LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "
+ "Downgrade to 1 loop\n");
+ packingLoops.resize(1);
+ }
+
// Note: at this point, packing loops may be empty but we would still like
// to hoist the padding if so specified.
@@ -512,18 +525,21 @@ static PackingLoopNestResult buildPackingLoopNest(
paddedTensor = maybeTransposeOp.getResult(0);
}
- // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
- // optionally transposed padded slice into the packed tensor.
- Value inserted = rewriter.create<tensor::InsertSliceOp>(
- loc, paddedTensor, packedTensor, offsets, sizes, strides);
-
- // Step 5. Iteratively pop the stack and propagate the yield.
- Value valueToYield = inserted;
- for (Value iv : llvm::reverse(clonedLoopIvs)) {
- auto forOp = scf::getForInductionVarOwner(iv);
- rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
- rewriter.create<scf::YieldOp>(loc, valueToYield);
- valueToYield = forOp.getResult(0);
+ // Innermost tensor.insert_slice and yields are optional / need loops.
+ if (nPackedLoops > 0) {
+ // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
+ // optionally transposed padded slice into the packed tensor.
+ Value inserted = rewriter.create<tensor::InsertSliceOp>(
+ loc, paddedTensor, packedTensor, offsets, sizes, strides);
+
+ // Step 5. Iteratively pop the stack and propagate the yield.
+ Value valueToYield = inserted;
+ for (Value iv : llvm::reverse(clonedLoopIvs)) {
+ auto forOp = scf::getForInductionVarOwner(iv);
+ rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
+ rewriter.create<scf::YieldOp>(loc, valueToYield);
+ valueToYield = forOp.getResult(0);
+ }
}
return PackingLoopNestResult{offsets,
@@ -534,6 +550,36 @@ static PackingLoopNestResult buildPackingLoopNest(
maybeTransposeOp};
}
+// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter
+// arg), propagate the `packedTensor` value through the same iter arg.
+// TODO: for multiple loops we need to track the use to the innermost loop.
+static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
+ tensor::ExtractSliceOp sliceOp,
+ scf::ForOp forOp) {
+ OpOperand *pUse = nullptr;
+ for (OpOperand &use : sliceOp->getUses()) {
+ if (use.getOwner() == forOp) {
+ assert(!pUse && "Multiple slice uses in the for loop");
+ pUse = &use;
+ }
+ }
+ assert(pUse && "No slice use in the for loop");
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(packedTensor.getDefiningOp());
+ Value casted = rewriter.create<tensor::CastOp>(
+ packedTensor.getLoc(), pUse->get().getType(), packedTensor);
+
+ std::optional<unsigned> operandNumber =
+ forOp.getIterArgNumberForOpOperand(*pUse);
+ assert(operandNumber.has_value() && "expected a proper iter arg number");
+ SmallVector<Value> initArgs = forOp.getInitArgs();
+ initArgs[operandNumber.value()] = casted;
+ rewriter.startRootUpdate(forOp);
+ forOp.getInitArgsMutable().assign(initArgs);
+ rewriter.finalizeRootUpdate(forOp);
+ return forOp.getRegionIterArgForOpOperand(*pUse);
+}
+
/// Produce a tensor extracted from the packingResult. This can be used as a
/// replacement for `opToHoist` in callers.
static Value replaceByPackingLoopNestResult(
@@ -588,8 +634,12 @@ static Value replaceByPackingLoopNestResult(
LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n");
- // TODO: atm we are missing the plumbing of packedTensor through the loop
- // bbarg when required (i.e. when hoisting init tensors).
+ // If the consumer of `padOp` was a `forOp`, propagate through iter args.
+ scf::ForOp forOp = analysis.padConsumingForOp;
+ if (forOp) {
+ packedTensor =
+ padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp);
+ }
// offsets = [maybe_leading_ivs, 0 .. 0].
// sizes = [1 .. 1, transposedShape] (defined above).
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
index 18f3b641a7565..02e4698e0d0f5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
@@ -161,14 +161,13 @@ func.func @pad_and_hoist_init(
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) {
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}}
// CHECK: : tensor<?x25xf32> to tensor<5x25xf32>
- // CHECK: scf.for %{{.*}} -> (tensor<?x25xf32>) {
- // CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[PADDED]] : tensor<5x25xf32>
- //
- // TODO: atm we are missing the plumbing of packedTensor through the loop bbarg
- // when required (i.e. when hoisting init tensors).
- // CHECK: %[[RES_EXTRACTED:.*]] = tensor.extract_slice %[[RES]][0, 0] [%{{.*}}, 25] [1, 1]
- // CHECK-SAME: : tensor<5x25xf32> to tensor<?x25xf32>
- // CHECK: scf.yield %[[RES_EXTRACTED]] : tensor<?x25xf32>
+ // CHECK: scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>)
+ // CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]]
+ // CHECK-SAME: : tensor<5x25xf32>
+ // CHECK: scf.yield %[[RES]] : tensor<5x25xf32>
+ // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor<?x25xf32>
+ // CHECK: tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+ // CHECK-SAME: : tensor<?x25xf32> into tensor<24x25xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
More information about the Mlir-commits
mailing list