[Mlir-commits] [mlir] 1e4faf2 - [mlir][IR] Add a Region::getOps method that returns a range of immediately nested operations
River Riddle
llvmlistbot at llvm.org
Mon May 4 18:02:12 PDT 2020
Author: River Riddle
Date: 2020-05-04T17:46:25-07:00
New Revision: 1e4faf23ffde601679cf8e48a9cd576918a0cf2c
URL: https://github.com/llvm/llvm-project/commit/1e4faf23ffde601679cf8e48a9cd576918a0cf2c
DIFF: https://github.com/llvm/llvm-project/commit/1e4faf23ffde601679cf8e48a9cd576918a0cf2c.diff
LOG: [mlir][IR] Add a Region::getOps method that returns a range of immediately nested operations
This allows for walking the operations nested directly within a region, without traversing nested regions.
Differential Revision: https://reviews.llvm.org/D79056
Added:
Modified:
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/BlockSupport.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Region.h
mlir/lib/Analysis/CallGraph.cpp
mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
mlir/lib/IR/Region.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Transforms/Inliner.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 12f82f84b52a..859fa1713ffd 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -156,54 +156,23 @@ class Block : public IRObjectWithUseList<BlockOperand>,
/// Recomputes the ordering of child operations within the block.
void recomputeOpOrder();
-private:
- /// A utility iterator that filters out operations that are not 'OpT'.
- template <typename OpT>
- class op_filter_iterator
- : public llvm::filter_iterator<Block::iterator, bool (*)(Operation &)> {
- static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
-
- public:
- op_filter_iterator(Block::iterator it, Block::iterator end)
- : llvm::filter_iterator<Block::iterator, bool (*)(Operation &)>(
- it, end, &filter) {}
-
- /// Allow implicit conversion to the underlying block iterator.
- operator Block::iterator() const { return this->wrapped(); }
- };
-
-public:
/// This class provides iteration over the held operations of a block for a
/// specific operation type.
template <typename OpT>
- class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>,
- OpT (*)(Operation &)> {
- static OpT unwrap(Operation &op) { return cast<OpT>(op); }
-
- public:
- using reference = OpT;
-
- /// Initializes the iterator to the specified filter iterator.
- op_iterator(op_filter_iterator<OpT> it)
- : llvm::mapped_iterator<op_filter_iterator<OpT>, OpT (*)(Operation &)>(
- it, &unwrap) {}
-
- /// Allow implicit conversion to the underlying block iterator.
- operator Block::iterator() const { return this->wrapped(); }
- };
+ using op_iterator = detail::op_iterator<OpT, iterator>;
/// Return an iterator range over the operations within this block that are of
/// 'OpT'.
template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
auto endIt = end();
- return {op_filter_iterator<OpT>(begin(), endIt),
- op_filter_iterator<OpT>(endIt, endIt)};
+ return {detail::op_filter_iterator<OpT, iterator>(begin(), endIt),
+ detail::op_filter_iterator<OpT, iterator>(endIt, endIt)};
}
template <typename OpT> op_iterator<OpT> op_begin() {
- return op_filter_iterator<OpT>(begin(), end());
+ return detail::op_filter_iterator<OpT, iterator>(begin(), end());
}
template <typename OpT> op_iterator<OpT> op_end() {
- return op_filter_iterator<OpT>(end(), end());
+ return detail::op_filter_iterator<OpT, iterator>(end(), end());
}
/// Return an iterator range over the operation within this block excluding
diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h
index 3c246749c584..10b8c48c6db6 100644
--- a/mlir/include/mlir/IR/BlockSupport.h
+++ b/mlir/include/mlir/IR/BlockSupport.h
@@ -75,6 +75,46 @@ class SuccessorRange final
friend RangeBaseT;
};
+//===----------------------------------------------------------------------===//
+// Operation Iterators
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// A utility iterator that filters out operations that are not 'OpT'.
+template <typename OpT, typename IteratorT>
+class op_filter_iterator
+ : public llvm::filter_iterator<IteratorT, bool (*)(Operation &)> {
+ static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
+
+public:
+ op_filter_iterator(IteratorT it, IteratorT end)
+ : llvm::filter_iterator<IteratorT, bool (*)(Operation &)>(it, end,
+ &filter) {}
+
+ /// Allow implicit conversion to the underlying iterator.
+ operator IteratorT() const { return this->wrapped(); }
+};
+
+/// This class provides iteration over the held operations of a block for a
+/// specific operation type.
+template <typename OpT, typename IteratorT>
+class op_iterator
+ : public llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
+ OpT (*)(Operation &)> {
+ static OpT unwrap(Operation &op) { return cast<OpT>(op); }
+
+public:
+ using reference = OpT;
+
+ /// Initializes the iterator to the specified filter iterator.
+ op_iterator(op_filter_iterator<OpT, IteratorT> it)
+ : llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
+ OpT (*)(Operation &)>(it, &unwrap) {}
+
+ /// Allow implicit conversion to the underlying block iterator.
+ operator IteratorT() const { return this->wrapped(); }
+};
+} // end namespace detail
} // end namespace mlir
namespace llvm {
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 2617a488edf4..e39b88091f2b 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -32,9 +32,10 @@ namespace mlir {
/// symbols referenced by name via a string attribute).
class FuncOp
: public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
- OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike,
- OpTrait::AutomaticAllocationScope, OpTrait::PolyhedralScope,
- CallableOpInterface::Trait, SymbolOpInterface::Trait> {
+ OpTrait::OneRegion, OpTrait::IsIsolatedFromAbove,
+ OpTrait::FunctionLike, OpTrait::AutomaticAllocationScope,
+ OpTrait::PolyhedralScope, CallableOpInterface::Trait,
+ SymbolOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 121ed1b568ff..b1830de88eef 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -583,6 +583,13 @@ class OneRegion : public TraitBase<ConcreteType, OneRegion> {
public:
Region &getRegion() { return this->getOperation()->getRegion(0); }
+ /// Returns a range of operations within the region of this operation.
+ auto getOps() { return getRegion().getOps(); }
+ template <typename OpT>
+ auto getOps() {
+ return getRegion().template getOps<OpT>();
+ }
+
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneRegion(op);
}
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index f824b65c646f..0efa58ca4c0a 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -34,6 +34,10 @@ class Region {
/// parent container. The region must have a valid parent container.
Location getLoc();
+ //===--------------------------------------------------------------------===//
+ // Block list management
+ //===--------------------------------------------------------------------===//
+
using BlockListType = llvm::iplist<Block>;
BlockListType &getBlocks() { return blocks; }
@@ -58,6 +62,72 @@ class Region {
return &Region::blocks;
}
+ //===--------------------------------------------------------------------===//
+ // Operation list utilities
+ //===--------------------------------------------------------------------===//
+
+ /// This class provides iteration over the held operations of blocks directly
+ /// within a region.
+ class OpIterator final
+ : public llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
+ Operation> {
+ public:
+ /// Initialize OpIterator for a region, specify `end` to return the iterator
+ /// to last operation.
+ explicit OpIterator(Region *region, bool end = false);
+
+ using llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
+ Operation>::operator++;
+ OpIterator &operator++();
+ Operation *operator->() const { return &*operation; }
+ Operation &operator*() const { return *operation; }
+
+ /// Compare this iterator with another.
+ bool operator==(const OpIterator &rhs) const {
+ return operation == rhs.operation;
+ }
+ bool operator!=(const OpIterator &rhs) const { return !(*this == rhs); }
+
+ private:
+ void skipOverBlocksWithNoOps();
+
+ /// The region whose operations are being iterated over.
+ Region *region;
+ /// The block of 'region' whose operations are being iterated over.
+ Region::iterator block;
+ /// The current operation within 'block'.
+ Block::iterator operation;
+ };
+
+ /// This class provides iteration over the held operations of a region for a
+ /// specific operation type.
+ template <typename OpT>
+ using op_iterator = detail::op_iterator<OpT, OpIterator>;
+
+ /// Return iterators that walk the operations nested directly within this
+ /// region.
+ OpIterator op_begin() { return OpIterator(this); }
+ OpIterator op_end() { return OpIterator(this, /*end=*/true); }
+ iterator_range<OpIterator> getOps() { return {op_begin(), op_end()}; }
+
+ /// Return iterators that walk operations of type 'T' nested directly within
+ /// this region.
+ template <typename OpT> op_iterator<OpT> op_begin() {
+ return detail::op_filter_iterator<OpT, OpIterator>(op_begin(), op_end());
+ }
+ template <typename OpT> op_iterator<OpT> op_end() {
+ return detail::op_filter_iterator<OpT, OpIterator>(op_end(), op_end());
+ }
+ template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
+ auto endIt = op_end();
+ return {detail::op_filter_iterator<OpT, OpIterator>(op_begin(), endIt),
+ detail::op_filter_iterator<OpT, OpIterator>(endIt, endIt)};
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Misc. utilities
+ //===--------------------------------------------------------------------===//
+
/// Return the region containing this region or nullptr if the region is
/// attached to a top-level operation.
Region *getParentRegion();
@@ -120,6 +190,10 @@ class Region {
/// they are to be deleted.
void dropAllReferences();
+ //===--------------------------------------------------------------------===//
+ // Operation Walkers
+ //===--------------------------------------------------------------------===//
+
/// Walk the operations in this region in postorder, calling the callback for
/// each operation. This method is invoked for void-returning callbacks.
/// See Operation::walk for more details.
@@ -142,6 +216,10 @@ class Region {
return WalkResult::advance();
}
+ //===--------------------------------------------------------------------===//
+ // CFG view utilities
+ //===--------------------------------------------------------------------===//
+
/// Displays the CFG in a window. This is for use from the debugger and
/// depends on Graphviz to generate the graph.
/// This function is defined in ViewRegionGraph and only works with that
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 1cb317fc0f6d..94965c7a623d 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -87,9 +87,8 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
}
for (Region ®ion : op->getRegions())
- for (Block &block : region)
- for (Operation &nested : block)
- computeCallGraph(&nested, cg, parentNode, resolveCalls);
+ for (Operation &nested : region.getOps())
+ computeCallGraph(&nested, cg, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
index a597dd7bf078..94c6a3b20f2b 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -36,18 +36,17 @@ struct ForLoopMapper : public ConvertSimpleLoopsToGPUBase<ForLoopMapper> {
}
void runOnFunction() override {
- for (Block &block : getFunction())
- for (Operation &op : llvm::make_early_inc_range(block)) {
- if (auto forOp = dyn_cast<AffineForOp>(&op)) {
- if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
- numThreadDims)))
- signalPassFailure();
- } else if (auto forOp = dyn_cast<ForOp>(&op)) {
- if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims,
- numThreadDims)))
- signalPassFailure();
- }
+ for (Operation &op : llvm::make_early_inc_range(getFunction().getOps())) {
+ if (auto forOp = dyn_cast<AffineForOp>(&op)) {
+ if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
+ numThreadDims)))
+ signalPassFailure();
+ } else if (auto forOp = dyn_cast<ForOp>(&op)) {
+ if (failed(
+ convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims)))
+ signalPassFailure();
}
+ }
}
};
@@ -81,14 +80,10 @@ struct ImperfectlyNestedForLoopMapper
funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
workGroupSizeVal.push_back(constOp);
}
- for (Block &block : getFunction()) {
- for (Operation &op : llvm::make_early_inc_range(block)) {
- if (auto forOp = dyn_cast<ForOp>(&op)) {
- if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
- workGroupSizeVal))) {
- return signalPassFailure();
- }
- }
+ for (ForOp forOp : llvm::make_early_inc_range(funcOp.getOps<ForOp>())) {
+ if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
+ workGroupSizeVal))) {
+ return signalPassFailure();
}
}
}
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 4f5054112a41..aa2acc00dde4 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -146,34 +146,32 @@ static bool isIsolatedAbove(Region ®ion, Region &limit,
// Traverse all operations in the region.
while (!pendingRegions.empty()) {
- for (Block &block : *pendingRegions.pop_back_val()) {
- for (Operation &op : block) {
- for (Value operand : op.getOperands()) {
- // operand should be non-null here if the IR is well-formed. But
- // we don't assert here as this function is called from the verifier
- // and so could be called on invalid IR.
- if (!operand) {
- if (noteLoc)
- op.emitOpError("block's operand not defined").attachNote(noteLoc);
- return false;
- }
+ for (Operation &op : pendingRegions.pop_back_val()->getOps()) {
+ for (Value operand : op.getOperands()) {
+ // operand should be non-null here if the IR is well-formed. But
+ // we don't assert here as this function is called from the verifier
+ // and so could be called on invalid IR.
+ if (!operand) {
+ if (noteLoc)
+ op.emitOpError("block's operand not defined").attachNote(noteLoc);
+ return false;
+ }
- // Check that any value that is used by an operation is defined in the
- // same region as either an operation result or a block argument.
- if (operand.getParentRegion()->isProperAncestor(&limit)) {
- if (noteLoc) {
- op.emitOpError("using value defined outside the region")
- .attachNote(noteLoc)
- << "required by region isolation constraints";
- }
- return false;
+ // Check that any value that is used by an operation is defined in the
+ // same region as either an operation result or a block argument.
+ if (operand.getParentRegion()->isProperAncestor(&limit)) {
+ if (noteLoc) {
+ op.emitOpError("using value defined outside the region")
+ .attachNote(noteLoc)
+ << "required by region isolation constraints";
}
+ return false;
}
- // Schedule any regions the operations contain for further checking.
- pendingRegions.reserve(pendingRegions.size() + op.getNumRegions());
- for (Region &subRegion : op.getRegions())
- pendingRegions.push_back(&subRegion);
}
+ // Schedule any regions the operations contain for further checking.
+ pendingRegions.reserve(pendingRegions.size() + op.getNumRegions());
+ for (Region &subRegion : op.getRegions())
+ pendingRegions.push_back(&subRegion);
}
}
return true;
@@ -219,6 +217,40 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
first->parentValidOpOrderPair.setPointer(curParent);
}
+//===----------------------------------------------------------------------===//
+// Region::OpIterator
+//===----------------------------------------------------------------------===//
+
+Region::OpIterator::OpIterator(Region *region, bool end)
+ : region(region), block(end ? region->end() : region->begin()) {
+ if (!region->empty())
+ skipOverBlocksWithNoOps();
+}
+
+Region::OpIterator &Region::OpIterator::operator++() {
+ // We increment over operations, if we reach the last use then move to next
+ // block.
+ if (operation != block->end())
+ ++operation;
+ if (operation == block->end()) {
+ ++block;
+ skipOverBlocksWithNoOps();
+ }
+ return *this;
+}
+
+void Region::OpIterator::skipOverBlocksWithNoOps() {
+ while (block != region->end() && block->empty())
+ ++block;
+
+ // If we are at the last block, then set the operation to first operation of
+ // next block (sentinel value used for end).
+ if (block == region->end())
+ operation = {};
+ else
+ operation = block->begin();
+}
+
//===----------------------------------------------------------------------===//
// RegionRange
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index e195225d675a..1d2235b61936 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -245,11 +245,9 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
// Look for a symbol with the given name.
- for (auto &block : symbolTableOp->getRegion(0)) {
- for (auto &op : block)
- if (getNameIfSymbol(&op) == symbol)
- return &op;
- }
+ for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
+ if (getNameIfSymbol(&op) == symbol)
+ return &op;
return nullptr;
}
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
@@ -444,21 +442,19 @@ static Optional<WalkResult> walkSymbolUses(
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
while (!worklist.empty()) {
- for (Block &block : *worklist.pop_back_val()) {
- for (Operation &op : block) {
- if (walkSymbolRefs(&op, callback).wasInterrupted())
- return WalkResult::interrupt();
-
- // Check that this isn't a potentially unknown symbol table.
- if (isPotentiallyUnknownSymbolTable(&op))
- return llvm::None;
-
- // If this op defines a new symbol table scope, we can't traverse. Any
- // symbol references nested within 'op' are
diff erent semantically.
- if (!op.hasTrait<OpTrait::SymbolTable>()) {
- for (Region ®ion : op.getRegions())
- worklist.push_back(®ion);
- }
+ for (Operation &op : worklist.pop_back_val()->getOps()) {
+ if (walkSymbolRefs(&op, callback).wasInterrupted())
+ return WalkResult::interrupt();
+
+ // Check that this isn't a potentially unknown symbol table.
+ if (isPotentiallyUnknownSymbolTable(&op))
+ return llvm::None;
+
+ // If this op defines a new symbol table scope, we can't traverse. Any
+ // symbol references nested within 'op' are
diff erent semantically.
+ if (!op.hasTrait<OpTrait::SymbolTable>()) {
+ for (Region ®ion : op.getRegions())
+ worklist.push_back(®ion);
}
}
}
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index ee645cb55511..f8f48d86b562 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -122,23 +122,21 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
// Walk each of the symbol tables looking for discardable callgraph nodes.
auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
- for (Block &block : symbolTableOp->getRegion(0)) {
- for (Operation &op : block) {
- // If this is a callgraph operation, check to see if it is discardable.
- if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
- if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
- SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
- if (symbol && (allUsesVisible || symbol.isPrivate()) &&
- symbol.canDiscardOnUseEmpty()) {
- discardableSymNodeUses.try_emplace(node, 0);
- }
- continue;
+ for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
+ // If this is a callgraph operation, check to see if it is discardable.
+ if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
+ if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
+ SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+ if (symbol && (allUsesVisible || symbol.isPrivate()) &&
+ symbol.canDiscardOnUseEmpty()) {
+ discardableSymNodeUses.try_emplace(node, 0);
}
+ continue;
}
- // Otherwise, check for any referenced nodes. These will be always-live.
- walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
- [](CallGraphNode *, Operation *) {});
}
+ // Otherwise, check for any referenced nodes. These will be always-live.
+ walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
+ [](CallGraphNode *, Operation *) {});
}
};
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
More information about the Mlir-commits
mailing list