[Mlir-commits] [mlir] b6ae3f8 - [mlir][tensor][bufferize] Implement getBufferType for CastOp

Matthias Springer llvmlistbot at llvm.org
Wed Feb 1 05:24:27 PST 2023


Author: Matthias Springer
Date: 2023-02-01T14:24:10+01:00
New Revision: b6ae3f88731c8a82668bd5a992a5ae9b41e716a9

URL: https://github.com/llvm/llvm-project/commit/b6ae3f88731c8a82668bd5a992a5ae9b41e716a9
DIFF: https://github.com/llvm/llvm-project/commit/b6ae3f88731c8a82668bd5a992a5ae9b41e716a9.diff

LOG: [mlir][tensor][bufferize] Implement getBufferType for CastOp

This interface method is used to compute the buffer type of a value during bufferization. It was missing. This is interface method is used during loop bufferization.

Also fix a bug where a cast from an unranked tensor to a ranked tensor type did not always apply a fully dynamic layout map on the result memref.

Differential Revision: https://reviews.llvm.org/D143063

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9c5b7d520ba1d..14401a98c8e2b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -51,6 +51,39 @@ struct CastOpInterface
     return BufferRelation::Equivalent;
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto castOp = cast<tensor::CastOp>(op);
+    auto maybeSrcBufferType =
+        bufferization::getBufferType(castOp.getSource(), options, fixedTypes);
+    if (failed(maybeSrcBufferType))
+      return failure();
+    Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
+
+    // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
+    // type in case the input is an unranked tensor type.
+
+    // Case 1: Casting an unranked tensor
+    if (castOp.getSource().getType().isa<UnrankedTensorType>()) {
+      // When casting to a ranked tensor, we cannot infer any static offset or
+      // strides from the source. Assume fully dynamic.
+      return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+    }
+
+    // Case 2: Casting to an unranked tensor type
+    if (castOp.getType().isa<UnrankedTensorType>()) {
+      return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+    }
+
+    // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
+    // change.
+    auto rankedResultType = castOp.getType().cast<RankedTensorType>();
+    return MemRefType::get(
+        rankedResultType.getShape(), rankedResultType.getElementType(),
+        maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
+  }
+
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto castOp = cast<tensor::CastOp>(op);
@@ -60,25 +93,19 @@ struct CastOpInterface
         getBuffer(rewriter, castOp.getSource(), options);
     if (failed(resultBuffer))
       return failure();
-    auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
-    TensorType resultTensorType =
-        castOp.getResult().getType().cast<TensorType>();
-    MemRefLayoutAttrInterface layout;
 
-    if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
-      if (resultTensorType.isa<RankedTensorType>())
-        layout = rankedMemRefType.getLayout();
-
-    // Compute the new memref type.
-    Type resultMemRefType = getMemRefType(castOp.getResult(), options, layout,
-                                          sourceMemRefType.getMemorySpace());
+    // Compute the new type.
+    auto resultMemRefType =
+        bufferization::getBufferType(castOp.getResult(), options);
+    if (failed(resultMemRefType))
+      return failure();
 
     // Replace the op with a memref.cast.
     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
-                                             resultMemRefType) &&
+                                             *resultMemRefType) &&
            "CallOp::bufferize: cast incompatible");
-    replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
-                                                 *resultBuffer);
+    replaceOpWithNewBufferizedOp<memref::CastOp>(
+        rewriter, op, *resultMemRefType, *resultBuffer);
 
     return success();
   }

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 90c88d86a11c1..587eed843c71f 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -925,3 +925,26 @@ func.func @non_block_argument_yield() {
   }
   return
 }
+
+// -----
+
+// This is a regression test. Make sure that bufferization succeeds.
+
+// CHECK-LABEL: func @regression_cast_in_loop(
+func.func @regression_cast_in_loop() -> tensor<2xindex> {
+  %false = arith.constant false
+  %c0 = arith.constant 0 : index
+  %0 = bufferization.alloc_tensor() : tensor<2xindex>
+  // CHECK: scf.while (%{{.*}} = %{{.*}}) : (memref<2xindex>) -> memref<2xindex>
+  %1 = scf.while (%arg0 = %0) : (tensor<2xindex>) -> tensor<2xindex> {
+    scf.condition(%false) %arg0 : tensor<2xindex>
+  } do {
+  // CHECK: ^bb0(%{{.*}}: memref<2xindex>):
+  ^bb0(%arg0: tensor<2xindex>):
+    %cast = tensor.cast %0 : tensor<2xindex> to tensor<?xindex>
+    %inserted = tensor.insert %c0 into %cast[%c0] : tensor<?xindex>
+    %cast_0 = tensor.cast %inserted : tensor<?xindex> to tensor<2xindex>
+    scf.yield %cast_0 : tensor<2xindex>
+  }
+  return %1 : tensor<2xindex>
+}

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 4948b0dccf976..cbcc1e3d339b6 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -40,8 +40,8 @@ func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
 // CHECK-LABEL:   func @tensor.cast_from_unranked(
 // CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
 // CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
-// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
-// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32>
+// CHECK:           %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32, strided<[?], offset: ?>>
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32, strided<[?], offset: ?>>
 // CHECK:           return %[[RET]] : tensor<2xf32>
 func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
   %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 59fde562bd876..25164a4ba870c 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -347,3 +347,26 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
   %1 = tensor.dim %t, %c0 : tensor<?xf32>
   return %0, %1 : tensor<?xf32>, index
 }
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)>
+// CHECK-LABEL: func.func @cast_retains_buffer_layout(
+//  CHECK-SAME:     %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
+//       CHECK:   %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]>
+//       CHECK:   %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>>
+//       CHECK:   return %[[slice]]
+func.func @cast_retains_buffer_layout(
+    %t: tensor<?xf32>
+        {bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>},
+    %sz: index)
+  -> (tensor<10xf32>, tensor<?xf32>)
+{
+  %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+  %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>
+
+  // Note: The %casted return type is folded away because both buffers are
+  // equivalent. Therefore, we currently loose some static type information
+  // in the caller.
+  return %casted, %slice : tensor<10xf32>, tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list