[Mlir-commits] [mlir] efa16ee - [mlir][linalg][NFC] Simplify padOperandToSmallestStaticBoundingBox
Matthias Springer
llvmlistbot at llvm.org
Wed Jun 7 00:00:11 PDT 2023
Author: Matthias Springer
Date: 2023-06-07T08:54:07+02:00
New Revision: efa16ee20ac5d22341a767e93175259a1daadaa3
URL: https://github.com/llvm/llvm-project/commit/efa16ee20ac5d22341a767e93175259a1daadaa3
DIFF: https://github.com/llvm/llvm-project/commit/efa16ee20ac5d22341a767e93175259a1daadaa3.diff
LOG: [mlir][linalg][NFC] Simplify padOperandToSmallestStaticBoundingBox
The implementation is based on `ValueBoundsOpInterface` to compute upper bounds for tensor dim sizes. It is not necessary to skip over certain ops and reify shape dims; `ValueBoundsOpInterface` already takes care of that.
Differential Revision: https://reviews.llvm.org/D152256
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4dceab36f4501..5beb2ffee545b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -86,72 +86,28 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
Value paddingValue = rewriter.create<arith::ConstantOp>(
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
- // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
- OpOperand *currOpOperand = opOperand;
- while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
- OpResult result = cast<OpResult>(currOpOperand->get());
- currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
- }
-
- SmallVector<OpFoldResult> mixedSizes;
- if (auto reifiableOp =
- llvm::dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
- currOpOperand->get().getDefiningOp())) {
- ReifiedRankedShapedTypeDims reifiedReturnShapes;
- LogicalResult status =
- reifiableOp.reifyResultShapes(rewriter, reifiedReturnShapes);
- mixedSizes = reifiedReturnShapes[0];
- if (failed(status)) {
- LLVM_DEBUG(DBGS() << "--failed to reify result shapes\n");
- return rewriter.notifyMatchFailure(opToPad,
- "failed to reify result shapes");
- }
- } else if (hasStaticShape) {
- mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape);
- } else {
- // TODO: may want to add support for going through loop iter args.
- // This is not strictly necessary as we can pad before hoisting but it would
- // make the system more resilient to minor transformation reordering.
- LLVM_DEBUG(DBGS() << "--not a ReifyRankedShapedTypeOpInterface op\n");
- return rewriter.notifyMatchFailure(
- opToPad, "not a ReifyRankedShapedTypeOpInterface op");
- }
- LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes: ");
- llvm::dbgs() << "\n");
-
// Upper bound the sizes to obtain a static bounding box.
SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
- int64_t shapeIdx = 0;
- for (const auto &en : enumerate(mixedSizes)) {
- LLVM_DEBUG(DBGS() << "----mixedSizes: " << en.value() << "\n");
+ for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+ LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
// Skip dimensions that do not require padding.
- if (!shapeDimsToPad.contains(shapeIdx)) {
- shapeIdx++;
- LLVM_DEBUG(DBGS() << "------dim does not require padding, SKIP\n");
- continue;
- }
- // If the size is an attribute add it directly to `paddedShape`.
- if (en.value().is<Attribute>()) {
- paddedShape[shapeIdx++] =
- dyn_cast<IntegerAttr>(en.value().get<Attribute>()).getInt();
- LLVM_DEBUG(
- DBGS() << "------dim is an attr, add it to padded shape, SKIP\n");
+ if (!shapeDimsToPad.contains(i)) {
+ LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
continue;
}
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, en.value().get<Value>(),
- /*dim=*/std::nullopt, /*stopCondition=*/nullptr, /*closedUB=*/true);
+ presburger::BoundType::UB, opOperand->get(),
+ /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
- LLVM_DEBUG(DBGS() << "--count not compute a bounding box for padding");
+ LLVM_DEBUG(DBGS() << "----count not compute a bounding box for padding");
return rewriter.notifyMatchFailure(
opToPad, "count not compute a bounding box for padding");
}
- paddedShape[shapeIdx++] = *upperBound;
+ paddedShape[i] = *upperBound;
+ LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
}
- assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
- "expect the dynamic and static ranks to match");
// Pad the operand to the bounding box defined by `paddedShape`.
auto paddedTensorType = RankedTensorType::get(
More information about the Mlir-commits
mailing list