[Mlir-commits] [mlir] 9dbb8ee - [mlir][tensor] Implement getBufferType for ReshapeOp.

Ingo Müller llvmlistbot at llvm.org
Fri Jun 2 06:40:53 PDT 2023


Author: Ingo Müller
Date: 2023-06-02T13:40:48Z
New Revision: 9dbb8eefd43b7acdd5d6d030deed839eb11bbd9c

URL: https://github.com/llvm/llvm-project/commit/9dbb8eefd43b7acdd5d6d030deed839eb11bbd9c
DIFF: https://github.com/llvm/llvm-project/commit/9dbb8eefd43b7acdd5d6d030deed839eb11bbd9c.diff

LOG: [mlir][tensor] Implement getBufferType for ReshapeOp.

This function should be implemented for ops that work in one-shot
bufferization.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D151548

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 1a4fc3bb5bbad..e824c73a1f079 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -992,13 +992,28 @@ struct ReshapeOpInterface
         getBuffer(rewriter, reshapeOp.getShape(), options);
     if (failed(srcBuffer) || failed(shapeBuffer))
       return failure();
-    auto resultMemRefType = getMemRefTypeWithStaticIdentityLayout(
-        reshapeOp.getResult().getType(),
-        cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
+    auto maybeResultMemRefType =
+        bufferization::getBufferType(reshapeOp.getResult(), options);
+    if (failed(maybeResultMemRefType))
+      return failure();
     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
-        rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
+        rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
     return success();
   }
+
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    assert(value == reshapeOp.getResult() && "unexpected value provided");
+    auto maybeSourceBufferType = bufferization::getBufferType(
+        reshapeOp.getSource(), options, fixedTypes);
+    if (failed(maybeSourceBufferType))
+      return failure();
+    return getMemRefTypeWithStaticIdentityLayout(
+        reshapeOp.getResult().getType(),
+        cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
+  }
 };
 
 /// Analysis of ParallelInsertSliceOp.


        


More information about the Mlir-commits mailing list