[Mlir-commits] [mlir] 758329d - [mlir][NFC] reifyResultShapes: Add extra error checking

Matthias Springer llvmlistbot at llvm.org
Fri Mar 10 02:38:09 PST 2023


Author: Matthias Springer
Date: 2023-03-10T11:37:54+01:00
New Revision: 758329dc7cd3b0da835a4f865b89003263050080

URL: https://github.com/llvm/llvm-project/commit/758329dc7cd3b0da835a4f865b89003263050080
DIFF: https://github.com/llvm/llvm-project/commit/758329dc7cd3b0da835a4f865b89003263050080.diff

LOG: [mlir][NFC] reifyResultShapes: Add extra error checking

This change adds a new helper function `mlir::reifyResultShapes` that calls the corresponding interface method and also checks the result produced by the implementation when running in debug mode. Bugs due to incorrect interface implementations can be difficult to debug.

This helper function also reduces the amount of code needed at call sites: the cast to `ReifyRankedShapedTypeOpInterface` is done in the helper function.

Differential Revision: https://reviews.llvm.org/D145777

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
    mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 42f5ec47fcdf3..b63d8b6be6739 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -28,6 +28,12 @@ namespace mlir {
 class ShapedTypeComponents;
 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
 
+/// Reify the shape of the result of an operation (typically in terms of the
+/// shape of its operands).
+LogicalResult
+reifyResultShapes(OpBuilder &b, Operation *op,
+                  ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+
 /// Adaptor class to abstract the 
diff erences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
 class ShapeAdaptor {

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 0f119d501c241..3b965cf732086 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -138,17 +138,15 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
     bool reifiedShapes = false;
     if (shapedValue.getType().isa<RankedTensorType>() &&
         shapedValue.isa<OpResult>()) {
-      if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
-              shapedValue.getDefiningOp())) {
-        ReifiedRankedShapedTypeDims resultDims;
-        if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
-          reifiedShapes = true;
-          auto &shape =
-              resultDims[shapedValue.cast<OpResult>().getResultNumber()];
-          for (const auto &dim : enumerate(tensorType.getShape()))
-            if (ShapedType::isDynamic(dim.value()))
-              dynamicSizes.push_back(shape[dim.index()].get<Value>());
-        }
+      ReifiedRankedShapedTypeDims resultDims;
+      if (succeeded(
+              reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) {
+        reifiedShapes = true;
+        auto &shape =
+            resultDims[shapedValue.cast<OpResult>().getResultNumber()];
+        for (const auto &dim : enumerate(tensorType.getShape()))
+          if (ShapedType::isDynamic(dim.value()))
+            dynamicSizes.push_back(shape[dim.index()].get<Value>());
       }
     }
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b33d989b18525..f6a58793e2f0d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -482,9 +482,7 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
       return failure();
 
     ReifiedRankedShapedTypeDims reifiedShape;
-    ReifyRankedShapedTypeOpInterface interface =
-        cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
-    if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
+    if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
       return rewriter.notifyMatchFailure(
           padOp, "failed to reify tensor.pad op result shape");
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 700f873d19729..3ec5094ed90b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -125,19 +125,17 @@ static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
     return {};
 
   // Try to reify dynamic sizes.
-  if (auto reifiableOp =
-          value.getDefiningOp<ReifyRankedShapedTypeOpInterface>()) {
-    ReifiedRankedShapedTypeDims reifiedShape;
-    if (succeeded(reifiableOp.reifyResultShapes(b, reifiedShape))) {
-      SmallVector<Value> dynSizes;
-      for (int64_t i = 0; i < tensorType.getRank(); ++i) {
-        if (tensorType.isDynamicDim(i))
-          dynSizes.push_back(
-              reifiedShape[value.cast<OpResult>().getResultNumber()][i]
-                  .get<Value>());
-      }
-      return dynSizes;
+  ReifiedRankedShapedTypeDims reifiedShape;
+  if (value.isa<OpResult>() &&
+      succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
+    SmallVector<Value> dynSizes;
+    for (int64_t i = 0; i < tensorType.getRank(); ++i) {
+      if (tensorType.isDynamicDim(i))
+        dynSizes.push_back(
+            reifiedShape[value.cast<OpResult>().getResultNumber()][i]
+                .get<Value>());
     }
+    return dynSizes;
   }
 
   // Create tensor.dim ops.
@@ -293,8 +291,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
   Location loc = padOp.getLoc();
   RankedTensorType resultType = padOp.getResultType();
   ReifiedRankedShapedTypeDims reifiedShape;
-  if (failed(cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
-                 .reifyResultShapes(rewriter, reifiedShape)))
+  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
     return rewriter.notifyMatchFailure(
         padOp, "failed to reify tensor.pad op result shape");
   SmallVector<Value> dynamicSizes;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
index 760b14edb7943..b6e2ffcbba368 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
@@ -62,10 +62,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
           padOp, "only supported for ops with all parallel iterator types");
     }
     ReifiedRankedShapedTypeDims resultShape;
-    ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
-        dyn_cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
-    if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
-                                                          resultShape)) ||
+    if (failed(reifyResultShapes(rewriter, padOp, resultShape)) ||
         resultShape.size() != 1) {
       return rewriter.notifyMatchFailure(
           padOp, "failed to get shape of pad op result");

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9de0f763d3292..2ba1562064093 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -205,8 +205,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
   }
 
   ReifiedRankedShapedTypeDims reifiedResultShapes;
-  if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
-                 .reifyResultShapes(rewriter, reifiedResultShapes))) {
+  if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
     LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
     return rewriter.notifyMatchFailure(opToPad,
                                        "failed to reify result shapes");

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 2a18c559307b3..50ac04d6d45cc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -84,33 +84,18 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
     if (!dimValue)
       return failure();
-    auto rankedShapeTypeOp =
-        dyn_cast<ReifyRankedShapedTypeOpInterface>(dimValue.getOwner());
-    if (!rankedShapeTypeOp)
-      return failure();
-
     std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
     if (!dimIndex)
       return failure();
 
     ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(
-            rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes)))
-      return failure();
-
-    if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults())
+    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
+                                 reifiedResultShapes)))
       return failure();
-
     unsigned resultNumber = dimValue.getResultNumber();
-    auto sourceType = dimValue.getType().dyn_cast<RankedTensorType>();
-    if (reifiedResultShapes[resultNumber].size() !=
-        static_cast<size_t>(sourceType.getRank()))
-      return failure();
-
-    rewriter.replaceOp(dimOp,
-                       getValueOrCreateConstantIndexOp(
-                           rewriter, dimOp.getLoc(),
-                           reifiedResultShapes[resultNumber][*dimIndex]));
+    Value replacement = getValueOrCreateConstantIndexOp(
+        rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
+    rewriter.replaceOp(dimOp, replacement);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e1bf889088a9e..5755ddf9fa6b2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -81,11 +81,7 @@ FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
   if (!tensorType.hasStaticShape()) {
     // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
     ReifiedRankedShapedTypeDims reifiedShapes;
-    ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
-        dyn_cast<ReifyRankedShapedTypeOpInterface>(opResult.getDefiningOp());
-    if (!reifyShapedTypeInterface)
-      return failure();
-    if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
+    if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
       return failure();
     mixedSizes = reifiedShapes[opResult.getResultNumber()];
   } else {

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 457f261ef8837..ecebf21866dfa 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -34,10 +34,7 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
 
   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
     ReifiedRankedShapedTypeDims reifiedShapes;
-    ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
-        dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
-    (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
-
+    (void)reifyResultShapes(b, op, reifiedShapes);
     Location loc = op->getLoc();
     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
@@ -84,7 +81,7 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
   ReifiedRankedShapedTypeDims resultShape;
-  (void)op.reifyResultShapes(builder, resultShape);
+  (void)reifyResultShapes(builder, op, resultShape);
   SmallVector<Range> loopBounds(rank);
   for (auto dim : llvm::seq<int64_t>(0, rank)) {
     loopBounds[dim].offset = zero;
@@ -216,7 +213,7 @@ struct PackOpTiling
     resultOffsets.append(outputRank - inputRank, zeroAttr);
 
     ReifiedRankedShapedTypeDims outputShape;
-    (void)packOp.reifyResultShapes(b, outputShape);
+    (void)reifyResultShapes(b, packOp, outputShape);
     resultSizes.assign(sizes.begin(), sizes.end());
     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
       resultSizes.push_back(outputShape[0][dataTileDim]);

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index f9512fd49e695..99679bc9b1378 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -26,10 +26,7 @@ struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
       return failure();
     Location loc = reshapeOp.getLoc();
     ReifiedRankedShapedTypeDims resultShapes;
-    ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
-        cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
-    if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
-                                                          resultShapes)) ||
+    if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
         !llvm::hasSingleElement(resultShapes))
       return failure();
     // TODO: Do not drop tensor type encoding.

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
index 0ef8729347df8..f1ad357098c55 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
@@ -112,9 +112,7 @@ tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
   // Materialize the output shape of the collapse_shape operation. This will
   // create IR describing the output shape in terms of the input shape.
   ReifiedRankedShapedTypeDims reifiedShapes;
-  ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
-      dyn_cast<ReifyRankedShapedTypeOpInterface>(op.getOperation());
-  if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
+  if (failed(reifyResultShapes(b, op, reifiedShapes)))
     return failure();
   SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
   SmallVector<ReassociationIndices> reassociationIndices =

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index b76f236c3686f..7d464af78023e 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -22,6 +22,49 @@ namespace mlir {
 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
 } // namespace mlir
 
+LogicalResult
+mlir::reifyResultShapes(OpBuilder &b, Operation *op,
+                        ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
+#ifndef NDEBUG
+  if (failed(status))
+    return failure();
+  // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
+  // a correct result.
+  int64_t resultIdx = 0;
+  for (OpResult result : op->getResults()) {
+    auto shapedType = result.getType().dyn_cast<ShapedType>();
+    if (!shapedType)
+      continue;
+    if (!shapedType.hasRank()) {
+      // Nothing to check for unranked shaped values.
+      ++resultIdx;
+      continue;
+    }
+    // Assert one OpFoldResult per dimension.
+    assert(shapedType.getRank() ==
+               static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
+           "incorrect implementation of ReifyRankedShapedTypeOpInterface");
+    for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
+      // reifyResultShapes must return:
+      // * Attribute for static dimensions
+      // * Value for dynamic dimensions
+      assert(shapedType.isDynamicDim(dim) ==
+                 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
+             "incorrect implementation of ReifyRankedShapedTypeOpInterface");
+    }
+    ++resultIdx;
+  }
+  // Assert that every shaped value result was reified.
+  assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
+         "incorrect implementation of ReifyRankedShapedTypeOpInterface");
+#endif // NDEBUG
+  return status;
+}
+
 bool ShapeAdaptor::hasRank() const {
   if (val.isNull())
     return false;

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index be375e47dcf5c..38988923fef67 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -188,7 +188,7 @@ struct RewriteExtractSliceFromCollapseShapeBase
 
     // Materialize the output shape values of the slice operation.
     ReifiedRankedShapedTypeDims reifiedShapes;
-    if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
+    if (failed(reifyResultShapes(rewriter, op, reifiedShapes)))
       return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
 
     // Create the destination tensor using the above values.

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 5bafedee2d75d..09f2fcc108ee8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1241,11 +1241,16 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
   Location loc = getLoc();
   shapes.reserve(getNumOperands());
   for (Value operand : llvm::reverse(getOperands())) {
+    auto tensorType = operand.getType().cast<RankedTensorType>();
     auto currShape = llvm::to_vector<4>(llvm::map_range(
-        llvm::seq<int64_t>(
-            0, operand.getType().cast<RankedTensorType>().getRank()),
+        llvm::seq<int64_t>(0, tensorType.getRank()),
         [&](int64_t dim) -> OpFoldResult {
-          return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
+          return tensorType.isDynamicDim(dim)
+                     ? static_cast<OpFoldResult>(
+                           builder.createOrFold<tensor::DimOp>(loc, operand,
+                                                               dim))
+                     : static_cast<OpFoldResult>(
+                           builder.getIndexAttr(tensorType.getDimSize(dim)));
         }));
     shapes.emplace_back(std::move(currShape));
   }


        


More information about the Mlir-commits mailing list