[Mlir-commits] [mlir] ae8cb64 - [mlir][scf][bufferize] Fix bug in WhileOp analysis verification

Matthias Springer llvmlistbot at llvm.org
Mon May 15 06:46:29 PDT 2023


Author: Matthias Springer
Date: 2023-05-15T15:42:56+02:00
New Revision: ae8cb6437294ca99ba203607c0dd522db4dbf6b6

URL: https://github.com/llvm/llvm-project/commit/ae8cb6437294ca99ba203607c0dd522db4dbf6b6
DIFF: https://github.com/llvm/llvm-project/commit/ae8cb6437294ca99ba203607c0dd522db4dbf6b6.diff

LOG: [mlir][scf][bufferize] Fix bug in WhileOp analysis verification

Block arguments and yielded values are not equivalent if there are not enough block arguments. This fixes #59442.

Differential Revision: https://reviews.llvm.org/D145575

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index ad395a9ac457b..4b0d0e40740f0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -902,10 +902,12 @@ struct WhileOpInterface
 
     auto conditionOp = whileOp.getConditionOp();
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+      Block *block = conditionOp->getBlock();
       if (!isa<TensorType>(it.value().getType()))
         continue;
-      if (!state.areEquivalentBufferizedValues(
-              it.value(), conditionOp->getBlock()->getArgument(it.index())))
+      if (it.index() >= block->getNumArguments() ||
+          !state.areEquivalentBufferizedValues(it.value(),
+                                               block->getArgument(it.index())))
         return conditionOp->emitError()
                << "Condition arg #" << it.index()
                << " is not equivalent to the corresponding iter bbArg";
@@ -913,10 +915,12 @@ struct WhileOpInterface
 
     auto yieldOp = whileOp.getYieldOp();
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
+      Block *block = yieldOp->getBlock();
       if (!isa<TensorType>(it.value().getType()))
         continue;
-      if (!state.areEquivalentBufferizedValues(
-              it.value(), yieldOp->getBlock()->getArgument(it.index())))
+      if (it.index() >= block->getNumArguments() ||
+          !state.areEquivalentBufferizedValues(it.value(),
+                                               block->getArgument(it.index())))
         return yieldOp->emitError()
                << "Yield operand #" << it.index()
                << " is not equivalent to the corresponding iter bbArg";

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 189ef6be0dff9..a2d47f08adbae 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -324,3 +324,17 @@ func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> {
 
 // This function may write to buffer(%ptr).
 func.func private @maybe_writing_func(%ptr : tensor<*xf32>)
+
+// -----
+
+func.func @regression_scf_while() {
+  %false = arith.constant false
+  %8 = bufferization.alloc_tensor() : tensor<10x10xf32>
+  scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () {
+    scf.condition(%false)
+  } do {
+    // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}}
+    scf.yield %8 : tensor<10x10xf32>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list