[Mlir-commits] [mlir] [mlir][bufferization] Ownership dealloc: support `IsolatedFromAbove` (PR #97669)
Nikhil Kalra
llvmlistbot at llvm.org
Thu Jul 4 14:45:37 PDT 2024
https://github.com/nikalra updated https://github.com/llvm/llvm-project/pull/97669
>From 103d76069cdc7d92a6fdf10954c3624d9c539109 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 3 Jul 2024 20:37:02 -0700
Subject: [PATCH 1/2] [mlir][bufferization] Ownership dealloc: support
`IsolatedFromAbove`
Handle `IsolatedFromAbove` operations in `ownership-based-buffer-deallocation` by using the same contract as function boundaries. Specifically, IsolatedFromAbove ops cannot take ownership of their arguments, and rely on the caller to deallocate them.
---
.../Transforms/BufferViewFlowAnalysis.cpp | 5 +
.../OwnershipBasedBufferDeallocation.cpp | 5 +-
.../dealloc-isolated-group.mlir | 125 ++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 12 ++
mlir/test/lib/Dialect/Test/TestOps.td | 19 ++-
5 files changed, 163 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 72f47b8b468ea..d9525cb640e1c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -174,6 +174,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
}
}
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // Mark the entry block arguments and results as terminal.
+ populateTerminalValues(op);
+ }
+
return WalkResult::advance();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index ca5d0688b5b59..a52906174fb07 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -640,8 +640,9 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
continue;
// Adhere to function boundary ABI: no ownership of function argument
- // MemRefs is taken.
- if (isa<FunctionOpInterface>(block->getParentOp()) &&
+ // MemRefs is taken. Likewise for ops marked IsolatedFromAbove.
+ if ((isa<FunctionOpInterface>(block->getParentOp()) ||
+ block->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) &&
block->isEntryBlock()) {
Value newArg = buildBoolValue(builder, arg.getLoc(), false);
state.updateOwnership(arg, newArg);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
new file mode 100644
index 0000000000000..0c3ceda5237cc
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
@@ -0,0 +1,125 @@
+// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
+// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+
+func.func @function_call() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = test.isolated_one_region_with_recursive_memory_effects %alloc {
+ ^bb0(%arg1: memref<f64>):
+ test.region_yield %arg1 : memref<f64>
+ } : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call()
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[ALLOC0]]
+// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
+// CHECK: test.region_yield [[ARG]]
+// CHECK-NOT: bufferization.dealloc
+// CHECK: }
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true, %true, [[RET]]#1)
+
+// -----
+
+func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloca() : memref<f64>
+ %0 = arith.select %arg0, %alloc, %alloc2 : memref<f64>
+ %ret = test.isolated_one_region_with_recursive_memory_effects %0 {
+ ^bb0(%arg1: memref<f64>):
+ test.region_yield %arg1 : memref<f64>
+ } : (memref<f64>) -> (memref<f64>)
+ test.copy(%ret, %alloc) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func @function_call_requries_merged_ownership_mid_block
+// CHECK: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloca(
+// CHECK-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
+// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[SELECT]]
+// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
+// CHECK: test.region_yield [[ARG]]
+// CHECK-NOT: bufferization.dealloc
+// CHECK: }
+// CHECK-NEXT: test.copy
+// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+// CHECK-SAME: if (%true, [[RET]]#1)
+// CHECK-NOT: retain
+// CHECK-NEXT: return
+
+// -----
+
+func.func @g(%arg0: memref<f32>) -> memref<f32> {
+ %0 = test.isolated_one_region_with_recursive_memory_effects %arg0 {
+ ^bb0(%arg1: memref<f32>):
+ test.region_yield %arg1 : memref<f32>
+ } : (memref<f32>) -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func.func @g(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[VAL_0]] {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<f32>):
+// CHECK: test.region_yield %[[ARG]], %false : memref<f32>, i1
+// CHECK: } : (memref<f32>) -> (memref<f32>, i1)
+// CHECK: %[[VAL_4:.*]] = scf.if %[[BLOCK]]#1 -> (memref<f32>) {
+// CHECK: scf.yield %[[BLOCK]]#0 : memref<f32>
+// CHECK: } else {
+// CHECK: %[[VAL_6:.*]] = bufferization.clone %[[BLOCK]]#0 : memref<f32> to memref<f32>
+// CHECK: scf.yield %[[VAL_6]] : memref<f32>
+// CHECK: }
+// CHECK: %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f32> -> memref<f32>, index
+// CHECK: %[[VAL_11:.*]] = bufferization.dealloc (%[[BUF]] : memref<f32>) if (%[[BLOCK]]#1) retain (%[[VAL_4]] : memref<f32>)
+// CHECK: return %[[VAL_4]] : memref<f32>
+// CHECK: }
+
+// -----
+
+func.func @alloc_yielded_from_block() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = test.isolated_one_region_with_recursive_memory_effects %alloc {
+ ^bb0(%arg1: memref<f64>):
+ %0 = memref.load %arg1[] : memref<f64>
+ %c1 = arith.constant 1.0 : f64
+ %r0 = arith.cmpf oeq, %0, %c1 : f64
+ %1 = scf.if %r0 -> memref<f64> {
+ %alloc3 = memref.alloc() : memref<f64>
+ scf.yield %alloc3 : memref<f64>
+ } else {
+ scf.yield %arg1 : memref<f64>
+ }
+ test.region_yield %1 : memref<f64>
+ } : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// CHECK-LABEL: func.func @alloc_yielded_from_block() {
+// CHECK: %true = arith.constant true
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<f64>
+// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[ALLOC]] {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<f64>):
+// CHECK: %[[VAL_9:.*]] = arith.cmpf oeq
+// CHECK: %[[VAL_10:.*]]:2 = scf.if %[[VAL_9]] -> (memref<f64>, i1) {
+// CHECK: %[[BLOCK_ALLOC:.*]] = memref.alloc() : memref<f64>
+// CHECK: scf.yield %[[BLOCK_ALLOC]], %true_{{[0-9]*}} : memref<f64>, i1
+// CHECK: } else {
+// CHECK: scf.yield %[[ARG]], %false : memref<f64>, i1
+// CHECK: }
+// CHECK: test.region_yield %[[VAL_10]]#0, %[[VAL_10]]#1 : memref<f64>, i1
+// CHECK: } : (memref<f64>) -> (memref<f64>, i1)
+// CHECK: test.copy
+// CHECK: %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f64> -> memref<f64>, index
+// CHECK: bufferization.dealloc (%[[ALLOC]], %{{.*}}, %[[BUF]] : memref<f64>, memref<f64>, memref<f64>) if (%true, %true, %[[BLOCK]]#1)
+// CHECK: return
+// CHECK: }
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index fbaa102d3e33c..6666c9b86db42 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -110,6 +110,18 @@ void IsolatedRegionOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
+//===----------------------------------------------------------------------===//
+// IsolatedOneRegionWithRecursiveMemoryEffectsOp
+//===----------------------------------------------------------------------===//
+
+void IsolatedOneRegionWithRecursiveMemoryEffectsOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody());
+ else
+ regions.emplace_back((*this)->getResults());
+}
+
//===----------------------------------------------------------------------===//
// SSACFGRegionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e1ec1428ee6d6..bbe84572868b2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -507,6 +507,23 @@ def OneRegionWithRecursiveMemoryEffectsOp
let regions = (region SizedRegion<1>:$body);
}
+def IsolatedOneRegionWithRecursiveMemoryEffectsOp
+ : TEST_Op<"isolated_one_region_with_recursive_memory_effects", [
+ RecursiveMemoryEffects,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"RegionYieldOp">,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
+ let description = [{
+ IsolatedFromAbove Op that has one region and recursive side effects.
+ }];
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = [{
+ attr-dict-with-keyword $operands $body `:` functional-type(operands, results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NoTerminator Operation
//===----------------------------------------------------------------------===//
@@ -2147,7 +2164,7 @@ def RegionYieldOp : TEST_Op<"region_yield",
This operation is used in a region and yields the corresponding type for
that operation.
}];
- let arguments = (ins AnyType:$result);
+ let arguments = (ins Variadic<AnyType>:$result);
let assemblyFormat = [{
$result `:` type($result) attr-dict
}];
>From 79b975f14f385c5d1d855bfbdf920683bd1c5881 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 4 Jul 2024 14:45:24 -0700
Subject: [PATCH 2/2] update tests based on feedback
---
.../Transforms/BufferViewFlowAnalysis.cpp | 5 ---
.../OwnershipBasedBufferDeallocation.cpp | 5 +--
.../dealloc-isolated-group.mlir | 42 ++++++++++---------
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 9 +++-
mlir/test/lib/Dialect/Test/TestOps.td | 2 +-
5 files changed, 33 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index d9525cb640e1c..72f47b8b468ea 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -174,11 +174,6 @@ void BufferViewFlowAnalysis::build(Operation *op) {
}
}
- if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
- // Mark the entry block arguments and results as terminal.
- populateTerminalValues(op);
- }
-
return WalkResult::advance();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index a52906174fb07..ca5d0688b5b59 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -640,9 +640,8 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
continue;
// Adhere to function boundary ABI: no ownership of function argument
- // MemRefs is taken. Likewise for ops marked IsolatedFromAbove.
- if ((isa<FunctionOpInterface>(block->getParentOp()) ||
- block->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) &&
+ // MemRefs is taken.
+ if (isa<FunctionOpInterface>(block->getParentOp()) &&
block->isEntryBlock()) {
Value newArg = buildBoolValue(builder, arg.getLoc(), false);
state.updateOwnership(arg, newArg);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
index 0c3ceda5237cc..f0f097e0c235b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-isolated-group.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
-// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
+// RUN: -buffer-deallocation-simplification -split-input-file -canonicalize %s | FileCheck %s
func.func @function_call() {
%alloc = memref.alloc() : memref<f64>
@@ -15,14 +15,15 @@ func.func @function_call() {
// CHECK-LABEL: func @function_call()
// CHECK: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
-// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[ALLOC0]]
-// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
-// CHECK: test.region_yield [[ARG]]
+// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[ALLOC0]], %false
+// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>, [[OWN:%.+]]: i1)
+// CHECK: test.region_yield [[ARG]], [[OWN]]
// CHECK-NOT: bufferization.dealloc
// CHECK: }
// CHECK-NEXT: test.copy
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
-// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true, %true, [[RET]]#1)
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC1]] :{{.*}}) if (%true)
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :{{.*}}) if (%true, [[RET]]#1)
// -----
@@ -42,9 +43,9 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
// CHECK: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloca(
// CHECK-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
-// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[SELECT]]
-// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>)
-// CHECK: test.region_yield [[ARG]]
+// CHECK-NEXT: [[RET:%.+]]:2 = test.isolated_one_region_with_recursive_memory_effects [[SELECT]], %false
+// CHECK-NEXT: ^bb0([[ARG:%.+]]: memref<f64>, [[OWN:%.+]]: i1)
+// CHECK: test.region_yield [[ARG]], [[OWN]]
// CHECK-NOT: bufferization.dealloc
// CHECK: }
// CHECK-NEXT: test.copy
@@ -66,10 +67,10 @@ func.func @g(%arg0: memref<f32>) -> memref<f32> {
// CHECK-LABEL: func.func @g(
// CHECK-SAME: %[[VAL_0:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[VAL_0]] {
-// CHECK: ^bb0(%[[ARG:.*]]: memref<f32>):
-// CHECK: test.region_yield %[[ARG]], %false : memref<f32>, i1
-// CHECK: } : (memref<f32>) -> (memref<f32>, i1)
+// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[VAL_0]], %false {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<f32>, [[OWN:%.+]]: i1):
+// CHECK: test.region_yield %[[ARG]], [[OWN]] : memref<f32>, i1
+// CHECK: }
// CHECK: %[[VAL_4:.*]] = scf.if %[[BLOCK]]#1 -> (memref<f32>) {
// CHECK: scf.yield %[[BLOCK]]#0 : memref<f32>
// CHECK: } else {
@@ -106,20 +107,21 @@ func.func @alloc_yielded_from_block() {
// CHECK-LABEL: func.func @alloc_yielded_from_block() {
// CHECK: %true = arith.constant true
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<f64>
-// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[ALLOC]] {
-// CHECK: ^bb0(%[[ARG:.*]]: memref<f64>):
+// CHECK: %[[BLOCK:.*]]:2 = test.isolated_one_region_with_recursive_memory_effects %[[ALLOC]], %false {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<f64>, [[OWN:%.+]]: i1):
// CHECK: %[[VAL_9:.*]] = arith.cmpf oeq
-// CHECK: %[[VAL_10:.*]]:2 = scf.if %[[VAL_9]] -> (memref<f64>, i1) {
+// CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_9]] -> (memref<f64>) {
// CHECK: %[[BLOCK_ALLOC:.*]] = memref.alloc() : memref<f64>
-// CHECK: scf.yield %[[BLOCK_ALLOC]], %true_{{[0-9]*}} : memref<f64>, i1
+// CHECK: scf.yield %[[BLOCK_ALLOC]] : memref<f64>
// CHECK: } else {
-// CHECK: scf.yield %[[ARG]], %false : memref<f64>, i1
+// CHECK: scf.yield %[[ARG]] : memref<f64>
// CHECK: }
-// CHECK: test.region_yield %[[VAL_10]]#0, %[[VAL_10]]#1 : memref<f64>, i1
-// CHECK: } : (memref<f64>) -> (memref<f64>, i1)
+// CHECK: bufferization.dealloc ({{.*}}) if ([[OWN]])
+// CHECK: test.region_yield %[[VAL_10]]
+// CHECK: }
// CHECK: test.copy
// CHECK: %[[BUF:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[BLOCK]]#0 : memref<f64> -> memref<f64>, index
-// CHECK: bufferization.dealloc (%[[ALLOC]], %{{.*}}, %[[BUF]] : memref<f64>, memref<f64>, memref<f64>) if (%true, %true, %[[BLOCK]]#1)
+// CHECK: bufferization.dealloc (%[[ALLOC]], %[[BUF]] : memref<f64>, memref<f64>) if (%true, %[[BLOCK]]#1)
// CHECK: return
// CHECK: }
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 6666c9b86db42..8ba5f7e695c40 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -10,6 +10,7 @@
#include "TestOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
@@ -114,10 +115,16 @@ void IsolatedRegionOp::print(OpAsmPrinter &p) {
// IsolatedOneRegionWithRecursiveMemoryEffectsOp
//===----------------------------------------------------------------------===//
+OperandRange
+IsolatedOneRegionWithRecursiveMemoryEffectsOp::getEntrySuccessorOperands(
+ RegionBranchPoint) {
+ return getOperands();
+}
+
void IsolatedOneRegionWithRecursiveMemoryEffectsOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
if (point.isParent())
- regions.emplace_back(&getBody());
+ regions.emplace_back(&getBody(), getBody().getArguments());
else
regions.emplace_back((*this)->getResults());
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bbe84572868b2..a465fe3c198ae 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -512,7 +512,7 @@ def IsolatedOneRegionWithRecursiveMemoryEffectsOp
RecursiveMemoryEffects,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"RegionYieldOp">,
- DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands"]>]> {
let description = [{
IsolatedFromAbove Op that has one region and recursive side effects.
}];
More information about the Mlir-commits
mailing list