[Mlir-commits] [mlir] 04ff600 - [mlir][tensor][bufferize] Implement getBufferType for Expand/CollapseShapeOp

Matthias Springer llvmlistbot at llvm.org
Wed Sep 21 02:32:14 PDT 2022


Author: Matthias Springer
Date: 2022-09-21T18:31:59+09:00
New Revision: 04ff6009fcaf2f920c2ccadfc882d51cd00eb928

URL: https://github.com/llvm/llvm-project/commit/04ff6009fcaf2f920c2ccadfc882d51cd00eb928
DIFF: https://github.com/llvm/llvm-project/commit/04ff6009fcaf2f920c2ccadfc882d51cd00eb928.diff

LOG: [mlir][tensor][bufferize] Implement getBufferType for Expand/CollapseShapeOp

This function must be implemented for all ops, where the result memref type is different from the input memref type.

Differential Revision: https://reviews.llvm.org/D134331

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2a3b1ee924f49..2ea19bbb0216c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,7 +1342,13 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
                "ArrayRef<ReassociationIndices>":$reassociation)>
   ];
-  let extraClassDeclaration = commonExtraClassDeclaration;
+
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    static FailureOr<MemRefType> computeExpandedType(
+        MemRefType srcType, ArrayRef<int64_t> resultShape,
+        ArrayRef<ReassociationIndices> reassociation);
+  }];
+
   let hasVerifier = 1;
 }
 
@@ -1389,6 +1395,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1422,12 +1429,16 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
+
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     /// Return `true` if this source MemRef type is guaranteed to be collapsible
     /// according to the given reassociation indices. In the presence of dynamic
     /// strides this is usually not the case.
     static bool isGuaranteedCollapsible(
         MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+
+    static MemRefType computeCollapsedType(
+        MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
   }];
 
   let hasVerifier = 1;

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f1c61c49473b5..227352b28de5a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1801,9 +1801,9 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
   return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
 }
 
-static FailureOr<MemRefType>
-computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
-                    ArrayRef<ReassociationIndices> reassociation) {
+FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
+    MemRefType srcType, ArrayRef<int64_t> resultShape,
+    ArrayRef<ReassociationIndices> reassociation) {
   if (srcType.getLayout().isIdentity()) {
     // If the source is contiguous (i.e., no layout map specified), so is the
     // result.
@@ -1827,7 +1827,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
   // Only ranked memref source values are supported.
   auto srcType = src.getType().cast<MemRefType>();
   FailureOr<MemRefType> resultType =
-      computeExpandedType(srcType, resultShape, reassociation);
+      ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
   // Failure of this assertion usually indicates a problem with the source
   // type, e.g., could not get strides/offset.
   assert(succeeded(resultType) && "could not compute layout");
@@ -1846,7 +1846,7 @@ LogicalResult ExpandShapeOp::verify() {
     return failure();
 
   // Compute expected result type (including layout map).
-  FailureOr<MemRefType> expectedResultType = computeExpandedType(
+  FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
       srcType, resultType.getShape(), getReassociationIndices());
   if (failed(expectedResultType))
     return emitOpError("invalid source layout map");
@@ -1943,9 +1943,8 @@ bool CollapseShapeOp::isGuaranteedCollapsible(
                                              /*strict=*/true));
 }
 
-static MemRefType
-computeCollapsedType(MemRefType srcType,
-                     ArrayRef<ReassociationIndices> reassociation) {
+MemRefType CollapseShapeOp::computeCollapsedType(
+    MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
   SmallVector<int64_t> resultShape;
   resultShape.reserve(reassociation.size());
   for (const ReassociationIndices &group : reassociation) {
@@ -1979,7 +1978,8 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
   auto srcType = src.getType().cast<MemRefType>();
-  MemRefType resultType = computeCollapsedType(srcType, reassociation);
+  MemRefType resultType =
+      CollapseShapeOp::computeCollapsedType(srcType, reassociation);
   build(b, result, resultType, src, attrs);
   result.addAttribute(::mlir::getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -2039,9 +2039,9 @@ struct CollapseShapeOpMemRefCastFolder
     if (!CastOp::canFoldIntoConsumerOp(cast))
       return failure();
 
-    Type newResultType =
-        computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(),
-                             op.getReassociationIndices());
+    Type newResultType = CollapseShapeOp::computeCollapsedType(
+        cast.getOperand().getType().cast<MemRefType>(),
+        op.getReassociationIndices());
 
     if (newResultType == op.getResultType()) {
       rewriter.updateRootInPlace(

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c6945ff0a067f..737e64258db55 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -109,6 +109,29 @@ struct CollapseShapeOpInterface
     return BufferRelation::Equivalent;
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
+    auto maybeSrcBufferType = bufferization::getBufferType(
+        collapseShapeOp.getSrc(), options, fixedTypes);
+    if (failed(maybeSrcBufferType))
+      return failure();
+    auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+    bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
+        srcBufferType, collapseShapeOp.getReassociationIndices());
+
+    if (!canBeCollapsed) {
+      // If dims cannot be collapsed, this op bufferizes to a new allocation.
+      RankedTensorType tensorResultType = collapseShapeOp.getResultType();
+      return bufferization::getMemRefTypeWithStaticIdentityLayout(
+          tensorResultType, srcBufferType.getMemorySpaceAsInt());
+    }
+
+    return memref::CollapseShapeOp::computeCollapsedType(
+        srcBufferType, collapseShapeOp.getReassociationIndices());
+  }
+
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
@@ -232,6 +255,23 @@ struct ExpandShapeOpInterface
     return BufferRelation::Equivalent;
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+    auto maybeSrcBufferType = bufferization::getBufferType(
+        expandShapeOp.getSrc(), options, fixedTypes);
+    if (failed(maybeSrcBufferType))
+      return failure();
+    auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+    auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
+        srcBufferType, expandShapeOp.getResultType().getShape(),
+        expandShapeOp.getReassociationIndices());
+    if (failed(maybeResultType))
+      return failure();
+    return *maybeResultType;
+  }
+
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index ea2b8dbdc2bd9..7a15d9b43c58b 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -858,3 +858,21 @@ func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true}
   }
   return %r1#1 : tensor<?xf32>
 }
+
+// -----
+
+// This is a regression test. Just check that the IR bufferizes.
+  
+// CHECK-LABEL: func @buffer_type_of_collapse_shape
+func.func @buffer_type_of_collapse_shape(%arg0: tensor<f64>) {
+  %true = arith.constant true
+  %0 = scf.while (%arg1 = %arg0) : (tensor<f64>) -> (tensor<f64>) {
+    scf.condition(%true) %arg1 : tensor<f64>
+  } do {
+  ^bb0(%_: tensor<f64>):
+    %3 = bufferization.alloc_tensor() : tensor<1xf64>
+    %16 = tensor.collapse_shape %3 [] : tensor<1xf64> into tensor<f64>
+    scf.yield %16 : tensor<f64>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list