[Mlir-commits] [mlir] a5c2f78 - [mlir][interfaces] Add insideMutuallyExclusiveRegions helper

Matthias Springer llvmlistbot at llvm.org
Thu Nov 25 00:49:41 PST 2021


Author: Matthias Springer
Date: 2021-11-25T17:44:39+09:00
New Revision: a5c2f7828796ce9c3e19e78fbd783fb0206b971d

URL: https://github.com/llvm/llvm-project/commit/a5c2f7828796ce9c3e19e78fbd783fb0206b971d
DIFF: https://github.com/llvm/llvm-project/commit/a5c2f7828796ce9c3e19e78fbd783fb0206b971d.diff

LOG: [mlir][interfaces] Add insideMutuallyExclusiveRegions helper

Add a helper function to ControlFlowInterfaces for checking if two ops
are in mutually exclusive regions according to RegionBranchOpInterface.

Utilize this new helper in Linalg ComprehensiveBufferize. This makes the
analysis independent of the SCF dialect and generalizes it to other ops
that implement RegionBranchOpInterface.

Differential Revision: https://reviews.llvm.org/D114220

Added: 
    mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

Modified: 
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/unittests/Interfaces/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index fd4c560422377..da3549ff3d77f 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -87,6 +87,10 @@ class RegionSuccessor {
   ValueRange inputs;
 };
 
+/// Return `true` if `a` and `b` are in mutually exclusive regions as per
+/// RegionBranchOpInterface.
+bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
+
 //===----------------------------------------------------------------------===//
 // RegionBranchTerminatorOpInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index c07adcf3fbc03..2b55d22ad4962 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -430,9 +430,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
                                               aliasInfo))
             continue;
 
-      // Special rules for branches.
-      // TODO: Use an interface.
-      if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
+      // Ops are not conflicting if they are in mutually exclusive regions.
+      if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
         continue;
 
       LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 498517b8f63fb..26c80795c6509 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -219,6 +219,78 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   return success();
 }
 
+/// Return `true` if `a` and `b` are in mutually exclusive regions.
+///
+/// 1. Find the first common of `a` and `b` (ancestor) that implements
+///    RegionBranchOpInterface.
+/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
+///    contained.
+/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
+///    mutually exclusive if they are not reachable from each other as per
+///    RegionBranchOpInterface::getSuccessorRegions.
+bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
+  assert(a && "expected non-empty operation");
+  assert(b && "expected non-empty operation");
+
+  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
+  while (branchOp) {
+    // Check if b is inside branchOp. (We already know that a is.)
+    if (!branchOp->isProperAncestor(b)) {
+      // Check next enclosing RegionBranchOpInterface.
+      branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
+      continue;
+    }
+
+    // b is contained in branchOp. Retrieve the regions in which `a` and `b`
+    // are contained.
+    Region *regionA = nullptr, *regionB = nullptr;
+    for (Region &r : branchOp->getRegions()) {
+      if (r.findAncestorOpInRegion(*a)) {
+        assert(!regionA && "already found a region for a");
+        regionA = &r;
+      }
+      if (r.findAncestorOpInRegion(*b)) {
+        assert(!regionB && "already found a region for b");
+        regionB = &r;
+      }
+    }
+    assert(regionA && regionB && "could not find region of op");
+
+    // Helper function that checks if region `r` is reachable from region
+    // `begin`.
+    std::function<bool(Region *, Region *)> isRegionReachable =
+        [&](Region *begin, Region *r) {
+          if (begin == r)
+            return true;
+          if (begin == nullptr)
+            return false;
+          // Compute index of region.
+          int64_t beginIndex = -1;
+          for (auto it : llvm::enumerate(branchOp->getRegions()))
+            if (&it.value() == begin)
+              beginIndex = it.index();
+          assert(beginIndex != -1 && "could not find region in op");
+          // Retrieve all successors of the region.
+          SmallVector<RegionSuccessor> successors;
+          branchOp.getSuccessorRegions(beginIndex, successors);
+          // Call function recursively on all successors.
+          for (RegionSuccessor successor : successors)
+            if (isRegionReachable(successor.getSuccessor(), r))
+              return true;
+          return false;
+        };
+
+    // `a` and `b` are in mutually exclusive regions if neither region is
+    // reachable from the other region.
+    return !isRegionReachable(regionA, regionB) &&
+           !isRegionReachable(regionB, regionA);
+  }
+
+  // Could not find a common RegionBranchOpInterface among a's and b's
+  // ancestors.
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchTerminatorOpInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index 003bbc41ef7c0..d86710c2781a7 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -1,10 +1,12 @@
 add_mlir_unittest(MLIRInterfacesTests
+  ControlFlowInterfacesTest.cpp
   DataLayoutInterfacesTest.cpp
   InferTypeOpInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
   PRIVATE
+  MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
   MLIRDLTI
   MLIRInferTypeOpInterface

diff  --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
new file mode 100644
index 0000000000000..e3934aa0208ae
--- /dev/null
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -0,0 +1,145 @@
+//===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+/// A dummy op that is also a terminator.
+struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
+  using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() { return "cftest.dummy_op"; }
+};
+
+/// All regions of this op are mutually exclusive.
+struct MutuallyExclusiveRegionsOp
+    : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
+  using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() {
+    return "cftest.mutually_exclusive_regions_op";
+  }
+
+  // Regions have no successors.
+  void getSuccessorRegions(Optional<unsigned> index,
+                           ArrayRef<Attribute> operands,
+                           SmallVectorImpl<RegionSuccessor> &regions) {}
+};
+
+/// Regions are executed sequentially.
+struct SequentialRegionsOp
+    : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
+  using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
+
+  // Region 0 has Region 1 as a successor.
+  void getSuccessorRegions(Optional<unsigned> index,
+                           ArrayRef<Attribute> operands,
+                           SmallVectorImpl<RegionSuccessor> &regions) {
+    assert(index.hasValue() && "expected index");
+    if (*index == 0) {
+      Operation *thisOp = this->getOperation();
+      regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
+    }
+  }
+};
+
+/// A dialect putting all the above together.
+struct CFTestDialect : Dialect {
+  explicit CFTestDialect(MLIRContext *ctx)
+      : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
+    addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
+  }
+  static StringRef getDialectNamespace() { return "cftest"; }
+};
+
+TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
+  const char *ir = R"MLIR(
+"cftest.mutually_exclusive_regions_op"() (
+      {"cftest.dummy_op"() : () -> ()},  // op1
+      {"cftest.dummy_op"() : () -> ()}   // op2
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningModuleRef module = parseSourceString(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  Operation *op1 = &testOp->getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+
+  EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
+  EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
+}
+
+TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
+  const char *ir = R"MLIR(
+"cftest.sequential_regions_op"() (
+      {"cftest.dummy_op"() : () -> ()},  // op1
+      {"cftest.dummy_op"() : () -> ()}   // op2
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningModuleRef module = parseSourceString(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  Operation *op1 = &testOp->getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+
+  EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
+  EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
+}
+
+TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
+  const char *ir = R"MLIR(
+"cftest.mutually_exclusive_regions_op"() (
+      {
+        "cftest.sequential_regions_op"() (
+              {"cftest.dummy_op"() : () -> ()},  // op1
+              {"cftest.dummy_op"() : () -> ()}   // op3
+          ) : () -> ()
+        "cftest.dummy_op"() : () -> ()
+      },
+      {"cftest.dummy_op"() : () -> ()}           // op2
+  ) : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<CFTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningModuleRef module = parseSourceString(ir, &ctx);
+  Operation *testOp = &module->getBody()->getOperations().front();
+  Operation *op1 =
+      &testOp->getRegion(0).front().front().getRegion(0).front().front();
+  Operation *op2 = &testOp->getRegion(1).front().front();
+  Operation *op3 =
+      &testOp->getRegion(0).front().front().getRegion(1).front().front();
+
+  EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
+  EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
+  EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
+}


        


More information about the Mlir-commits mailing list