[Mlir-commits] [mlir] 12e41d9 - [mlir][bufferize] Infer memref types when possible

Matthias Springer llvmlistbot at llvm.org
Sun May 15 17:02:24 PDT 2022


Author: Matthias Springer
Date: 2022-05-16T02:02:08+02:00
New Revision: 12e41d9264b6f84213be86aab75016fb82ebc1d1

URL: https://github.com/llvm/llvm-project/commit/12e41d9264b6f84213be86aab75016fb82ebc1d1
DIFF: https://github.com/llvm/llvm-project/commit/12e41d9264b6f84213be86aab75016fb82ebc1d1.diff

LOG: [mlir][bufferize] Infer memref types when possible

Instead of recomputing memref types from tensor types, try to infer them when possible. This results in more precise layout maps.

Differential Revision: https://reviews.llvm.org/D125614

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index bb9ec01380e4e..a6233295cb650 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -503,6 +503,9 @@ struct BufferizationState {
             Optional<Operation *> customCopyInsertionPoint = None);
 
   /// Return the buffer type for a given OpOperand (tensor) after bufferization.
+  ///
+  /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
+  /// This function should only be used if `getBuffer` cannot be used.
   BaseMemRefType getBufferType(OpOperand &opOperand) const;
 
   /// Return a reference to the BufferizationOptions.
@@ -546,9 +549,18 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
   return newOp;
 }
 
-/// Return a MemRefType to which the `tensorType` can be bufferized in a
-/// composable fashion. The layout must be the most dynamic possible and
-/// canonicalize away once bufferization is finished.
+/// Return a MemRefType to which the `tensorType` can be bufferized.
+///
+/// If possible, op bufferization implementations should not use this function
+/// and instead infer precise memref types for tensor results by themselves.
+///
+/// Unless a layout map was specified, `options` flags determine what kind of
+/// layout map will be used. For best composability (without copies), the fully
+/// dynamic layout map is used by default.
+///
+/// Note: Canonicalization patterns could clean up layout maps and infer more
+/// precise layout maps after bufferization. However, many possible
+/// canonicalizations are currently not implemented.
 BaseMemRefType getMemRefType(TensorType tensorType,
                              const BufferizationOptions &options,
                              MemRefLayoutAttrInterface layout = {},

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 4f1add5a899b9..3bdd3d92cfdc0 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -82,17 +82,22 @@ struct IndexCastOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto castOp = cast<arith::IndexCastOp>(op);
+    auto resultTensorType = castOp.getType().cast<TensorType>();
 
     Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
     auto sourceType = source.getType().cast<BaseMemRefType>();
 
     // Result type should have same layout and address space as the source type.
-    MemRefLayoutAttrInterface layout = {};
-    if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
-      layout = rankedMemRefType.getLayout();
-    Type resultType =
-        getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
-                      layout, sourceType.getMemorySpace());
+    BaseMemRefType resultType;
+    if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
+      resultType = MemRefType::get(
+          rankedMemRefType.getShape(), resultTensorType.getElementType(),
+          rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
+    } else {
+      auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
+      resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
+                                           unrankedMemrefType.getMemorySpace());
+    }
 
     replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
                                                      source);
@@ -146,15 +151,14 @@ struct SelectOpInterface
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
       auto trueType = trueBuffer.getType().cast<MemRefType>();
-      auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
       int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
-      SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+      SmallVector<int64_t> dynamicStrides(trueType.getRank(),
                                           ShapedType::kDynamicStrideOrOffset);
       AffineMap stridedLayout = makeStridedLinearLayoutMap(
           dynamicStrides, dynamicOffset, op->getContext());
-      BaseMemRefType castedType = bufferization::getMemRefType(
-          tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
-          trueType.getMemorySpace());
+      auto castedType =
+          MemRefType::get(trueType.getShape(), trueType.getElementType(),
+                          stridedLayout, trueType.getMemorySpaceAsInt());
       trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
       falseBuffer =
           rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 337b0aa57ea3e..638deb6a68c4b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -79,6 +79,7 @@ struct ExecuteRegionOpInterface
     SmallVector<Type> newResultTypes;
     for (Type type : executeRegionOp->getResultTypes()) {
       if (auto tensorType = type.dyn_cast<TensorType>()) {
+        // TODO: Infer the result type instead of computing it.
         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
       } else {
         newResultTypes.push_back(type);
@@ -188,6 +189,7 @@ struct IfOpInterface
     SmallVector<Type> newTypes;
     for (Type returnType : ifOp->getResultTypes()) {
       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
+        // TODO: Infer the result type instead of computing it.
         newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
       } else {
         newTypes.push_back(returnType);

diff  --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index ce3c85e6454e0..ac868a3723a78 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -66,6 +66,7 @@ struct AssumingOpInterface
     SmallVector<Type> newResultTypes;
     for (Type type : assumingOp->getResultTypes()) {
       if (auto tensorType = type.dyn_cast<TensorType>()) {
+        // TODO: Infer the result type instead of computing it.
         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
       } else {
         newResultTypes.push_back(type);


        


More information about the Mlir-commits mailing list