[Mlir-commits] [mlir] 1e4faf2 - [mlir][IR] Add a Region::getOps method that returns a range of immediately nested operations

River Riddle llvmlistbot at llvm.org
Mon May 4 18:02:12 PDT 2020


Author: River Riddle
Date: 2020-05-04T17:46:25-07:00
New Revision: 1e4faf23ffde601679cf8e48a9cd576918a0cf2c

URL: https://github.com/llvm/llvm-project/commit/1e4faf23ffde601679cf8e48a9cd576918a0cf2c
DIFF: https://github.com/llvm/llvm-project/commit/1e4faf23ffde601679cf8e48a9cd576918a0cf2c.diff

LOG: [mlir][IR] Add a Region::getOps method that returns a range of immediately nested operations

This allows for walking the operations nested directly within a region, without traversing nested regions.

Differential Revision: https://reviews.llvm.org/D79056

Added: 
    

Modified: 
    mlir/include/mlir/IR/Block.h
    mlir/include/mlir/IR/BlockSupport.h
    mlir/include/mlir/IR/Function.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Region.h
    mlir/lib/Analysis/CallGraph.cpp
    mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
    mlir/lib/IR/Region.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/lib/Transforms/Inliner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 12f82f84b52a..859fa1713ffd 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -156,54 +156,23 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   /// Recomputes the ordering of child operations within the block.
   void recomputeOpOrder();
 
-private:
-  /// A utility iterator that filters out operations that are not 'OpT'.
-  template <typename OpT>
-  class op_filter_iterator
-      : public llvm::filter_iterator<Block::iterator, bool (*)(Operation &)> {
-    static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
-
-  public:
-    op_filter_iterator(Block::iterator it, Block::iterator end)
-        : llvm::filter_iterator<Block::iterator, bool (*)(Operation &)>(
-              it, end, &filter) {}
-
-    /// Allow implicit conversion to the underlying block iterator.
-    operator Block::iterator() const { return this->wrapped(); }
-  };
-
-public:
   /// This class provides iteration over the held operations of a block for a
   /// specific operation type.
   template <typename OpT>
-  class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>,
-                                                   OpT (*)(Operation &)> {
-    static OpT unwrap(Operation &op) { return cast<OpT>(op); }
-
-  public:
-    using reference = OpT;
-
-    /// Initializes the iterator to the specified filter iterator.
-    op_iterator(op_filter_iterator<OpT> it)
-        : llvm::mapped_iterator<op_filter_iterator<OpT>, OpT (*)(Operation &)>(
-              it, &unwrap) {}
-
-    /// Allow implicit conversion to the underlying block iterator.
-    operator Block::iterator() const { return this->wrapped(); }
-  };
+  using op_iterator = detail::op_iterator<OpT, iterator>;
 
   /// Return an iterator range over the operations within this block that are of
   /// 'OpT'.
   template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
     auto endIt = end();
-    return {op_filter_iterator<OpT>(begin(), endIt),
-            op_filter_iterator<OpT>(endIt, endIt)};
+    return {detail::op_filter_iterator<OpT, iterator>(begin(), endIt),
+            detail::op_filter_iterator<OpT, iterator>(endIt, endIt)};
   }
   template <typename OpT> op_iterator<OpT> op_begin() {
-    return op_filter_iterator<OpT>(begin(), end());
+    return detail::op_filter_iterator<OpT, iterator>(begin(), end());
   }
   template <typename OpT> op_iterator<OpT> op_end() {
-    return op_filter_iterator<OpT>(end(), end());
+    return detail::op_filter_iterator<OpT, iterator>(end(), end());
   }
 
   /// Return an iterator range over the operation within this block excluding

diff  --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h
index 3c246749c584..10b8c48c6db6 100644
--- a/mlir/include/mlir/IR/BlockSupport.h
+++ b/mlir/include/mlir/IR/BlockSupport.h
@@ -75,6 +75,46 @@ class SuccessorRange final
   friend RangeBaseT;
 };
 
+//===----------------------------------------------------------------------===//
+// Operation Iterators
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// A utility iterator that filters out operations that are not 'OpT'.
+template <typename OpT, typename IteratorT>
+class op_filter_iterator
+    : public llvm::filter_iterator<IteratorT, bool (*)(Operation &)> {
+  static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
+
+public:
+  op_filter_iterator(IteratorT it, IteratorT end)
+      : llvm::filter_iterator<IteratorT, bool (*)(Operation &)>(it, end,
+                                                                &filter) {}
+
+  /// Allow implicit conversion to the underlying iterator.
+  operator IteratorT() const { return this->wrapped(); }
+};
+
+/// This class provides iteration over the held operations of a block for a
+/// specific operation type.
+template <typename OpT, typename IteratorT>
+class op_iterator
+    : public llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
+                                   OpT (*)(Operation &)> {
+  static OpT unwrap(Operation &op) { return cast<OpT>(op); }
+
+public:
+  using reference = OpT;
+
+  /// Initializes the iterator to the specified filter iterator.
+  op_iterator(op_filter_iterator<OpT, IteratorT> it)
+      : llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
+                              OpT (*)(Operation &)>(it, &unwrap) {}
+
+  /// Allow implicit conversion to the underlying block iterator.
+  operator IteratorT() const { return this->wrapped(); }
+};
+} // end namespace detail
 } // end namespace mlir
 
 namespace llvm {

diff  --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 2617a488edf4..e39b88091f2b 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -32,9 +32,10 @@ namespace mlir {
 /// symbols referenced by name via a string attribute).
 class FuncOp
     : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
-                OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike,
-                OpTrait::AutomaticAllocationScope, OpTrait::PolyhedralScope,
-                CallableOpInterface::Trait, SymbolOpInterface::Trait> {
+                OpTrait::OneRegion, OpTrait::IsIsolatedFromAbove,
+                OpTrait::FunctionLike, OpTrait::AutomaticAllocationScope,
+                OpTrait::PolyhedralScope, CallableOpInterface::Trait,
+                SymbolOpInterface::Trait> {
 public:
   using Op::Op;
   using Op::print;

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 121ed1b568ff..b1830de88eef 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -583,6 +583,13 @@ class OneRegion : public TraitBase<ConcreteType, OneRegion> {
 public:
   Region &getRegion() { return this->getOperation()->getRegion(0); }
 
+  /// Returns a range of operations within the region of this operation.
+  auto getOps() { return getRegion().getOps(); }
+  template <typename OpT>
+  auto getOps() {
+    return getRegion().template getOps<OpT>();
+  }
+
   static LogicalResult verifyTrait(Operation *op) {
     return impl::verifyOneRegion(op);
   }

diff  --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index f824b65c646f..0efa58ca4c0a 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -34,6 +34,10 @@ class Region {
   /// parent container. The region must have a valid parent container.
   Location getLoc();
 
+  //===--------------------------------------------------------------------===//
+  // Block list management
+  //===--------------------------------------------------------------------===//
+
   using BlockListType = llvm::iplist<Block>;
   BlockListType &getBlocks() { return blocks; }
 
@@ -58,6 +62,72 @@ class Region {
     return &Region::blocks;
   }
 
+  //===--------------------------------------------------------------------===//
+  // Operation list utilities
+  //===--------------------------------------------------------------------===//
+
+  /// This class provides iteration over the held operations of blocks directly
+  /// within a region.
+  class OpIterator final
+      : public llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
+                                          Operation> {
+  public:
+    /// Initialize OpIterator for a region, specify `end` to return the iterator
+    /// to last operation.
+    explicit OpIterator(Region *region, bool end = false);
+
+    using llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
+                                     Operation>::operator++;
+    OpIterator &operator++();
+    Operation *operator->() const { return &*operation; }
+    Operation &operator*() const { return *operation; }
+
+    /// Compare this iterator with another.
+    bool operator==(const OpIterator &rhs) const {
+      return operation == rhs.operation;
+    }
+    bool operator!=(const OpIterator &rhs) const { return !(*this == rhs); }
+
+  private:
+    void skipOverBlocksWithNoOps();
+
+    /// The region whose operations are being iterated over.
+    Region *region;
+    /// The block of 'region' whose operations are being iterated over.
+    Region::iterator block;
+    /// The current operation within 'block'.
+    Block::iterator operation;
+  };
+
+  /// This class provides iteration over the held operations of a region for a
+  /// specific operation type.
+  template <typename OpT>
+  using op_iterator = detail::op_iterator<OpT, OpIterator>;
+
+  /// Return iterators that walk the operations nested directly within this
+  /// region.
+  OpIterator op_begin() { return OpIterator(this); }
+  OpIterator op_end() { return OpIterator(this, /*end=*/true); }
+  iterator_range<OpIterator> getOps() { return {op_begin(), op_end()}; }
+
+  /// Return iterators that walk operations of type 'T' nested directly within
+  /// this region.
+  template <typename OpT> op_iterator<OpT> op_begin() {
+    return detail::op_filter_iterator<OpT, OpIterator>(op_begin(), op_end());
+  }
+  template <typename OpT> op_iterator<OpT> op_end() {
+    return detail::op_filter_iterator<OpT, OpIterator>(op_end(), op_end());
+  }
+  template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
+    auto endIt = op_end();
+    return {detail::op_filter_iterator<OpT, OpIterator>(op_begin(), endIt),
+            detail::op_filter_iterator<OpT, OpIterator>(endIt, endIt)};
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Misc. utilities
+  //===--------------------------------------------------------------------===//
+
   /// Return the region containing this region or nullptr if the region is
   /// attached to a top-level operation.
   Region *getParentRegion();
@@ -120,6 +190,10 @@ class Region {
   /// they are to be deleted.
   void dropAllReferences();
 
+  //===--------------------------------------------------------------------===//
+  // Operation Walkers
+  //===--------------------------------------------------------------------===//
+
   /// Walk the operations in this region in postorder, calling the callback for
   /// each operation. This method is invoked for void-returning callbacks.
   /// See Operation::walk for more details.
@@ -142,6 +216,10 @@ class Region {
     return WalkResult::advance();
   }
 
+  //===--------------------------------------------------------------------===//
+  // CFG view utilities
+  //===--------------------------------------------------------------------===//
+
   /// Displays the CFG in a window. This is for use from the debugger and
   /// depends on Graphviz to generate the graph.
   /// This function is defined in ViewRegionGraph and only works with that

diff  --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 1cb317fc0f6d..94965c7a623d 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -87,9 +87,8 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
   }
 
   for (Region &region : op->getRegions())
-    for (Block &block : region)
-      for (Operation &nested : block)
-        computeCallGraph(&nested, cg, parentNode, resolveCalls);
+    for (Operation &nested : region.getOps())
+      computeCallGraph(&nested, cg, parentNode, resolveCalls);
 }
 
 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {

diff  --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
index a597dd7bf078..94c6a3b20f2b 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -36,18 +36,17 @@ struct ForLoopMapper : public ConvertSimpleLoopsToGPUBase<ForLoopMapper> {
   }
 
   void runOnFunction() override {
-    for (Block &block : getFunction())
-      for (Operation &op : llvm::make_early_inc_range(block)) {
-        if (auto forOp = dyn_cast<AffineForOp>(&op)) {
-          if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
-                                                      numThreadDims)))
-            signalPassFailure();
-        } else if (auto forOp = dyn_cast<ForOp>(&op)) {
-          if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims,
-                                                numThreadDims)))
-            signalPassFailure();
-        }
+    for (Operation &op : llvm::make_early_inc_range(getFunction().getOps())) {
+      if (auto forOp = dyn_cast<AffineForOp>(&op)) {
+        if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
+                                                    numThreadDims)))
+          signalPassFailure();
+      } else if (auto forOp = dyn_cast<ForOp>(&op)) {
+        if (failed(
+                convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims)))
+          signalPassFailure();
       }
+    }
   }
 };
 
@@ -81,14 +80,10 @@ struct ImperfectlyNestedForLoopMapper
           funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
       workGroupSizeVal.push_back(constOp);
     }
-    for (Block &block : getFunction()) {
-      for (Operation &op : llvm::make_early_inc_range(block)) {
-        if (auto forOp = dyn_cast<ForOp>(&op)) {
-          if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
-                                            workGroupSizeVal))) {
-            return signalPassFailure();
-          }
-        }
+    for (ForOp forOp : llvm::make_early_inc_range(funcOp.getOps<ForOp>())) {
+      if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
+                                        workGroupSizeVal))) {
+        return signalPassFailure();
       }
     }
   }

diff  --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 4f5054112a41..aa2acc00dde4 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -146,34 +146,32 @@ static bool isIsolatedAbove(Region &region, Region &limit,
 
   // Traverse all operations in the region.
   while (!pendingRegions.empty()) {
-    for (Block &block : *pendingRegions.pop_back_val()) {
-      for (Operation &op : block) {
-        for (Value operand : op.getOperands()) {
-          // operand should be non-null here if the IR is well-formed. But
-          // we don't assert here as this function is called from the verifier
-          // and so could be called on invalid IR.
-          if (!operand) {
-            if (noteLoc)
-              op.emitOpError("block's operand not defined").attachNote(noteLoc);
-            return false;
-          }
+    for (Operation &op : pendingRegions.pop_back_val()->getOps()) {
+      for (Value operand : op.getOperands()) {
+        // operand should be non-null here if the IR is well-formed. But
+        // we don't assert here as this function is called from the verifier
+        // and so could be called on invalid IR.
+        if (!operand) {
+          if (noteLoc)
+            op.emitOpError("block's operand not defined").attachNote(noteLoc);
+          return false;
+        }
 
-          // Check that any value that is used by an operation is defined in the
-          // same region as either an operation result or a block argument.
-          if (operand.getParentRegion()->isProperAncestor(&limit)) {
-            if (noteLoc) {
-              op.emitOpError("using value defined outside the region")
-                      .attachNote(noteLoc)
-                  << "required by region isolation constraints";
-            }
-            return false;
+        // Check that any value that is used by an operation is defined in the
+        // same region as either an operation result or a block argument.
+        if (operand.getParentRegion()->isProperAncestor(&limit)) {
+          if (noteLoc) {
+            op.emitOpError("using value defined outside the region")
+                    .attachNote(noteLoc)
+                << "required by region isolation constraints";
           }
+          return false;
         }
-        // Schedule any regions the operations contain for further checking.
-        pendingRegions.reserve(pendingRegions.size() + op.getNumRegions());
-        for (Region &subRegion : op.getRegions())
-          pendingRegions.push_back(&subRegion);
       }
+      // Schedule any regions the operations contain for further checking.
+      pendingRegions.reserve(pendingRegions.size() + op.getNumRegions());
+      for (Region &subRegion : op.getRegions())
+        pendingRegions.push_back(&subRegion);
     }
   }
   return true;
@@ -219,6 +217,40 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
     first->parentValidOpOrderPair.setPointer(curParent);
 }
 
+//===----------------------------------------------------------------------===//
+// Region::OpIterator
+//===----------------------------------------------------------------------===//
+
+Region::OpIterator::OpIterator(Region *region, bool end)
+    : region(region), block(end ? region->end() : region->begin()) {
+  if (!region->empty())
+    skipOverBlocksWithNoOps();
+}
+
+Region::OpIterator &Region::OpIterator::operator++() {
+  // We increment over operations, if we reach the last use then move to next
+  // block.
+  if (operation != block->end())
+    ++operation;
+  if (operation == block->end()) {
+    ++block;
+    skipOverBlocksWithNoOps();
+  }
+  return *this;
+}
+
+void Region::OpIterator::skipOverBlocksWithNoOps() {
+  while (block != region->end() && block->empty())
+    ++block;
+
+  // If we are at the last block, then set the operation to first operation of
+  // next block (sentinel value used for end).
+  if (block == region->end())
+    operation = {};
+  else
+    operation = block->begin();
+}
+
 //===----------------------------------------------------------------------===//
 // RegionRange
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index e195225d675a..1d2235b61936 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -245,11 +245,9 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
 
   // Look for a symbol with the given name.
-  for (auto &block : symbolTableOp->getRegion(0)) {
-    for (auto &op : block)
-      if (getNameIfSymbol(&op) == symbol)
-        return &op;
-  }
+  for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
+    if (getNameIfSymbol(&op) == symbol)
+      return &op;
   return nullptr;
 }
 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
@@ -444,21 +442,19 @@ static Optional<WalkResult> walkSymbolUses(
     function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
   SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
   while (!worklist.empty()) {
-    for (Block &block : *worklist.pop_back_val()) {
-      for (Operation &op : block) {
-        if (walkSymbolRefs(&op, callback).wasInterrupted())
-          return WalkResult::interrupt();
-
-        // Check that this isn't a potentially unknown symbol table.
-        if (isPotentiallyUnknownSymbolTable(&op))
-          return llvm::None;
-
-        // If this op defines a new symbol table scope, we can't traverse. Any
-        // symbol references nested within 'op' are 
diff erent semantically.
-        if (!op.hasTrait<OpTrait::SymbolTable>()) {
-          for (Region &region : op.getRegions())
-            worklist.push_back(&region);
-        }
+    for (Operation &op : worklist.pop_back_val()->getOps()) {
+      if (walkSymbolRefs(&op, callback).wasInterrupted())
+        return WalkResult::interrupt();
+
+      // Check that this isn't a potentially unknown symbol table.
+      if (isPotentiallyUnknownSymbolTable(&op))
+        return llvm::None;
+
+      // If this op defines a new symbol table scope, we can't traverse. Any
+      // symbol references nested within 'op' are 
diff erent semantically.
+      if (!op.hasTrait<OpTrait::SymbolTable>()) {
+        for (Region &region : op.getRegions())
+          worklist.push_back(&region);
       }
     }
   }

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index ee645cb55511..f8f48d86b562 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -122,23 +122,21 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
 
   // Walk each of the symbol tables looking for discardable callgraph nodes.
   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
-    for (Block &block : symbolTableOp->getRegion(0)) {
-      for (Operation &op : block) {
-        // If this is a callgraph operation, check to see if it is discardable.
-        if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
-          if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
-            SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
-            if (symbol && (allUsesVisible || symbol.isPrivate()) &&
-                symbol.canDiscardOnUseEmpty()) {
-              discardableSymNodeUses.try_emplace(node, 0);
-            }
-            continue;
+    for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
+      // If this is a callgraph operation, check to see if it is discardable.
+      if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
+        if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
+          SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+          if (symbol && (allUsesVisible || symbol.isPrivate()) &&
+              symbol.canDiscardOnUseEmpty()) {
+            discardableSymNodeUses.try_emplace(node, 0);
           }
+          continue;
         }
-        // Otherwise, check for any referenced nodes. These will be always-live.
-        walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
-                                  [](CallGraphNode *, Operation *) {});
       }
+      // Otherwise, check for any referenced nodes. These will be always-live.
+      walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
+                                [](CallGraphNode *, Operation *) {});
     }
   };
   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),


        


More information about the Mlir-commits mailing list