[Mlir-commits] [mlir] [mlir] Allow loop-like operations in `AbstractDenseForwardDataFlowAnalysis` (PR #66179)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 13 00:28:26 PDT 2023
https://github.com/victor-eds created https://github.com/llvm/llvm-project/pull/66179:
Remove assertion violated by loop-like operations.
>From 0bd0667b6f5edbc7d87ca728074bda2fa0c54d92 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 13 Sep 2023 08:26:55 +0100
Subject: [PATCH] [mlir] Allow loop-like operations in
`AbstractDenseForwardDataFlowAnalysis`
Remove assertion violated by loop-like operations.
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 2 -
.../Analysis/DataFlow/test-last-modified.mlir | 119 ++++++++++++++++++
.../TestDenseForwardDataFlowAnalysis.cpp | 24 ++--
mlir/test/lib/Dialect/Test/TestDialect.cpp | 10 ++
mlir/test/lib/Dialect/Test/TestOps.td | 12 ++
5 files changed, 156 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index eab408cd5977c3a..5e75883e61468ec 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -199,8 +199,6 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
op == branch ? std::optional<unsigned>()
: op->getBlock()->getParent()->getRegionNumber();
if (auto *toBlock = point.dyn_cast<Block *>()) {
- assert(op == branch ||
- toBlock->getParent() != op->getBlock()->getParent());
unsigned regionTo = toBlock->getParent()->getRegionNumber();
visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
*before, after);
diff --git a/mlir/test/Analysis/DataFlow/test-last-modified.mlir b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
index 069cbbcc0cc1684..43326614d67933c 100644
--- a/mlir/test/Analysis/DataFlow/test-last-modified.mlir
+++ b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
@@ -229,3 +229,122 @@ func.func @store_with_a_region_after_containing_a_store(%arg0: memref<f32>) -> m
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
return {tag = "return"} %arg0 : memref<f32>
}
+
+// CHECK-LABEL: test_tag: store_with_a_loop_region_before::before:
+// CHECK: operand #0
+// CHECK: - pre
+// CHECK: test_tag: inside_region:
+// CHECK: operand #0
+// CHECK: - region
+// CHECK: test_tag: after:
+// CHECK: operand #0
+// CHECK: - region
+// CHECK: test_tag: return:
+// CHECK: operand #0
+// CHECK: - post
+func.func @store_with_a_loop_region_before(%arg0: memref<f32>) -> memref<f32> {
+ %0 = arith.constant 0.0 : f32
+ %1 = arith.constant 1.0 : f32
+ memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+ memref.load %arg0[] {tag = "store_with_a_loop_region_before::before"} : memref<f32>
+ test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = true } {
+ memref.load %arg0[] {tag = "inside_region"} : memref<f32>
+ test.store_with_a_region_terminator
+ } : memref<f32>
+ memref.load %arg0[] {tag = "after"} : memref<f32>
+ memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+ return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_loop_region_after::before:
+// CHECK: operand #0
+// CHECK: - pre
+// CHECK: test_tag: inside_region:
+// CHECK: operand #0
+// CHECK: - pre
+// CHECK: test_tag: after:
+// CHECK: operand #0
+// CHECK: - region
+// CHECK: test_tag: return:
+// CHECK: operand #0
+// CHECK: - post
+func.func @store_with_a_loop_region_after(%arg0: memref<f32>) -> memref<f32> {
+ %0 = arith.constant 0.0 : f32
+ %1 = arith.constant 1.0 : f32
+ memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+ memref.load %arg0[] {tag = "store_with_a_loop_region_after::before"} : memref<f32>
+ test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = false } {
+ memref.load %arg0[] {tag = "inside_region"} : memref<f32>
+ test.store_with_a_region_terminator
+ } : memref<f32>
+ memref.load %arg0[] {tag = "after"} : memref<f32>
+ memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+ return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_loop_region_before_containing_a_store::before:
+// CHECK: operand #0
+// CHECK: - pre
+// CHECK: test_tag: enter_region:
+// CHECK: operand #0
+// CHECK-DAG: - region
+// CHECK-DAG: - inner
+// CHECK: test_tag: exit_region:
+// CHECK: operand #0
+// CHECK: - inner
+// CHECK: test_tag: after:
+// CHECK: operand #0
+// CHECK-DAG: - region
+// CHECK-DAG: - inner
+// CHECK: test_tag: return:
+// CHECK: operand #0
+// CHECK: - post
+func.func @store_with_a_loop_region_before_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
+ %0 = arith.constant 0.0 : f32
+ %1 = arith.constant 1.0 : f32
+ memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+ memref.load %arg0[] {tag = "store_with_a_loop_region_before_containing_a_store::before"} : memref<f32>
+ test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = true } {
+ memref.load %arg0[] {tag = "enter_region"} : memref<f32>
+ %2 = arith.constant 2.0 : f32
+ memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
+ memref.load %arg0[] {tag = "exit_region"} : memref<f32>
+ test.store_with_a_region_terminator
+ } : memref<f32>
+ memref.load %arg0[] {tag = "after"} : memref<f32>
+ memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+ return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_loop_region_after_containing_a_store::before:
+// CHECK: operand #0
+// CHECK: - pre
+// CHECK: test_tag: enter_region:
+// CHECK: operand #0
+// CHECK-DAG: - pre
+// CHECK-DAG: - inner
+// CHECK: test_tag: exit_region:
+// CHECK: operand #0
+// CHECK: - inner
+// CHECK: test_tag: after:
+// CHECK: operand #0
+// CHECK: - region
+// CHECK: test_tag: return:
+// CHECK: operand #0
+// CHECK: - post
+func.func @store_with_a_loop_region_after_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
+ %0 = arith.constant 0.0 : f32
+ %1 = arith.constant 1.0 : f32
+ memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+ memref.load %arg0[] {tag = "store_with_a_loop_region_after_containing_a_store::before"} : memref<f32>
+ test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = false } {
+ memref.load %arg0[] {tag = "enter_region"} : memref<f32>
+ %2 = arith.constant 2.0 : f32
+ memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
+ memref.load %arg0[] {tag = "exit_region"} : memref<f32>
+ test.store_with_a_region_terminator
+ } : memref<f32>
+ memref.load %arg0[] {tag = "after"} : memref<f32>
+ memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+ return {tag = "return"} %arg0 : memref<f32>
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 1a21719a44b994f..2520ed3d83b9efa 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -17,6 +17,8 @@
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
#include <optional>
using namespace mlir;
@@ -133,15 +135,19 @@ void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const LastModification &before,
LastModification *after) {
- auto testStoreWithARegion =
- dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
- if (testStoreWithARegion &&
- ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
- (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
- return visitOperation(branch, before, after);
- }
- AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
- branch, regionFrom, regionTo, before, after);
+ auto defaultHandling = [&]() {
+ AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
+ branch, regionFrom, regionTo, before, after);
+ };
+ TypeSwitch<Operation *>(branch.getOperation())
+ .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
+ [=](auto storeWithRegion) {
+ if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
+ (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
+ visitOperation(branch, before, after);
+ defaultHandling();
+ })
+ .Default([=](auto) { defaultHandling(); });
}
namespace {
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index ae4c9a85605e1c5..c02e6240ec7f361 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1318,6 +1318,16 @@ void TestStoreWithARegion::getSuccessorRegions(
regions.emplace_back();
}
+void TestStoreWithALoopRegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself. It is possible for the operation not to
+ // enter the body.
+ regions.emplace_back(
+ RegionSuccessor(&getBody(), getBody().front().getArguments()));
+ regions.emplace_back();
+}
+
LogicalResult
TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
::mlir::OperationState &state) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ceadab8fa4a086..0aa8ce4de9756fa 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2953,6 +2953,18 @@ def TestStoreWithARegion : TEST_Op<"store_with_a_region",
"$address attr-dict-with-keyword regions `:` type($address)";
}
+def TestStoreWithALoopRegion : TEST_Op<"store_with_a_loop_region",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ SingleBlock]> {
+ let arguments = (ins
+ Arg<AnyMemRef, "", [MemWrite]>:$address,
+ BoolAttr:$store_before_region
+ );
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat =
+ "$address attr-dict-with-keyword regions `:` type($address)";
+}
+
def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
[ReturnLike, Terminator, NoMemoryEffect]> {
let assemblyFormat = "attr-dict";
More information about the Mlir-commits
mailing list