[Mlir-commits] [mlir] eaf20c4 - [mlir] Fix a cast that should be a dyn_cast.

Johannes Reifferscheid llvmlistbot at llvm.org
Thu Sep 22 04:13:33 PDT 2022


Author: Johannes Reifferscheid
Date: 2022-09-22T13:13:21+02:00
New Revision: eaf20c4fc257db0bcbd97b0f39836a53eeb3039a

URL: https://github.com/llvm/llvm-project/commit/eaf20c4fc257db0bcbd97b0f39836a53eeb3039a
DIFF: https://github.com/llvm/llvm-project/commit/eaf20c4fc257db0bcbd97b0f39836a53eeb3039a.diff

LOG: [mlir] Fix a cast that should be a dyn_cast.

This fixes a crash for certain IR, see the new test case for an
example.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D134424

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b782d1be4bc82..2f1e1a837b890 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -891,7 +891,7 @@ struct WhileOpInterface
     assert(value.getType().isa<TensorType>() && "expected tensor type");
 
     // Case 1: Block argument of the "before" region.
-    if (auto bbArg = value.cast<BlockArgument>()) {
+    if (auto bbArg = value.dyn_cast<BlockArgument>()) {
       if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
         Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
         auto yieldOp = whileOp.getYieldOp();

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 7a15d9b43c58b..6d3dbeee1c8e3 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -876,3 +876,26 @@ func.func @buffer_type_of_collapse_shape(%arg0: tensor<f64>) {
   }
   return
 }
+
+// -----
+
+// This is a regression test. Just check that the IR bufferizes.
+  
+// CHECK-LABEL: func @non_block_argument_yield
+func.func @non_block_argument_yield() {
+  %true = arith.constant true 
+  %0 = bufferization.alloc_tensor() : tensor<i32>
+  %1 = scf.while (%arg0 = %0) : (tensor<i32>) -> (tensor<i32>) {
+    scf.condition(%true) %arg0 : tensor<i32>
+  } do {
+  ^bb0(%arg0: tensor<i32>):
+    %ret = scf.while (%arg1 = %0) : (tensor<i32>) -> (tensor<i32>) {
+      scf.condition(%true) %arg1 : tensor<i32>
+    } do {
+    ^bb0(%arg7: tensor<i32>):
+      scf.yield %0 : tensor<i32>
+    }
+    scf.yield %ret : tensor<i32>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list