[Mlir-commits] [mlir] 1492ae7 - [mlir][Linalg] Use ReifyRankedShapedTypeOpInterface for pad transforms.

Hanhan Wang llvmlistbot at llvm.org
Wed May 10 17:23:58 PDT 2023


Author: Hanhan Wang
Date: 2023-05-10T17:23:46-07:00
New Revision: 1492ae750b81ec14cef0d4d044c634a4c77babc8

URL: https://github.com/llvm/llvm-project/commit/1492ae750b81ec14cef0d4d044c634a4c77babc8
DIFF: https://github.com/llvm/llvm-project/commit/1492ae750b81ec14cef0d4d044c634a4c77babc8.diff

LOG: [mlir][Linalg] Use ReifyRankedShapedTypeOpInterface for pad transforms.

The information is not tied to tensor.empty op and tensor.extract_slice
op. We can infer smallest static bounding box for pad transform if
they implement ReifyRankedShapedTypeOpInterface. The revision extends
the usability for downstream projects. No tests are added because the
existing tests cover the change, and most of MLIR
ReifyRankedShapedTypeOpInterface ops are covered in the tests, except
tensor.generate and bufferization.alloc_tensor ops.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D150227

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 90d2aa5211572..984ff35515230 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -93,29 +93,28 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
     currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
   }
 
-  // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp.
-  auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
-  auto emptyOp = currOpOperand->get().getDefiningOp<tensor::EmptyOp>();
-
-  llvm::SmallBitVector droppedDims;
   SmallVector<OpFoldResult> mixedSizes;
-  if (sliceOp) {
-    // Compute the dropped dimensions if `sliceOp` is rank-reducing.
-    droppedDims = sliceOp.getDroppedDims();
-    mixedSizes = sliceOp.getMixedSizes();
-  } else if (emptyOp) {
-    mixedSizes = emptyOp.getMixedSizes();
-    droppedDims.resize(mixedSizes.size());
+  if (auto reifiableOp =
+          llvm::dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
+              currOpOperand->get().getDefiningOp())) {
+    ReifiedRankedShapedTypeDims reifiedReturnShapes;
+    LogicalResult status =
+        reifiableOp.reifyResultShapes(rewriter, reifiedReturnShapes);
+    mixedSizes = reifiedReturnShapes[0];
+    if (failed(status)) {
+      LLVM_DEBUG(DBGS() << "--failed to reify result shapes\n");
+      return rewriter.notifyMatchFailure(opToPad,
+                                         "failed to reify result shapes");
+    }
   } else if (hasStaticShape) {
     mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape);
-    droppedDims.resize(mixedSizes.size());
   } else {
     // TODO: may want to add support for going through loop iter args.
     // This is not strictly necessary as we can pad before hoisting but it would
     // make the system more resilient to minor transformation reordering.
-    LLVM_DEBUG(DBGS() << "--not defined by an extractSlice or emptyOp\n");
+    LLVM_DEBUG(DBGS() << "--not a ReifyRankedShapedTypeOpInterface op\n");
     return rewriter.notifyMatchFailure(
-        opToPad, "not defined by an extractSlice or emptyOp");
+        opToPad, "not a ReifyRankedShapedTypeOpInterface op");
   }
   LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes:  ");
              llvm::dbgs() << "\n");
@@ -125,11 +124,6 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
   int64_t shapeIdx = 0;
   for (const auto &en : enumerate(mixedSizes)) {
     LLVM_DEBUG(DBGS() << "----mixedSizes:  " << en.value() << "\n");
-    // Skip dropped dimensions.
-    if (droppedDims.test(en.index())) {
-      LLVM_DEBUG(DBGS() << "------dim is dropped, SKIP\n");
-      continue;
-    }
     // Skip dimensions that do not require padding.
     if (!shapeDimsToPad.contains(shapeIdx)) {
       shapeIdx++;


        


More information about the Mlir-commits mailing list