[Mlir-commits] [mlir] 6cd7b65 - [mlir][bufferization] Prevent crash in one shot bufferization with unranked tensor cast

Kai Sasaki llvmlistbot at llvm.org
Thu May 18 16:54:51 PDT 2023


Author: Kai Sasaki
Date: 2023-05-19T08:54:43+09:00
New Revision: 6cd7b655d83047565f03cfbb0bdf70cae3acf1c2

URL: https://github.com/llvm/llvm-project/commit/6cd7b655d83047565f03cfbb0bdf70cae3acf1c2
DIFF: https://github.com/llvm/llvm-project/commit/6cd7b655d83047565f03cfbb0bdf70cae3acf1c2.diff

LOG: [mlir][bufferization] Prevent crash in one shot bufferization with unranked tensor cast

One shot bufferization does not support bufferizing the cast between unranked tensors. To prevent the crash, we can check the compatibility of the result type in advance. Reported in https://github.com/llvm/llvm-project/issues/62369.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D149239

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d93d88630fd86..9253bc2ffeb7e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -94,6 +94,11 @@ struct CastOpInterface
         bufferization::getBufferType(castOp.getResult(), options);
     if (failed(resultMemRefType))
       return failure();
+    if (resultBuffer->getType() == *resultMemRefType) {
+      // This cast is a no-op.
+      replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
+      return success();
+    }
 
     // Replace the op with a memref.cast.
     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e980b2b777654..acefa14db487f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -199,3 +199,12 @@ func.func @read_of_alias(%t: tensor<100xf32>, %pos1: index, %pos2: index,
   %3 = tensor.extract %0[%pos3] : tensor<100xf32>
   return %2, %3 : f32, f32
 }
+
+// -----
+
+// CHECK-LABEL: func @from_unranked_to_unranked
+func.func @from_unranked_to_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+  // CHECK: return %arg{{.*}} : tensor<*xi32>
+  %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
+  return %0 : tensor<*xi32>
+}


        


More information about the Mlir-commits mailing list