[Mlir-commits] [mlir] 8f42939 - [mlir][bufferize][NFC] Make getContiguousMemRefType a static function

Matthias Springer llvmlistbot at llvm.org
Fri May 13 02:31:01 PDT 2022


Author: Matthias Springer
Date: 2022-05-13T11:27:43+02:00
New Revision: 8f42939a07547d02424655cb7355963b54329c1e

URL: https://github.com/llvm/llvm-project/commit/8f42939a07547d02424655cb7355963b54329c1e
DIFF: https://github.com/llvm/llvm-project/commit/8f42939a07547d02424655cb7355963b54329c1e.diff

LOG: [mlir][bufferize][NFC] Make getContiguousMemRefType a static function

No need to expose this as public API anymore.

Differential Revision: https://reviews.llvm.org/D125361

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 4d008add40cd9..bb9ec01380e4e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -546,11 +546,6 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
   return newOp;
 }
 
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `addressSpace`.
-MemRefType getContiguousMemRefType(ShapedType shapedType,
-                                   Attribute memorySpace = {});
-
 /// Return a MemRefType to which the `tensorType` can be bufferized in a
 /// composable fashion. The layout must be the most dynamic possible and
 /// canonicalize away once bufferization is finished.

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f5b6203c2395f..75b05168902ef 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -441,6 +441,13 @@ LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
   return success();
 }
 
+static MemRefType getContiguousMemRefType(ShapedType shapedType,
+                                          Attribute memorySpace = {}) {
+  MemRefLayoutAttrInterface layout = {};
+  return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
+                         layout, memorySpace);
+}
+
 /// Compute the type of the `memref` to use for allocating the buffer for
 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
 /// dynamic dimensions in the returned `memref` type.
@@ -644,13 +651,6 @@ bool bufferization::isFunctionArgument(Value value) {
   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
 }
 
-MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType,
-                                                  Attribute memorySpace) {
-  MemRefLayoutAttrInterface layout = {};
-  return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
-                         layout, memorySpace);
-}
-
 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 3ec52e94423f6..238d5365fb87e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -445,13 +445,12 @@ struct GenerateOpInterface
 
     // Allocate memory.
     Location loc = op->getLoc();
-    MemRefType memrefType =
-        getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
     FailureOr<Value> maybeResult =
         state.createAlloc(rewriter, loc, generateOp.result());
     if (failed(maybeResult))
       return failure();
     Value result = *maybeResult;
+    MemRefType memrefType = result.getType().cast<MemRefType>();
 
     // Collect loop bounds.
     int64_t rank = memrefType.getRank();


        


More information about the Mlir-commits mailing list