[Mlir-commits] [mlir] [mlir] Allow loop-like operations in `AbstractDenseForwardDataFlowAnalysis` (PR #66179)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 13 00:29:27 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir
            
<details>
<summary>Changes</summary>
Remove assertion violated by loop-like operations.
--
Full diff: https://github.com/llvm/llvm-project/pull/66179.diff

5 Files Affected:

- (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (-2) 
- (modified) mlir/test/Analysis/DataFlow/test-last-modified.mlir (+119) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp (+15-9) 
- (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+10) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+12) 


<pre>
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index eab408cd5977c3a..5e75883e61468ec 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -199,8 +199,6 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
         op == branch ? std::optional<unsigned>()
                      : op->getBlock()->getParent()->getRegionNumber();
     if (auto *toBlock = point.dyn_cast<Block *>()) {
-      assert(op == branch ||
-             toBlock->getParent() != op->getBlock()->getParent());
       unsigned regionTo = toBlock->getParent()->getRegionNumber();
       visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
                                            *before, after);
diff --git a/mlir/test/Analysis/DataFlow/test-last-modified.mlir b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
index 069cbbcc0cc1684..43326614d67933c 100644
--- a/mlir/test/Analysis/DataFlow/test-last-modified.mlir
+++ b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
@@ -229,3 +229,122 @@ func.func @store_with_a_region_after_containing_a_store(%arg0: memref<f32>) -> m
   memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
   return {tag = "return"} %arg0 : memref<f32>
 }
+
+// CHECK-LABEL: test_tag: store_with_a_loop_region_before::before:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: inside_region:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: after:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @store_with_a_loop_region_before(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_loop_region_before::before"} : memref<f32>
+  test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = true } {
+    memref.load %arg0[] {tag = "inside_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_loop_region_after::before:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: inside_region:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: after:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @store_with_a_loop_region_after(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_loop_region_after::before"} : memref<f32>
+  test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = false } {
+    memref.load %arg0[] {tag = "inside_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL:     test_tag: store_with_a_loop_region_before_containing_a_store::before:
+// CHECK:      operand #0
+// CHECK:       - pre
+// CHECK:     test_tag: enter_region:
+// CHECK:      operand #0
+// CHECK-DAG:   - region
+// CHECK-DAG:   - inner
+// CHECK:     test_tag: exit_region:
+// CHECK:      operand #0
+// CHECK:       - inner
+// CHECK:     test_tag: after:
+// CHECK:      operand #0
+// CHECK-DAG:   - region
+// CHECK-DAG:   - inner
+// CHECK:     test_tag: return:
+// CHECK:      operand #0
+// CHECK:       - post
+func.func @store_with_a_loop_region_before_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_loop_region_before_containing_a_store::before"} : memref<f32>
+  test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = true } {
+    memref.load %arg0[] {tag = "enter_region"} : memref<f32>
+    %2 = arith.constant 2.0 : f32
+    memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
+    memref.load %arg0[] {tag = "exit_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL:     test_tag: store_with_a_loop_region_after_containing_a_store::before:
+// CHECK:      operand #0
+// CHECK:       - pre
+// CHECK:     test_tag: enter_region:
+// CHECK:      operand #0
+// CHECK-DAG:   - pre
+// CHECK-DAG:   - inner
+// CHECK:     test_tag: exit_region:
+// CHECK:      operand #0
+// CHECK:       - inner
+// CHECK:     test_tag: after:
+// CHECK:      operand #0
+// CHECK:       - region
+// CHECK:     test_tag: return:
+// CHECK:      operand #0
+// CHECK:       - post
+func.func @store_with_a_loop_region_after_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_loop_region_after_containing_a_store::before"} : memref<f32>
+  test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = false } {
+    memref.load %arg0[] {tag = "enter_region"} : memref<f32>
+    %2 = arith.constant 2.0 : f32
+    memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
+    memref.load %arg0[] {tag = "exit_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 1a21719a44b994f..2520ed3d83b9efa 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -17,6 +17,8 @@
 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
 using namespace mlir;
@@ -133,15 +135,19 @@ void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
     RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
     std::optional<unsigned> regionTo, const LastModification &before,
     LastModification *after) {
-  auto testStoreWithARegion =
-      dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
-  if (testStoreWithARegion &&
-      ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
-       (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
-    return visitOperation(branch, before, after);
-  }
-  AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
-      branch, regionFrom, regionTo, before, after);
+  auto defaultHandling = [&]() {
+    AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
+        branch, regionFrom, regionTo, before, after);
+  };
+  TypeSwitch<Operation *>(branch.getOperation())
+      .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
+          [=](auto storeWithRegion) {
+            if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
+                (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
+              visitOperation(branch, before, after);
+            defaultHandling();
+          })
+      .Default([=](auto) { defaultHandling(); });
 }
 
 namespace {
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index ae4c9a85605e1c5..c02e6240ec7f361 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1318,6 +1318,16 @@ void TestStoreWithARegion::getSuccessorRegions(
     regions.emplace_back();
 }
 
+void TestStoreWithALoopRegion::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself. It is possible for the operation not to
+  // enter the body.
+  regions.emplace_back(
+      RegionSuccessor(&getBody(), getBody().front().getArguments()));
+  regions.emplace_back();
+}
+
 LogicalResult
 TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
                                  ::mlir::OperationState &state) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ceadab8fa4a086..0aa8ce4de9756fa 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2953,6 +2953,18 @@ def TestStoreWithARegion : TEST_Op<"store_with_a_region",
     "$address attr-dict-with-keyword regions `:` type($address)";
 }
 
+def TestStoreWithALoopRegion : TEST_Op<"store_with_a_loop_region",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+     SingleBlock]> {
+  let arguments = (ins
+    Arg<AnyMemRef, "", [MemWrite]>:$address,
+    BoolAttr:$store_before_region
+  );
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat =
+    "$address attr-dict-with-keyword regions `:` type($address)";
+}
+
 def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
     [ReturnLike, Terminator, NoMemoryEffect]> {
   let assemblyFormat = "attr-dict";
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66179


More information about the Mlir-commits mailing list