[Mlir-commits] [mlir] eaf20c4 - [mlir] Fix a cast that should be a dyn_cast.
Johannes Reifferscheid
llvmlistbot at llvm.org
Thu Sep 22 04:13:33 PDT 2022
Author: Johannes Reifferscheid
Date: 2022-09-22T13:13:21+02:00
New Revision: eaf20c4fc257db0bcbd97b0f39836a53eeb3039a
URL: https://github.com/llvm/llvm-project/commit/eaf20c4fc257db0bcbd97b0f39836a53eeb3039a
DIFF: https://github.com/llvm/llvm-project/commit/eaf20c4fc257db0bcbd97b0f39836a53eeb3039a.diff
LOG: [mlir] Fix a cast that should be a dyn_cast.
This fixes a crash for certain IR, see the new test case for an
example.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D134424
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b782d1be4bc82..2f1e1a837b890 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -891,7 +891,7 @@ struct WhileOpInterface
assert(value.getType().isa<TensorType>() && "expected tensor type");
// Case 1: Block argument of the "before" region.
- if (auto bbArg = value.cast<BlockArgument>()) {
+ if (auto bbArg = value.dyn_cast<BlockArgument>()) {
if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
auto yieldOp = whileOp.getYieldOp();
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 7a15d9b43c58b..6d3dbeee1c8e3 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -876,3 +876,26 @@ func.func @buffer_type_of_collapse_shape(%arg0: tensor<f64>) {
}
return
}
+
+// -----
+
+// This is a regression test. Just check that the IR bufferizes.
+
+// CHECK-LABEL: func @non_block_argument_yield
+func.func @non_block_argument_yield() {
+ %true = arith.constant true
+ %0 = bufferization.alloc_tensor() : tensor<i32>
+ %1 = scf.while (%arg0 = %0) : (tensor<i32>) -> (tensor<i32>) {
+ scf.condition(%true) %arg0 : tensor<i32>
+ } do {
+ ^bb0(%arg0: tensor<i32>):
+ %ret = scf.while (%arg1 = %0) : (tensor<i32>) -> (tensor<i32>) {
+ scf.condition(%true) %arg1 : tensor<i32>
+ } do {
+ ^bb0(%arg7: tensor<i32>):
+ scf.yield %0 : tensor<i32>
+ }
+ scf.yield %ret : tensor<i32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list