[Mlir-commits] [mlir] b3ed6e5 - [mlir][Linalg] Fix hoist padding through scf.for iter_arg
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Apr 13 05:25:14 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-13T05:21:38-07:00
New Revision: b3ed6e545568a2b483561b416c36942ce2e5d2a2
URL: https://github.com/llvm/llvm-project/commit/b3ed6e545568a2b483561b416c36942ce2e5d2a2
DIFF: https://github.com/llvm/llvm-project/commit/b3ed6e545568a2b483561b416c36942ce2e5d2a2.diff
LOG: [mlir][Linalg] Fix hoist padding through scf.for iter_arg
Previously, hoisting through an iter_arg would mistakenly yield the unpadded value and
cast it to the padded value.
This was incorrect and resulted in out-of-bounds accesses.
The correct formulation is to yield the padded value and extract a smaller dynamic slice
out of it.
Differential Revision: https://reviews.llvm.org/D148173
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 5386420b7a39c..7a6c58aee5938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
@@ -153,9 +154,9 @@ struct HoistPaddingAnalysis {
bool isValid() { return valid.has_value() && valid.value(); }
bool isInvalid() { return valid.has_value() && !valid.value(); }
- /// Footprint of the packedTensor, computed from the packingLoops.
- SmallVector<Value> getPackedTensorSizes(RewriterBase &rewriter,
- Location loc) const;
+ /// Footprint of the hoistedPackedTensor, computed from the packingLoops.
+ SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter,
+ Location loc) const;
/// Performs optional hoisting to enable hoist padding to occur. This may be
/// necessary when `sliceOp` is not defined outside of the outermost enclosing
@@ -450,8 +451,8 @@ LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
}
SmallVector<Value>
-HoistPaddingAnalysis::getPackedTensorSizes(RewriterBase &rewriter,
- Location loc) const {
+HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
+ Location loc) const {
SmallVector<Value> dynamicTensorSizes;
// Upper bound the packing loop lengths to size the packed tensor. Taking
@@ -525,7 +526,8 @@ static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer,
// Build a packing loop nest by iteratively traversing the backward slice and
// clone the operations, iteratively stepping into the loops that we encounter.
// The implementation proceeds in a stack-like fashion:
-// 1. Iteratively clone and step into the loops, pushing the `packedTensor`
+// 1. Iteratively clone and step into the loops, pushing the
+// `hoistedPackedTensor`
// deeper in the stack.
// 2. At the innermost loop level, create a GenericOp if `transposeVector` is
// non-empty.
@@ -537,7 +539,7 @@ static PackingResult buildPackingLoopNestImpl(
ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
SmallVector<OpFoldResult> offsets, sizes, strides;
- SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
+ SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings;
scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
@@ -558,14 +560,14 @@ static PackingResult buildPackingLoopNestImpl(
bbArg = operand.get().dyn_cast<BlockArgument>();
}
- // Step 1. iteratively clone loops and push `packedTensor`.
- Value packedTensor = emptyOp.getResult();
+ // Step 1. iteratively clone loops and push `hoistedPackedTensor`.
+ Value hoistedPackedTensor = emptyOp.getResult();
OpBuilder::InsertionGuard g(rewriter);
for (Operation *op : analysis.backwardSlice) {
- // Specifically sit out in the extract_slice(packedTensor) case: this is
- // the piece we seek to replace.
+ // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this
+ // is the piece we seek to replace.
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
- if (bvm.lookupOrDefault(sliceOp.getSource()) == packedTensor) {
+ if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {
LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");
continue;
}
@@ -579,11 +581,12 @@ static PackingResult buildPackingLoopNestImpl(
continue;
}
- // Create a packing loop that takes `packedTensor` as iteration argument.
+ // Create a packing loop that takes `hoistedPackedTensor` as iteration
+ // argument.
auto clonedForOp = rewriter.create<scf::ForOp>(
loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
- bvm.lookupOrDefault(forOp.getStep()), packedTensor);
+ bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
// Map the induction var, region args and results to the `clonedForOp`.
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
@@ -600,16 +603,18 @@ static PackingResult buildPackingLoopNestImpl(
// Assert the loop-independent iteration count can be computed.
if (!loopIndependentIterationCount)
llvm_unreachable("loop independence prerequisite not met");
- leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
- packedTensor = clonedForOp.getRegionIterArgs().front();
+ leadingHoistedPackedTensorIndexings.push_back(
+ loopIndependentIterationCount);
+ hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();
}
// Step 2. Construct offsets, sizes and strides for the innermost level of the
// packing loop.
int64_t nPackedLoops = clonedLoopIvs.size();
// offsets = [clonedLoopIvs, 0 .. 0].
- offsets = SmallVector<OpFoldResult>{leadingPackedTensorIndexings.begin(),
- leadingPackedTensorIndexings.end()};
+ offsets =
+ SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(),
+ leadingHoistedPackedTensorIndexings.end()};
offsets.append(paddedRank, rewriter.getIndexAttr(0));
// sizes = [1 .. 1, transposedShape].
sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1));
@@ -627,7 +632,8 @@ static PackingResult buildPackingLoopNestImpl(
Value paddedTensor = bvm.lookup(opToHoist.getResult());
if (!transposeVector.empty()) {
Value outputTensor = rewriter.create<tensor::ExtractSliceOp>(
- loc, transposedTensorType, packedTensor, offsets, sizes, strides);
+ loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
+ strides);
maybeTransposeOp = makeTransposeOp(rewriter, loc, paddedTensor,
outputTensor, transposeVector);
paddedTensor = maybeTransposeOp.getResult(0);
@@ -638,7 +644,7 @@ static PackingResult buildPackingLoopNestImpl(
// Step 4. Create InsertSliceOp at the innermost loop level, inserting an
// optionally transposed padded slice into the packed tensor.
Value inserted = rewriter.create<tensor::InsertSliceOp>(
- loc, paddedTensor, packedTensor, offsets, sizes, strides);
+ loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
// Step 5. Iteratively pop the stack and propagate the yield.
Value valueToYield = inserted;
@@ -655,7 +661,7 @@ static PackingResult buildPackingLoopNestImpl(
sizes,
strides,
clonedLoopIvs,
- leadingPackedTensorIndexings,
+ leadingHoistedPackedTensorIndexings,
maybeTransposeOp,
cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
}
@@ -688,7 +694,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
// TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
llvm::append_range(packedShape, transposedTensorType->getShape());
- auto packedTensorType = RankedTensorType::get(
+ auto hoistedPackedTensorType = RankedTensorType::get(
packedShape, transposedTensorType->getElementType());
// Set the insertion point right before the outer loop and start packing.
@@ -696,10 +702,10 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(outerLoop);
SmallVector<Value> dynamicTensorSizes =
- analysis.getPackedTensorSizes(rewriter, loc);
+ analysis.getHoistedPackedTensorSizes(rewriter, loc);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, packedTensorType.getShape(), packedTensorType.getElementType(),
- dynamicTensorSizes);
+ loc, hoistedPackedTensorType.getShape(),
+ hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
*transposedTensorType, emptyOp, analysis);
@@ -727,14 +733,71 @@ FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest(
// hoistPaddingOnTensors Implementation.
//===----------------------------------------------------------------------===//
-// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter
-// arg), propagate the `packedTensor` value through the same iter arg.
-// TODO: for multiple loops we need to track the use to the innermost loop.
-static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
- tensor::ExtractSliceOp sliceOp,
- scf::ForOp forOp) {
+/// Return true if we can walk back the use-def chain from `extractSliceOp` to
+/// expectedSource going through DestinationStyleOpInterface inits only.
+/// This is a poor man's analysis that is sufficient to check the extractSliceOp
+/// the matches tensor.pad we want to hoist.
+/// In the future, it will be easier to ensure this with a matching symmetric
+/// tensor.unpad op.
+static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
+ Value expectedSource) {
+ LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp
+ << "\n");
+ LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");
+ Value source = extractSliceOp.getSource();
+ LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
+ while (source && source != expectedSource) {
+ auto destOp =
+ dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
+ if (!destOp)
+ break;
+ LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
+ source =
+ destOp.getDpsInitOperand(source.cast<OpResult>().getResultNumber())
+ ->get();
+ }
+ LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
+ LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
+ return source == expectedSource;
+}
+
+/// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an
+/// iter arg), propagate the `hoistedPackedTensor` value through the same iter
+/// arg.
+/// TODO: for multiple loops we need to track the use to the innermost loop.
+///
+/// Match:
+/// ```
+/// %outerSliceOp = tensor.extract_slice ..
+/// %f = scf.for ... iter_args(%arg0 = %outerSliceOp) {
+/// %hoistedPackedTensor = tensor.pad %arg0
+/// %1 = compute %hoistedPackedTensor
+/// %2 = tensor.extract_slice %1
+/// scf.yield %2
+/// }
+/// ```
+///
+/// and rewrite as:
+/// ```
+/// %outerSliceOp = tensor.extract_slice ..
+/// %hoistedPackedTensor = tensor.pad %outerSliceOp
+/// %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) {
+/// %1 = compute %arg0
+/// scf.yield %1
+/// }
+/// %2 = tensor.extract_slice %forOp
+/// ```
+///
+/// Return null when no rewrite happened.
+static tensor::ExtractSliceOp
+padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
+ Value hoistedPackedTensor,
+ tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {
+ LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");
+ LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "
+ << paddedValueBeforeHoisting << "\n");
OpOperand *pUse = nullptr;
- for (OpOperand &use : sliceOp->getUses()) {
+ for (OpOperand &use : outerSliceOp->getUses()) {
if (use.getOwner() == forOp) {
assert(!pUse && "Multiple slice uses in the for loop");
pUse = &use;
@@ -742,20 +805,67 @@ static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor,
}
assert(pUse && "No slice use in the for loop");
OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointAfter(packedTensor.getDefiningOp());
- Value casted = rewriter.create<tensor::CastOp>(
- packedTensor.getLoc(), pUse->get().getType(), packedTensor);
+ rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
- std::optional<unsigned> operandNumber =
+ std::optional<unsigned> maybeOperandNumber =
forOp.getIterArgNumberForOpOperand(*pUse);
- assert(operandNumber.has_value() && "expected a proper iter arg number");
+ assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
+
+ int64_t operandNumber = maybeOperandNumber.value();
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
+ auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+ .getDefiningOp<tensor::ExtractSliceOp>();
+ if (!yieldingExtractSliceOp)
+ return tensor::ExtractSliceOp();
+
+ // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad.
+ // In the future, it will be easier to ensure this with a matching symmetric
+ // tensor.unpad op.
+ if (!tracesBackToExpectedValue(yieldingExtractSliceOp,
+ paddedValueBeforeHoisting))
+ return tensor::ExtractSliceOp();
SmallVector<Value> initArgs = forOp.getInitArgs();
- initArgs[operandNumber.value()] = casted;
- rewriter.startRootUpdate(forOp);
- forOp.getInitArgsMutable().assign(initArgs);
- rewriter.finalizeRootUpdate(forOp);
- return forOp.getRegionIterArgForOpOperand(*pUse);
+ initArgs[operandNumber] = hoistedPackedTensor;
+ SmallVector<Value> yieldOperands = yieldOp.getOperands();
+ yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+
+ int64_t numOriginalForOpResults = initArgs.size();
+ LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
+ << "\n");
+ tensor::ExtractSliceOp extracted;
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(forOp);
+ extracted = rewriter.create<tensor::ExtractSliceOp>(
+ hoistedPackedTensor.getLoc(), hoistedPackedTensor,
+ outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
+ outerSliceOp.getMixedStrides());
+ rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+ }
+ scf::ForOp newForOp =
+ replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
+
+ LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
+ << "\n");
+ LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
+ LLVM_DEBUG(DBGS() << "with result #"
+ << numOriginalForOpResults + operandNumber
+ << " of forOp, giving us: " << extracted << "\n");
+ rewriter.startRootUpdate(extracted);
+ extracted.getSourceMutable().assign(
+ newForOp.getResult(numOriginalForOpResults + operandNumber));
+ rewriter.finalizeRootUpdate(extracted);
+
+ LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
+ << "\n");
+ LLVM_DEBUG(DBGS() << "with region iter arg #"
+ << numOriginalForOpResults + operandNumber << "\n");
+ rewriter.replaceAllUsesWith(
+ paddedValueBeforeHoisting,
+ newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+
+ return extracted;
}
/// Produce a tensor extracted from the packingResult. This can be used as a
@@ -781,7 +891,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops;
- Value packedTensor;
+ Value hoistedPackedTensor;
SmallVector<Value> loopIterationCounts;
SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank,
rewriter.getIndexAttr(0));
@@ -798,29 +908,29 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
// offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0].
std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),
offsets.begin());
- packedTensor =
+ hoistedPackedTensor =
scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
->getResult(0);
} else {
// If no loops were created, this is just hoisting without packing.
- packedTensor = bvm.lookup(opToHoist.getResult());
+ hoistedPackedTensor = bvm.lookup(opToHoist.getResult());
}
- LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n");
+ LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");
// If the consumer of `padOp` was a `forOp`, propagate through iter args.
scf::ForOp forOp = analysis.padConsumingForOp;
if (forOp) {
- packedTensor =
- padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp);
+ return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor,
+ analysis.sliceOp, forOp);
}
// offsets = [maybe_leading_ivs, 0 .. 0].
// sizes = [1 .. 1, transposedShape] (defined above).
// strides = [1 .. 1] (defined above)
return rewriter.create<tensor::ExtractSliceOp>(
- loc, transposedTensorType, packedTensor, offsets, packingResult.sizes,
- packingResult.strides);
+ loc, transposedTensorType, hoistedPackedTensor, offsets,
+ packingResult.sizes, packingResult.strides);
}
FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
index fd0d3091af3aa..871163ab40cf3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
@@ -161,12 +161,13 @@ func.func @pad_and_hoist_init(
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) {
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}}
// CHECK: : tensor<?x25xf32> to tensor<5x25xf32>
- // CHECK: scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>)
+ // CHECK: %[[SCF_YIELD:.*]] = scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>)
// CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]]
// CHECK-SAME: : tensor<5x25xf32>
// CHECK: scf.yield %[[RES]] : tensor<5x25xf32>
- // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor<?x25xf32>
- // CHECK: tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+ // CHECK: %[[EXTRACTED:.*]] = tensor.extract_slice %[[SCF_YIELD]][%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+ // CHECK-SAME: : tensor<5x25xf32> to tensor<?x25xf32>
+ // CHECK: tensor.insert_slice %[[EXTRACTED]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
// CHECK-SAME: : tensor<?x25xf32> into tensor<24x25xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
More information about the Mlir-commits
mailing list