[Mlir-commits] [mlir] 48ca8e9 - [mlir][arith][bufferize][NFC] Move buffer type computation to getBufferType

Matthias Springer llvmlistbot at llvm.org
Tue Aug 30 07:45:11 PDT 2022


Author: Matthias Springer
Date: 2022-08-30T16:43:22+02:00
New Revision: 48ca8e955521b9ea7c1466401444a68c6435646d

URL: https://github.com/llvm/llvm-project/commit/48ca8e955521b9ea7c1466401444a68c6435646d
DIFF: https://github.com/llvm/llvm-project/commit/48ca8e955521b9ea7c1466401444a68c6435646d.diff

LOG: [mlir][arith][bufferize][NFC] Move buffer type computation to getBufferType

A part of the functionality of `bufferize` is extracted into `getBufferType`.

Differential Revision: https://reviews.llvm.org/D132861

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index e33ea753f011c..8398740ac3b6c 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -155,27 +155,19 @@ struct SelectOpInterface
       return failure();
     Value trueBuffer = *maybeTrueBuffer;
     Value falseBuffer = *maybeFalseBuffer;
-    BaseMemRefType trueType = trueBuffer.getType().cast<BaseMemRefType>();
-    BaseMemRefType falseType = falseBuffer.getType().cast<BaseMemRefType>();
-    if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt())
-      return op->emitError("inconsistent memory space on true/false operands");
 
     // The "true" and the "false" operands must have the same type. If the
     // buffers have 
diff erent types, they 
diff er only in their layout map. Cast
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
-      auto trueType = trueBuffer.getType().cast<MemRefType>();
-      int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
-      SmallVector<int64_t> dynamicStrides(trueType.getRank(),
-                                          ShapedType::kDynamicStrideOrOffset);
-      AffineMap stridedLayout = makeStridedLinearLayoutMap(
-          dynamicStrides, dynamicOffset, op->getContext());
-      auto castedType =
-          MemRefType::get(trueType.getShape(), trueType.getElementType(),
-                          stridedLayout, trueType.getMemorySpaceAsInt());
-      trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
+      auto targetType =
+          bufferization::getBufferType(selectOp.getResult(), options);
+      if (failed(targetType))
+        return failure();
+      trueBuffer =
+          rewriter.create<memref::CastOp>(loc, *targetType, trueBuffer);
       falseBuffer =
-          rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
+          rewriter.create<memref::CastOp>(loc, *targetType, falseBuffer);
     }
 
     replaceOpWithNewBufferizedOp<arith::SelectOp>(
@@ -183,6 +175,31 @@ struct SelectOpInterface
     return success();
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto selectOp = cast<arith::SelectOp>(op);
+    assert(value == selectOp.getResult() && "invalid value");
+    auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
+                                                 options, fixedTypes);
+    auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
+                                                  options, fixedTypes);
+    if (failed(trueType) || failed(falseType))
+      return failure();
+    if (*trueType == *falseType)
+      return *trueType;
+    if (trueType->getMemorySpaceAsInt() != falseType->getMemorySpaceAsInt())
+      return op->emitError("inconsistent memory space on true/false operands");
+
+    // If the buffers have 
diff erent types, they 
diff er only in their layout
+    // map.
+    auto memrefType = trueType->cast<MemRefType>();
+    return getMemRefTypeWithFullyDynamicLayout(
+        RankedTensorType::get(memrefType.getShape(),
+                              memrefType.getElementType()),
+        memrefType.getMemorySpaceAsInt());
+  }
+
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                 const AnalysisState &state) const {
     return BufferRelation::None;


        


More information about the Mlir-commits mailing list