[Mlir-commits] [mlir] d7a9bf9 - [mlir][tensor] Fix verifier and bufferization of collapse_shape

Matthias Springer llvmlistbot at llvm.org
Fri Apr 8 02:23:08 PDT 2022


Author: Matthias Springer
Date: 2022-04-08T18:20:40+09:00
New Revision: d7a9bf91431a08bf43cc5b7111a043de9defaee9

URL: https://github.com/llvm/llvm-project/commit/d7a9bf91431a08bf43cc5b7111a043de9defaee9
DIFF: https://github.com/llvm/llvm-project/commit/d7a9bf91431a08bf43cc5b7111a043de9defaee9.diff

LOG: [mlir][tensor] Fix verifier and bufferization of collapse_shape

Insert a buffer copy unless the dims are guaranteed to be collapsible. In the verifier, accept collapses unless they are guaranteed to be non-collapsible.

Differential Revision: https://reviews.llvm.org/D123316

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cbd83b85a2787..7ccc2480f4be2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -427,6 +427,8 @@ class AlwaysCopyAnalysisState : public AnalysisState {
 /// BufferizationState provides helper functions for performing bufferization
 /// rewrites and handling memref buffers.
 struct BufferizationState {
+  enum ForceInPlacability { FORCE_INPLACE, FORCE_OUT_OF_PLACE };
+
   BufferizationState(const AnalysisState &analysisState)
       : analysisState(analysisState) {}
 
@@ -448,11 +450,19 @@ struct BufferizationState {
   /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
   /// bufferization was decided.
+  ///
+  /// Whether a buffer is in-place or out-of-place is queried from the analysis
+  /// state. Some analyses may always conservatively opt for out-of-place
+  /// bufferization. Inplacability decisions can be overridden with the optional
+  /// `overrideInPlace` parameter.
   FailureOr<Value>
   getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
-            bool forceInPlace = false,
+            Optional<ForceInPlacability> overrideInPlace = None,
             Optional<Operation *> customCopyInsertionPoint = None);
 
+  /// Return the buffer type for a given OpOperand (tensor) after bufferization.
+  BaseMemRefType getBufferType(OpOperand &opOperand) const;
+
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const {
     return analysisState.getOptions();

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6b727f9183b28..5f7ec96162f88 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1295,7 +1295,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
                "ArrayRef<ReassociationIndices>":$reassociation)>
   ];
 
-  let extraClassDeclaration = commonExtraClassDeclaration;
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    /// Return `true` if this source MemRef type is guaranteed to be collapsible
+    /// according to the given reassociation indices. In the presence of dynamic
+    /// strides this is usually not the case.
+    static bool isGuaranteedCollapsible(
+        MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+  }];
+
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 091462f1ed73a..f2c67ed754a50 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -247,12 +247,12 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
                                                     tensor);
 }
 
-/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
-/// bufferization is necessary.
+/// bufferization was decided.
 FailureOr<Value>
 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
-                              bool forceInPlace,
+                              Optional<ForceInPlacability> overrideInPlace,
                               Optional<Operation *> customCopyInsertionPoint) {
   const BufferizationOptions &options = analysisState.getOptions();
   OpBuilder::InsertionGuard guard(rewriter);
@@ -263,7 +263,11 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
   Value operand = opOperand.get();
   Value operandBuffer = lookupBuffer(rewriter, operand, options);
 
-  if (forceInPlace || analysisState.isInPlace(opOperand))
+  // Can `operandBuffer` be used directly or do we need a copy?
+  bool inplace =
+      overrideInPlace != FORCE_OUT_OF_PLACE &&
+      (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
+  if (inplace)
     return operandBuffer;
 
   // Bufferizing out-of-place: Allocate a new buffer.
@@ -317,6 +321,18 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
   return resultBuffer;
 }
 
+/// Return the buffer type for a given OpOperand (tensor) after bufferization.
+BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const {
+  Value tensor = opOperand.get();
+  auto tensorType = tensor.getType().dyn_cast<TensorType>();
+  assert(tensorType && "unexpected non-tensor type");
+
+  if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
+    return toTensorOp.memref().getType().cast<BaseMemRefType>();
+
+  return getMemRefType(tensorType, getOptions());
+}
+
 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
                                                   Operation *op,
                                                   ValueRange values) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 4ee0e6360e117..6146175debb7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,8 +48,9 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
       continue;
     }
     // Input operands are never written to.
-    newInputBuffers.push_back(
-        *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
+    newInputBuffers.push_back(*state.getBuffer(
+        rewriter, *opOperand,
+        BufferizationState::ForceInPlacability::FORCE_INPLACE));
   }
 
   // New output operands for the cloned op.

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 86dc0199b0c39..9e1010d4896de 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1812,10 +1812,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 ///
 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
 /// not possible to check this by inspecting a MemRefType in the general case.
-/// But it is assumed. If this is not the case, the behavior is undefined.
+/// If non-contiguity cannot be checked statically, the collapse is assumed to
+/// be valid (and thus accepted by this function) unless `strict = true`.
 static FailureOr<AffineMap>
 computeCollapsedLayoutMap(MemRefType srcType,
-                          ArrayRef<ReassociationIndices> reassociation) {
+                          ArrayRef<ReassociationIndices> reassociation,
+                          bool strict = false) {
   int64_t srcOffset;
   SmallVector<int64_t> srcStrides;
   auto srcShape = srcType.getShape();
@@ -1837,11 +1839,26 @@ computeCollapsedLayoutMap(MemRefType srcType,
     auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
       stride = stride * Wrapper::size(srcShape[idx]);
-      // Both are either static strides of the same value, or both are dynamic.
-      // The dynamic case is best effort atm : we can't check it statically.
-      // One exception to the dynamic check is when the srcShape is `1`, in
-      // which case it can never produce a non-contiguity.
-      if (stride != Wrapper::stride(srcStrides[idx - 1]) && srcShape[idx] != 1)
+
+      // Both source and result stride must have the same static value. In that
+      // case, we can be sure, that the dimensions are collapsible (because they
+      // are contiguous).
+      //
+      // One special case is when the srcShape is `1`, in which case it can
+      // never produce non-contiguity.
+      if (srcShape[idx] == 1)
+        continue;
+
+      // If `strict = false` (default during op verification), we accept cases
+      // where one or both strides are dynamic. This is best effort: We reject
+      // ops where obviously non-contiguous dims are collapsed, but accept ops
+      // where we cannot be sure statically. Such ops may fail at runtime. See
+      // the op documentation for details.
+      auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
+      if (strict && (stride.saturated || srcStride.saturated))
+        return failure();
+
+      if (!stride.saturated && !srcStride.saturated && stride != srcStride)
         return failure();
     }
   }
@@ -1849,6 +1866,16 @@ computeCollapsedLayoutMap(MemRefType srcType,
                                     srcType.getContext());
 }
 
+bool ExpandShapeOp::isGuaranteedCollapsible(
+    MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
+  // MemRefs with standard layout are always collapsible.
+  if (srcType.getLayout().isIdentity())
+    return true;
+
+  return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
+                                             /*strict=*/true));
+}
+
 static MemRefType
 computeCollapsedType(MemRefType srcType,
                      ArrayRef<ReassociationIndices> reassociation) {

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 01d8da85ce962..94df8c54d941e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -109,12 +109,12 @@ struct CollapseShapeOpInterface
                           BufferizationState &state) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
-    Value buffer =
-        *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
+    OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
+    auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
 
     if (tensorResultType.getRank() == 0) {
       // 0-d collapses must go through a 
diff erent op builder.
-      auto bufferType = buffer.getType().cast<MemRefType>();
+      Value buffer = *state.getBuffer(rewriter, srcOperand);
       MemRefType resultType;
 
       if (bufferType.getLayout().isIdentity()) {
@@ -141,6 +141,18 @@ struct CollapseShapeOpInterface
       return success();
     }
 
+    // If the dims are not collapsible (due to an incompatible source layout
+    // map), force an out-of-place bufferization, i.e., a buffer copy. This
+    // newly allocated buffer will have no layout map and thus be collapsible.
+    bool canBeCollapsed = memref::ExpandShapeOp::isGuaranteedCollapsible(
+        bufferType, collapseShapeOp.getReassociationIndices());
+    Optional<BufferizationState::ForceInPlacability> overrideInPlace =
+        canBeCollapsed
+            ? None
+            : Optional<BufferizationState::ForceInPlacability>(
+                  BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
+    Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
+
     // Result type is inferred by the builder.
     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
@@ -248,9 +260,12 @@ struct ExtractSliceOpInterface
                           BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     Location loc = extractSliceOp.getLoc();
+
+    // Even if this op was decided to bufferize out-of-place, do not insert the
+    // buffer copy yet. This is done later in this function.
     Value srcMemref =
         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
-                         /*forceInPlace=*/true);
+                         BufferizationState::ForceInPlacability::FORCE_INPLACE);
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index da27b9c80b6e5..204eaab203486 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -384,3 +384,20 @@ func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> {
   %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
   return %1 : tensor<i32>
 }
+
+// CHECK-LABEL: func @tensor.collapse_shape_of_slice2(
+func @tensor.collapse_shape_of_slice2(
+    %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index)
+    -> tensor<87x63648xi64> {
+  // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, #{{.*}}>
+  %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64>
+
+  // This memref may not be collapsible, so the buffer must be copied to get rid
+  // of the layout map.
+  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64>
+  // CHECK: memref.copy %[[subview]], %[[alloc]]
+  // CHECK: memref.collapse_shape %[[alloc]] [
+  // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64>
+  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
+  return %1 : tensor<87x63648xi64>
+}


        


More information about the Mlir-commits mailing list