[Mlir-commits] [mlir] 6247988 - One-shot-bufferize: fix for inconsistent while arg types in before/after.
Johannes Reifferscheid
llvmlistbot at llvm.org
Thu Sep 8 01:24:23 PDT 2022
Author: Johannes Reifferscheid
Date: 2022-09-08T10:24:11+02:00
New Revision: 6247988e0751422fa10d70e64939c987dd3b81d9
URL: https://github.com/llvm/llvm-project/commit/6247988e0751422fa10d70e64939c987dd3b81d9
DIFF: https://github.com/llvm/llvm-project/commit/6247988e0751422fa10d70e64939c987dd3b81d9.diff
LOG: One-shot-bufferize: fix for inconsistent while arg types in before/after.
Currently, if the `before` and `after` regions of a while op have
tensor args in different indices, this leads to a crash.
This moves the pass-through check for args to the handling of the
condition block, since that is where the results are produced, so
it's also where copies must be made.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D133477
Added:
mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fb8c8dd3e2b8e..27be21430b17b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -762,13 +762,6 @@ struct WhileOpInterface
OpBuilder::InsertionGuard g(rewriter);
auto whileOp = cast<scf::WhileOp>(op);
auto conditionOp = whileOp.getConditionOp();
- auto yieldOp = whileOp.getYieldOp();
-
- // Indices of all bbArgs that have tensor type. These are the ones that
- // are bufferized. The "before" and "after" regions may have
diff erent args.
- DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
- DenseSet<int64_t> indicesAfter =
- getTensorIndices(whileOp.getAfterArguments());
// For every yielded value, is the value equivalent to its corresponding
// bbArg?
@@ -783,8 +776,9 @@ struct WhileOpInterface
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
- if (!indicesBefore.contains(idx) ||
- equivalentYieldsBefore.contains(idx)) {
+ if (!value.getType().isa<TensorType>() ||
+ (equivalentYieldsAfter.contains(idx) &&
+ equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
continue;
}
@@ -799,27 +793,6 @@ struct WhileOpInterface
conditionOp.getArgsMutable().assign(beforeYieldValues);
});
- // Update "after" region.
- rewriter.setInsertionPoint(yieldOp);
- SmallVector<Value> afterYieldValues;
- for (int64_t idx = 0;
- idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
- Value value = yieldOp.getResults()[idx];
- if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
- afterYieldValues.push_back(value);
- continue;
- }
- FailureOr<Value> alloc =
- allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
- /*escape=*/true, state.getOptions());
- if (failed(alloc))
- return failure();
- afterYieldValues.push_back(*alloc);
- }
- rewriter.updateRootInPlace(yieldOp, [&]() {
- yieldOp.getResultsMutable().assign(afterYieldValues);
- });
-
return success();
}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir
new file mode 100644
index 0000000000000..7e894b775fe06
--- /dev/null
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s \
+// RUN: -one-shot-bufferize="allow-return-allocs create-deallocs=0" \
+// RUN: -split-input-file | \
+// RUN: FileCheck %s --dump-input=always
+
+// A regression test to check that
diff erent before and after argument types are
+// bufferized successfully.
+func.func @
diff erent_before_after_args() -> tensor<f32> {
+ %true = arith.constant true
+ %cst = arith.constant dense<0.0> : tensor<f32>
+ %0 = scf.while (%arg4 = %true) : (i1) -> (tensor<f32>) {
+ scf.condition(%true) %cst : tensor<f32>
+ } do {
+ ^bb0(%arg4: tensor<f32>):
+ scf.yield %true : i1
+ }
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: @
diff erent_before_after_args
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
index ec0ffa657d876..a5337831a0a50 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
@@ -98,9 +98,7 @@ func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>,
^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
// CHECK: } do {
// CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>):
- // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {bufferization.escape = [true]} : tensor<5xi1>
- // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {bufferization.escape = [true]} : tensor<5xi1>
- // CHECK: scf.yield %[[yield2]], %[[yield3]]
+ // CHECK: scf.yield %[[b1]], %[[b0]]
// CHECK: }
scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1>
}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index dab4331a2586f..b37999315208d 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -430,8 +430,8 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
%idx: index)
-> (tensor<5xi1>, tensor<5xi1>)
{
- // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]]
- // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]]
+ // CHECK-DAG: %[[clone1:.*]] = bufferization.clone %[[arg1]]
+ // CHECK-DAG: %[[clone0:.*]] = bufferization.clone %[[arg0]]
// CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} {
%r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
: (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
@@ -454,19 +454,13 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: } do {
// CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
// CHECK: memref.store %{{.*}}, %[[b0]]
- // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
- // CHECK: memref.copy %[[b1]], %[[a3]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[b1]]
+ // CHECK: %[[casted0:.*]] = memref.cast %[[b0]]
+ // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
// CHECK: memref.dealloc %[[b1]]
- // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
- // CHECK: memref.copy %[[b0]], %[[a2]]
+ // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[b0]]
- // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
- // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
- // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
- // CHECK: memref.dealloc %[[a2]]
- // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted3]]
- // CHECK: memref.dealloc %[[a3]]
- // CHECK: scf.yield %[[cloned3]], %[[cloned2]]
+ // CHECK: scf.yield %[[cloned1]], %[[cloned0]]
// CHECK: }
%pos = "dummy.some_op"() : () -> (index)
%val = "dummy.another_op"() : () -> (i1)
More information about the Mlir-commits
mailing list