[Mlir-commits] [mlir] efa16ee - [mlir][linalg][NFC] Simplify padOperandToSmallestStaticBoundingBox

Matthias Springer llvmlistbot at llvm.org
Wed Jun 7 00:00:11 PDT 2023


Author: Matthias Springer
Date: 2023-06-07T08:54:07+02:00
New Revision: efa16ee20ac5d22341a767e93175259a1daadaa3

URL: https://github.com/llvm/llvm-project/commit/efa16ee20ac5d22341a767e93175259a1daadaa3
DIFF: https://github.com/llvm/llvm-project/commit/efa16ee20ac5d22341a767e93175259a1daadaa3.diff

LOG: [mlir][linalg][NFC] Simplify padOperandToSmallestStaticBoundingBox

The implementation is based on `ValueBoundsOpInterface` to compute upper bounds for tensor dim sizes. It is not necessary to skip over certain ops and reify shape dims; `ValueBoundsOpInterface` already takes care of that.

Differential Revision: https://reviews.llvm.org/D152256

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4dceab36f4501..5beb2ffee545b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -86,72 +86,28 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
   Value paddingValue = rewriter.create<arith::ConstantOp>(
       opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
 
-  // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
-  OpOperand *currOpOperand = opOperand;
-  while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
-    OpResult result = cast<OpResult>(currOpOperand->get());
-    currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
-  }
-
-  SmallVector<OpFoldResult> mixedSizes;
-  if (auto reifiableOp =
-          llvm::dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
-              currOpOperand->get().getDefiningOp())) {
-    ReifiedRankedShapedTypeDims reifiedReturnShapes;
-    LogicalResult status =
-        reifiableOp.reifyResultShapes(rewriter, reifiedReturnShapes);
-    mixedSizes = reifiedReturnShapes[0];
-    if (failed(status)) {
-      LLVM_DEBUG(DBGS() << "--failed to reify result shapes\n");
-      return rewriter.notifyMatchFailure(opToPad,
-                                         "failed to reify result shapes");
-    }
-  } else if (hasStaticShape) {
-    mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape);
-  } else {
-    // TODO: may want to add support for going through loop iter args.
-    // This is not strictly necessary as we can pad before hoisting but it would
-    // make the system more resilient to minor transformation reordering.
-    LLVM_DEBUG(DBGS() << "--not a ReifyRankedShapedTypeOpInterface op\n");
-    return rewriter.notifyMatchFailure(
-        opToPad, "not a ReifyRankedShapedTypeOpInterface op");
-  }
-  LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes:  ");
-             llvm::dbgs() << "\n");
-
   // Upper bound the sizes to obtain a static bounding box.
   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
-  int64_t shapeIdx = 0;
-  for (const auto &en : enumerate(mixedSizes)) {
-    LLVM_DEBUG(DBGS() << "----mixedSizes:  " << en.value() << "\n");
+  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+    LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
     // Skip dimensions that do not require padding.
-    if (!shapeDimsToPad.contains(shapeIdx)) {
-      shapeIdx++;
-      LLVM_DEBUG(DBGS() << "------dim does not require padding, SKIP\n");
-      continue;
-    }
-    // If the size is an attribute add it directly to `paddedShape`.
-    if (en.value().is<Attribute>()) {
-      paddedShape[shapeIdx++] =
-          dyn_cast<IntegerAttr>(en.value().get<Attribute>()).getInt();
-      LLVM_DEBUG(
-          DBGS() << "------dim is an attr, add it to padded shape, SKIP\n");
+    if (!shapeDimsToPad.contains(i)) {
+      LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
       continue;
     }
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, en.value().get<Value>(),
-            /*dim=*/std::nullopt, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB, opOperand->get(),
+            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
-      LLVM_DEBUG(DBGS() << "--count not compute a bounding box for padding");
+      LLVM_DEBUG(DBGS() << "----count not compute a bounding box for padding");
       return rewriter.notifyMatchFailure(
           opToPad, "count not compute a bounding box for padding");
     }
-    paddedShape[shapeIdx++] = *upperBound;
+    paddedShape[i] = *upperBound;
+    LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
   }
-  assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
-         "expect the dynamic and static ranks to match");
 
   // Pad the operand to the bounding box defined by `paddedShape`.
   auto paddedTensorType = RankedTensorType::get(


        


More information about the Mlir-commits mailing list