[Mlir-commits] [mlir] 6247988 - One-shot-bufferize: fix for inconsistent while arg types in before/after.

Johannes Reifferscheid llvmlistbot at llvm.org
Thu Sep 8 01:24:23 PDT 2022


Author: Johannes Reifferscheid
Date: 2022-09-08T10:24:11+02:00
New Revision: 6247988e0751422fa10d70e64939c987dd3b81d9

URL: https://github.com/llvm/llvm-project/commit/6247988e0751422fa10d70e64939c987dd3b81d9
DIFF: https://github.com/llvm/llvm-project/commit/6247988e0751422fa10d70e64939c987dd3b81d9.diff

LOG: One-shot-bufferize: fix for inconsistent while arg types in before/after.

Currently, if the `before` and `after` regions of a while op have
tensor args in different indices, this leads to a crash.

This moves the pass-through check for args to the handling of the
condition block, since that is where the results are produced, so
it's also where copies must be made.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D133477

Added: 
    mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fb8c8dd3e2b8e..27be21430b17b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -762,13 +762,6 @@ struct WhileOpInterface
     OpBuilder::InsertionGuard g(rewriter);
     auto whileOp = cast<scf::WhileOp>(op);
     auto conditionOp = whileOp.getConditionOp();
-    auto yieldOp = whileOp.getYieldOp();
-
-    // Indices of all bbArgs that have tensor type. These are the ones that
-    // are bufferized. The "before" and "after" regions may have 
diff erent args.
-    DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
-    DenseSet<int64_t> indicesAfter =
-        getTensorIndices(whileOp.getAfterArguments());
 
     // For every yielded value, is the value equivalent to its corresponding
     // bbArg?
@@ -783,8 +776,9 @@ struct WhileOpInterface
     for (int64_t idx = 0;
          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
       Value value = conditionOp.getArgs()[idx];
-      if (!indicesBefore.contains(idx) ||
-          equivalentYieldsBefore.contains(idx)) {
+      if (!value.getType().isa<TensorType>() ||
+          (equivalentYieldsAfter.contains(idx) &&
+              equivalentYieldsBefore.contains(idx))) {
         beforeYieldValues.push_back(value);
         continue;
       }
@@ -799,27 +793,6 @@ struct WhileOpInterface
       conditionOp.getArgsMutable().assign(beforeYieldValues);
     });
 
-    // Update "after" region.
-    rewriter.setInsertionPoint(yieldOp);
-    SmallVector<Value> afterYieldValues;
-    for (int64_t idx = 0;
-         idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
-      Value value = yieldOp.getResults()[idx];
-      if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
-        afterYieldValues.push_back(value);
-        continue;
-      }
-      FailureOr<Value> alloc =
-          allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
-                                       /*escape=*/true, state.getOptions());
-      if (failed(alloc))
-        return failure();
-      afterYieldValues.push_back(*alloc);
-    }
-    rewriter.updateRootInPlace(yieldOp, [&]() {
-      yieldOp.getResultsMutable().assign(afterYieldValues);
-    });
-
     return success();
   }
 

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir
new file mode 100644
index 0000000000000..7e894b775fe06
--- /dev/null
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s \
+// RUN:     -one-shot-bufferize="allow-return-allocs create-deallocs=0" \
+// RUN:     -split-input-file | \
+// RUN: FileCheck %s --dump-input=always
+
+// A regression test to check that 
diff erent before and after argument types are
+// bufferized successfully.
+func.func @
diff erent_before_after_args() -> tensor<f32> {
+  %true = arith.constant true
+  %cst = arith.constant dense<0.0> : tensor<f32>
+  %0 = scf.while (%arg4 = %true) : (i1) -> (tensor<f32>) {
+    scf.condition(%true) %cst : tensor<f32>
+  } do {
+  ^bb0(%arg4: tensor<f32>):
+    scf.yield %true : i1
+  }
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: @
diff erent_before_after_args
\ No newline at end of file

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
index ec0ffa657d876..a5337831a0a50 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
@@ -98,9 +98,7 @@ func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>,
   ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
     // CHECK: } do {
     // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>):
-    // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {bufferization.escape = [true]} : tensor<5xi1>
-    // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {bufferization.escape = [true]} : tensor<5xi1>
-    // CHECK: scf.yield %[[yield2]], %[[yield3]]
+    // CHECK: scf.yield %[[b1]], %[[b0]]
     // CHECK: }
     scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1>
   }

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index dab4331a2586f..b37999315208d 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -430,8 +430,8 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
                                                   %idx: index)
   -> (tensor<5xi1>, tensor<5xi1>)
 {
-  // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]]
-  // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]]
+  // CHECK-DAG: %[[clone1:.*]] = bufferization.clone %[[arg1]]
+  // CHECK-DAG: %[[clone0:.*]] = bufferization.clone %[[arg0]]
   // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} {
   %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
       : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
@@ -454,19 +454,13 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
     // CHECK: } do {
     // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
     // CHECK: memref.store %{{.*}}, %[[b0]]
-    // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
-    // CHECK: memref.copy %[[b1]], %[[a3]]
+    // CHECK: %[[casted1:.*]] = memref.cast %[[b1]]
+    // CHECK: %[[casted0:.*]] = memref.cast %[[b0]]
+    // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
     // CHECK: memref.dealloc %[[b1]]
-    // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
-    // CHECK: memref.copy %[[b0]], %[[a2]]
+    // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
     // CHECK: memref.dealloc %[[b0]]
-    // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
-    // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
-    // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
-    // CHECK: memref.dealloc %[[a2]]
-    // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted3]]
-    // CHECK: memref.dealloc %[[a3]]
-    // CHECK: scf.yield %[[cloned3]], %[[cloned2]]
+    // CHECK: scf.yield %[[cloned1]], %[[cloned0]]
     // CHECK: }
     %pos = "dummy.some_op"() : () -> (index)
     %val = "dummy.another_op"() : () -> (i1)


        


More information about the Mlir-commits mailing list