[Mlir-commits] [mlir] bb0d5f7 - [mlir] Add NumberOfExecutions analysis + update RegionBranchOpInterface interface to query number of region invocations
Eugene Zhulenev
llvmlistbot at llvm.org
Wed Nov 11 01:43:24 PST 2020
Author: Eugene Zhulenev
Date: 2020-11-11T01:43:17-08:00
New Revision: bb0d5f767dd7cf34a92ba2af2d6fdb206d883e8c
URL: https://github.com/llvm/llvm-project/commit/bb0d5f767dd7cf34a92ba2af2d6fdb206d883e8c
DIFF: https://github.com/llvm/llvm-project/commit/bb0d5f767dd7cf34a92ba2af2d6fdb206d883e8c.diff
LOG: [mlir] Add NumberOfExecutions analysis + update RegionBranchOpInterface interface to query number of region invocations
Implements RFC discussed in: https://llvm.discourse.group/t/rfc-operationinstancesinterface-or-any-better-name/2158/10
Reviewed By: silvas, ftynse, rriddle
Differential Revision: https://reviews.llvm.org/D90922
Added:
mlir/include/mlir/Analysis/NumberOfExecutions.h
mlir/lib/Analysis/NumberOfExecutions.cpp
mlir/test/Analysis/test-number-of-block-executions.mlir
mlir/test/Analysis/test-number-of-operation-executions.mlir
mlir/test/lib/Transforms/TestNumberOfExecutions.cpp
mlir/unittests/Support/MathExtrasTest.cpp
Modified:
mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/include/mlir/Support/MathExtras.h
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/SCF/SCF.cpp
mlir/lib/Interfaces/ControlFlowInterfaces.cpp
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
mlir/unittests/Support/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/NumberOfExecutions.h b/mlir/include/mlir/Analysis/NumberOfExecutions.h
new file mode 100644
index 000000000000..aa3080431cf6
--- /dev/null
+++ b/mlir/include/mlir/Analysis/NumberOfExecutions.h
@@ -0,0 +1,107 @@
+//===- NumberOfExecutions.h - Number of executions analysis -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains an analysis for computing how many times a block within a
+// region is executed *each time* that region is entered. The analysis
+// iterates over all associated regions that are attached to the given top-level
+// operation.
+//
+// It is possible to query number of executions information on block level.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H
+#define MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+
+class Block;
+class BlockNumberOfExecutionsInfo;
+class Operation;
+class Region;
+
+/// Represents an analysis for computing how many times a block or an operation
+/// within a region is executed *each time* that region is entered. The analysis
+/// iterates over all associated regions that are attached to the given
+/// top-level operation.
+///
+/// This analysis assumes that all operations complete in a finite amount of
+/// time (do not abort and do not go into the infinite loop).
+class NumberOfExecutions {
+public:
+ /// Creates a new NumberOfExecutions analysis that computes how many times a
+ /// block within a region is executed for all associated regions.
+ explicit NumberOfExecutions(Operation *op);
+
+ /// Returns the number of times operations `op` is executed *each time* the
+ /// control flow enters the region `perEntryOfThisRegion`. Returns empty
+ /// optional if this is not known statically.
+ Optional<int64_t> getNumberOfExecutions(Operation *op,
+ Region *perEntryOfThisRegion) const;
+
+ /// Returns the number of times block `block` is executed *each time* the
+ /// control flow enters the region `perEntryOfThisRegion`. Returns empty
+ /// optional if this is not known statically.
+ Optional<int64_t> getNumberOfExecutions(Block *block,
+ Region *perEntryOfThisRegion) const;
+
+ /// Dumps the number of block executions *each time* the control flow enters
+ /// the region `perEntryOfThisRegion` to the given stream.
+ void printBlockExecutions(raw_ostream &os,
+ Region *perEntryOfThisRegion) const;
+
+ /// Dumps the number of operation executions *each time* the control flow
+ /// enters the region `perEntryOfThisRegion` to the given stream.
+ void printOperationExecutions(raw_ostream &os,
+ Region *perEntryOfThisRegion) const;
+
+private:
+ /// The operation this analysis was constructed from.
+ Operation *operation;
+
+ /// A mapping from blocks to number of executions information.
+ DenseMap<Block *, BlockNumberOfExecutionsInfo> blockNumbersOfExecution;
+};
+
+/// Represents number of block executions information.
+class BlockNumberOfExecutionsInfo {
+public:
+ BlockNumberOfExecutionsInfo(Block *block,
+ Optional<int64_t> numberOfRegionInvocations,
+ Optional<int64_t> numberOfBlockExecutions);
+
+ /// Returns the number of times this block will be executed *each time* the
+ /// parent operation is executed.
+ Optional<int64_t> getNumberOfExecutions() const;
+
+ /// Returns the number of times this block will be executed if the parent
+ /// region is invoked `numberOfRegionInvocations` times. This can be
diff erent
+ /// from the number of region invocations by the parent operation.
+ Optional<int64_t>
+ getNumberOfExecutions(int64_t numberOfRegionInvocations) const;
+
+ Block *getBlock() const { return block; }
+
+private:
+ Block *block;
+
+ /// Number of `block` parent region invocations *each time* parent operation
+ /// is executed.
+ Optional<int64_t> numberOfRegionInvocations;
+
+ /// Number of `block` executions *each time* parent region is invoked.
+ Optional<int64_t> numberOfBlockExecutions;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index 1519ccd1bdfc..28790557f36b 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -19,6 +19,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index a1e489007f0f..7fad5ce48214 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -14,6 +14,7 @@
#define ASYNC_OPS
include "mlir/Dialect/Async/IR/AsyncBase.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -26,6 +27,8 @@ class Async_Op<string mnemonic, list<OpTrait> traits = []> :
def Async_ExecuteOp :
Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getNumRegionInvocations"]>,
AttrSizedOperandSegments]> {
let summary = "Asynchronous execute operation";
let description = [{
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 1dc6ef1f68a4..f30e8f20c4b7 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -196,6 +196,11 @@ def ForOp : SCF_Op<"for",
/// induction variable. LoopOp only has one region, so 0 is the only valid
/// value for `index`.
OperandRange getSuccessorEntryOperands(unsigned index);
+
+ /// Returns the number of invocations of the body block if the loop bounds
+ /// are constants. Returns `kUnknownNumRegionInvocations` otherwise.
+ void getNumRegionInvocations(ArrayRef<Attribute> operands,
+ SmallVectorImpl<int64_t> &countPerRegion);
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 725e13b8b9d2..57ff63fbf784 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -41,6 +41,9 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
+// A constant value to represent unknown number of region invocations.
+extern const int64_t kUnknownNumRegionInvocations;
+
namespace detail {
/// Verify that types match along control flow edges described the given op.
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 217ab263877f..3c568a0e776f 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -130,6 +130,26 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
"void", "getSuccessorRegions",
(ins "Optional<unsigned>":$index, "ArrayRef<Attribute>":$operands,
"SmallVectorImpl<RegionSuccessor> &":$regions)
+ >,
+ InterfaceMethod<[{
+ Populates countPerRegion with the number of times this operation will
+ invoke the attached regions (assuming the regions yield normally, i.e.
+ do not abort or invoke an infinite loop). If the number of region
+ invocations is not known statically it will set the number of
+ invocations to `kUnknownNumRegionInvocations`.
+
+ `operands` is a set of optional attributes that either correspond to a
+ constant values for each operand of this operation, or null if that
+ operand is not a constant.
+ }],
+ "void", "getNumRegionInvocations",
+ (ins "ArrayRef<Attribute>":$operands,
+ "SmallVectorImpl<int64_t> &":$countPerRegion), [{}],
+ /*defaultImplementation=*/[{
+ unsigned numRegions = this->getOperation()->getNumRegions();
+ assert(countPerRegion.empty());
+ countPerRegion.resize(numRegions, kUnknownNumRegionInvocations);
+ }]
>
];
diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h
index 98bdacedf657..65429e5f0265 100644
--- a/mlir/include/mlir/Support/MathExtras.h
+++ b/mlir/include/mlir/Support/MathExtras.h
@@ -19,19 +19,21 @@
namespace mlir {
/// Returns the result of MLIR's ceildiv operation on constants. The RHS is
-/// expected to be positive.
+/// expected to be non-zero.
inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
- assert(rhs >= 1);
+ assert(rhs != 0);
// C/C++'s integer division rounds towards 0.
- return lhs % rhs > 0 ? lhs / rhs + 1 : lhs / rhs;
+ int64_t x = (rhs > 0) ? -1 : 1;
+ return (lhs * rhs > 0) ? ((lhs + x) / rhs) + 1 : -(-lhs / rhs);
}
/// Returns the result of MLIR's floordiv operation on constants. The RHS is
-/// expected to be positive.
+/// expected to be non-zero.
inline int64_t floorDiv(int64_t lhs, int64_t rhs) {
- assert(rhs >= 1);
+ assert(rhs != 0);
// C/C++'s integer division rounds towards 0.
- return lhs % rhs < 0 ? lhs / rhs - 1 : lhs / rhs;
+ int64_t x = (rhs < 0) ? 1 : -1;
+ return (lhs * rhs < 0) ? -((-lhs + x) / rhs) - 1 : lhs / rhs;
}
/// Returns MLIR's mod operation on constants. MLIR's mod operation yields the
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 632310bc5f13..3247ef1f56b0 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
Liveness.cpp
LoopAnalysis.cpp
NestedMatcher.cpp
+ NumberOfExecutions.cpp
PresburgerSet.cpp
SliceAnalysis.cpp
Utils.cpp
@@ -15,6 +16,7 @@ add_mlir_library(MLIRAnalysis
BufferAliasAnalysis.cpp
CallGraph.cpp
Liveness.cpp
+ NumberOfExecutions.cpp
SliceAnalysis.cpp
ADDITIONAL_HEADER_DIRS
@@ -53,5 +55,5 @@ add_mlir_library(MLIRLoopAnalysis
MLIRPresburger
MLIRSCF
)
-
+
add_subdirectory(Presburger)
diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp
new file mode 100644
index 000000000000..425936b5eaaf
--- /dev/null
+++ b/mlir/lib/Analysis/NumberOfExecutions.cpp
@@ -0,0 +1,243 @@
+//===- NumberOfExecutions.cpp - Number of executions analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of the number of executions analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/NumberOfExecutions.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "number-of-executions-analysis"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// NumberOfExecutions
+//===----------------------------------------------------------------------===//
+
+/// Computes blocks number of executions information for the given region.
+static void computeRegionBlockNumberOfExecutions(
+ Region ®ion, DenseMap<Block *, BlockNumberOfExecutionsInfo> &blockInfo) {
+ Operation *parentOp = region.getParentOp();
+ int regionId = region.getRegionNumber();
+
+ auto regionKindInterface = dyn_cast<RegionKindInterface>(parentOp);
+ bool isGraphRegion =
+ regionKindInterface &&
+ regionKindInterface.getRegionKind(regionId) == RegionKind::Graph;
+
+ // CFG analysis does not make sense for Graph regions, set the number of
+ // executions for all blocks as unknown.
+ if (isGraphRegion) {
+ for (Block &block : region)
+ blockInfo.insert({&block, {&block, None, None}});
+ return;
+ }
+
+ // Number of region invocations for all attached regions.
+ SmallVector<int64_t, 4> numRegionsInvocations;
+
+ // Query RegionBranchOpInterface interface if it is available.
+ if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)) {
+ SmallVector<Attribute, 4> operands(parentOp->getNumOperands());
+ for (auto operandIt : llvm::enumerate(parentOp->getOperands()))
+ matchPattern(operandIt.value(), m_Constant(&operands[operandIt.index()]));
+
+ regionInterface.getNumRegionInvocations(operands, numRegionsInvocations);
+ }
+
+ // Number of region invocations *each time* parent operation is invoked.
+ Optional<int64_t> numRegionInvocations;
+
+ if (!numRegionsInvocations.empty() &&
+ numRegionsInvocations[regionId] != kUnknownNumRegionInvocations) {
+ numRegionInvocations = numRegionsInvocations[regionId];
+ }
+
+ // DFS traversal looking for loops in the CFG.
+ llvm::SmallSet<Block *, 4> loopStart;
+
+ llvm::unique_function<void(Block *, llvm::SmallSet<Block *, 4> &)> dfs =
+ [&](Block *block, llvm::SmallSet<Block *, 4> &visited) {
+ // Found a loop in the CFG that starts at the `block`.
+ if (visited.contains(block)) {
+ loopStart.insert(block);
+ return;
+ }
+
+ // Continue DFS traversal.
+ visited.insert(block);
+ for (Block *successor : block->getSuccessors())
+ dfs(successor, visited);
+ visited.erase(block);
+ };
+
+ llvm::SmallSet<Block *, 4> visited;
+ dfs(®ion.front(), visited);
+
+ // Start from the entry block and follow only blocks with single succesor.
+ Block *block = ®ion.front();
+ while (block && !loopStart.contains(block)) {
+ // Block will be executed exactly once.
+ blockInfo.insert(
+ {block, BlockNumberOfExecutionsInfo(block, numRegionInvocations,
+ /*numberOfBlockExecutions=*/1)});
+
+ // We reached the exit block or block with multiple successors.
+ if (block->getNumSuccessors() != 1)
+ break;
+
+ // Continue traversal.
+ block = block->getSuccessor(0);
+ }
+
+ // For all blocks that we did not visit set the executions number to unknown.
+ for (Block &block : region)
+ if (blockInfo.count(&block) == 0)
+ blockInfo.insert({&block, BlockNumberOfExecutionsInfo(
+ &block, numRegionInvocations,
+ /*numberOfBlockExecutions=*/None)});
+}
+
+/// Creates a new NumberOfExecutions analysis that computes how many times a
+/// block within a region is executed for all associated regions.
+NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
+ operation->walk([&](Region *region) {
+ computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
+ });
+}
+
+Optional<int64_t>
+NumberOfExecutions::getNumberOfExecutions(Operation *op,
+ Region *perEntryOfThisRegion) const {
+ // Assuming that all operations complete in a finite amount of time (do not
+ // abort and do not go into the infinite loop), the number of operation
+ // executions is equal to the number of block executions that contains the
+ // operation.
+ return getNumberOfExecutions(op->getBlock(), perEntryOfThisRegion);
+}
+
+Optional<int64_t>
+NumberOfExecutions::getNumberOfExecutions(Block *block,
+ Region *perEntryOfThisRegion) const {
+ // Return None if the given `block` does not lie inside the
+ // `perEntryOfThisRegion` region.
+ if (!perEntryOfThisRegion->findAncestorBlockInRegion(*block))
+ return None;
+
+ // Find the block information for the given `block.
+ auto blockIt = blockNumbersOfExecution.find(block);
+ if (blockIt == blockNumbersOfExecution.end())
+ return None;
+ const auto &blockInfo = blockIt->getSecond();
+
+ // Override the number of region invocations with `1` if the
+ // `perEntryOfThisRegion` region owns the block.
+ auto getNumberOfExecutions = [&](const BlockNumberOfExecutionsInfo &info) {
+ if (info.getBlock()->getParent() == perEntryOfThisRegion)
+ return info.getNumberOfExecutions(/*numberOfRegionInvocations=*/1);
+ return info.getNumberOfExecutions();
+ };
+
+ // Immediately return None if we do not know the block number of executions.
+ auto blockExecutions = getNumberOfExecutions(blockInfo);
+ if (!blockExecutions.hasValue())
+ return None;
+
+ // Follow parent operations until we reach the operations that owns the
+ // `perEntryOfThisRegion`.
+ int64_t numberOfExecutions = *blockExecutions;
+ Operation *parentOp = block->getParentOp();
+
+ while (parentOp != perEntryOfThisRegion->getParentOp()) {
+ // Find how many times will be executed the block that owns the parent
+ // operation.
+ Block *parentBlock = parentOp->getBlock();
+
+ auto parentBlockIt = blockNumbersOfExecution.find(parentBlock);
+ if (parentBlockIt == blockNumbersOfExecution.end())
+ return None;
+ const auto &parentBlockInfo = parentBlockIt->getSecond();
+ auto parentBlockExecutions = getNumberOfExecutions(parentBlockInfo);
+
+ // We stumbled upon an operation with unknown number of executions.
+ if (!parentBlockExecutions.hasValue())
+ return None;
+
+ // Number of block executions is a product of all parent blocks executions.
+ numberOfExecutions *= *parentBlockExecutions;
+ parentOp = parentOp->getParentOp();
+
+ assert(parentOp != nullptr);
+ }
+
+ return numberOfExecutions;
+}
+
+void NumberOfExecutions::printBlockExecutions(
+ raw_ostream &os, Region *perEntryOfThisRegion) const {
+ unsigned blockId = 0;
+
+ operation->walk([&](Block *block) {
+ llvm::errs() << "Block: " << blockId++ << "\n";
+ llvm::errs() << "Number of executions: ";
+ if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
+ llvm::errs() << *n << "\n";
+ else
+ llvm::errs() << "<unknown>\n";
+ });
+}
+
+void NumberOfExecutions::printOperationExecutions(
+ raw_ostream &os, Region *perEntryOfThisRegion) const {
+ operation->walk([&](Block *block) {
+ block->walk([&](Operation *operation) {
+ // Skip the operation that was used to build the analysis.
+ if (operation == this->operation)
+ return;
+
+ llvm::errs() << "Operation: " << operation->getName() << "\n";
+ llvm::errs() << "Number of executions: ";
+ if (auto n = getNumberOfExecutions(operation, perEntryOfThisRegion))
+ llvm::errs() << *n << "\n";
+ else
+ llvm::errs() << "<unknown>\n";
+ });
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// BlockNumberOfExecutionsInfo
+//===----------------------------------------------------------------------===//
+
+BlockNumberOfExecutionsInfo::BlockNumberOfExecutionsInfo(
+ Block *block, Optional<int64_t> numberOfRegionInvocations,
+ Optional<int64_t> numberOfBlockExecutions)
+ : block(block), numberOfRegionInvocations(numberOfRegionInvocations),
+ numberOfBlockExecutions(numberOfBlockExecutions) {}
+
+Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions() const {
+ if (numberOfRegionInvocations && numberOfBlockExecutions)
+ return *numberOfRegionInvocations * *numberOfBlockExecutions;
+ return None;
+}
+
+Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions(
+ int64_t numberOfRegionInvocations) const {
+ if (numberOfBlockExecutions)
+ return numberOfRegionInvocations * *numberOfBlockExecutions;
+ return None;
+}
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 36ef1af8cfa7..0c22ea22a9da 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -118,6 +118,27 @@ static LogicalResult verify(YieldOp op) {
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
+void ExecuteOp::getNumRegionInvocations(
+ ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
+ (void)operands;
+ assert(countPerRegion.empty());
+ countPerRegion.push_back(1);
+}
+
+void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The `body` region branch back to the parent operation.
+ if (index.hasValue()) {
+ assert(*index == 0);
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // Otherwise the successor is the body region.
+ regions.push_back(RegionSuccessor(&body()));
+}
+
static void print(OpAsmPrinter &p, ExecuteOp op) {
p << op.getOperationName();
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index ea607a3a402a..bc8671b9ba85 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
@@ -284,6 +285,26 @@ void ForOp::getSuccessorRegions(Optional<unsigned> index,
regions.push_back(RegionSuccessor(getResults()));
}
+void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
+ SmallVectorImpl<int64_t> &countPerRegion) {
+ assert(countPerRegion.empty());
+ countPerRegion.resize(1);
+
+ auto lb = operands[0].dyn_cast_or_null<IntegerAttr>();
+ auto ub = operands[1].dyn_cast_or_null<IntegerAttr>();
+ auto step = operands[2].dyn_cast_or_null<IntegerAttr>();
+
+ // Loop bounds are not known statically.
+ if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
+ countPerRegion[0] = -1;
+ return;
+ }
+
+ countPerRegion[0] =
+ ceilDiv(ub.getValue().getSExtValue() - lb.getValue().getSExtValue(),
+ step.getValue().getSExtValue());
+}
+
ValueVector mlir::scf::buildLoopNest(
OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
ValueRange steps, ValueRange iterArgs,
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 99e86bcb5057..5b845584b2e8 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -73,6 +73,9 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
+// A constant value to represent unknown number of region invocations.
+const int64_t mlir::kUnknownNumRegionInvocations = -1;
+
/// Verify that types match along all region control flow edges originating from
/// `sourceNo` (region # if source is a region, llvm::None if source is parent
/// op). `getInputsTypesForRegion` is a function that returns the types of the
diff --git a/mlir/test/Analysis/test-number-of-block-executions.mlir b/mlir/test/Analysis/test-number-of-block-executions.mlir
new file mode 100644
index 000000000000..d6ac9d2709a4
--- /dev/null
+++ b/mlir/test/Analysis/test-number-of-block-executions.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s \
+// RUN: -test-print-number-of-block-executions \
+// RUN: -split-input-file 2>&1 \
+// RUN: | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: Number of executions: empty
+func @empty() {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: sequential
+func @sequential() {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ br ^bb1
+^bb1:
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: 1
+ br ^bb2
+^bb2:
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: 1
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: conditional
+func @conditional(%cond : i1) {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ br ^bb1
+^bb1:
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: 1
+ cond_br %cond, ^bb2, ^bb3
+^bb2:
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: <unknown>
+ br ^bb4
+^bb3:
+ // CHECK: Block: 3
+ // CHECK-NEXT: Number of executions: <unknown>
+ br ^bb4
+^bb4:
+ // CHECK: Block: 4
+ // CHECK-NEXT: Number of executions: <unknown>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: loop
+func @loop(%cond : i1) {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ br ^bb1
+^bb1:
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: <unknown>
+ br ^bb2
+^bb2:
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: <unknown>
+ br ^bb3
+^bb3:
+ // CHECK: Block: 3
+ // CHECK-NEXT: Number of executions: <unknown>
+ cond_br %cond, ^bb1, ^bb4
+^bb4:
+ // CHECK: Block: 4
+ // CHECK-NEXT: Number of executions: <unknown>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: scf_if_dynamic_branch
+func @scf_if_dynamic_branch(%cond : i1) {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ scf.if %cond {
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: <unknown>
+ } else {
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: <unknown>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: async_execute
+func @async_execute() {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ async.execute {
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: 1
+ async.yield
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: async_execute_with_scf_if
+func @async_execute_with_scf_if(%cond : i1) {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ async.execute {
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: 1
+ scf.if %cond {
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: <unknown>
+ } else {
+ // CHECK: Block: 3
+ // CHECK-NEXT: Number of executions: <unknown>
+ }
+ async.yield
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: scf_for_constant_bounds
+func @scf_for_constant_bounds() {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+
+ scf.for %i = %c0 to %c2 step %c1 {
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: 2
+ }
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: propagate_parent_num_executions
+func @propagate_parent_num_executions() {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+
+ scf.for %i = %c0 to %c2 step %c1 {
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: 2
+ async.execute {
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: 2
+ async.yield
+ }
+ }
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: clear_num_executions
+func @clear_num_executions(%step : index) {
+ // CHECK: Block: 0
+ // CHECK-NEXT: Number of executions: 1
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+
+ scf.for %i = %c0 to %c2 step %step {
+ // CHECK: Block: 1
+ // CHECK-NEXT: Number of executions: <unknown>
+ async.execute {
+ // CHECK: Block: 2
+ // CHECK-NEXT: Number of executions: <unknown>
+ async.yield
+ }
+ }
+
+ return
+}
diff --git a/mlir/test/Analysis/test-number-of-operation-executions.mlir b/mlir/test/Analysis/test-number-of-operation-executions.mlir
new file mode 100644
index 000000000000..60e9a8a86f42
--- /dev/null
+++ b/mlir/test/Analysis/test-number-of-operation-executions.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s \
+// RUN: -test-print-number-of-operation-executions \
+// RUN: -split-input-file 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: Number of executions: empty
+func @empty() {
+ // CHECK: Operation: std.return
+ // CHECK-NEXT: Number of executions: 1
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: propagate_parent_num_executions
+func @propagate_parent_num_executions() {
+ // CHECK: Operation: std.constant
+ // CHECK-NEXT: Number of executions: 1
+ %c0 = constant 0 : index
+ // CHECK: Operation: std.constant
+ // CHECK-NEXT: Number of executions: 1
+ %c1 = constant 1 : index
+ // CHECK: Operation: std.constant
+ // CHECK-NEXT: Number of executions: 1
+ %c2 = constant 2 : index
+
+ // CHECK-DAG: Operation: scf.for
+ // CHECK-NEXT: Number of executions: 1
+ scf.for %i = %c0 to %c2 step %c1 {
+ // CHECK-DAG: Operation: async.execute
+ // CHECK-NEXT: Number of executions: 2
+ async.execute {
+ // CHECK-DAG: Operation: async.yield
+ // CHECK-NEXT: Number of executions: 2
+ async.yield
+ }
+ }
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: clear_num_executions
+func @clear_num_executions(%step : index) {
+ // CHECK: Operation: std.constant
+ // CHECK-NEXT: Number of executions: 1
+ %c0 = constant 0 : index
+ // CHECK: Operation: std.constant
+ // CHECK-NEXT: Number of executions: 1
+ %c2 = constant 2 : index
+
+ // CHECK: Operation: scf.for
+ // CHECK-NEXT: Number of executions: 1
+ scf.for %i = %c0 to %c2 step %step {
+ // CHECK: Operation: async.execute
+ // CHECK-NEXT: Number of executions: <unknown>
+ async.execute {
+ // CHECK: Operation: async.yield
+ // CHECK-NEXT: Number of executions: <unknown>
+ async.yield
+ }
+ }
+
+ return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 249addc61b9b..69d45b570a3c 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_library(MLIRTestTransforms
TestLoopMapping.cpp
TestLoopParametricTiling.cpp
TestLoopUnrolling.cpp
+ TestNumberOfExecutions.cpp
TestOpaqueLoc.cpp
TestMemRefBoundCheck.cpp
TestMemRefDependenceCheck.cpp
diff --git a/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp b/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp
new file mode 100644
index 000000000000..908e596ac158
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp
@@ -0,0 +1,57 @@
+//===- TestNumberOfExecutions.cpp - Test number of executions analysis ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for constructing and resolving number of
+// executions information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/NumberOfExecutions.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestNumberOfBlockExecutionsPass
+ : public PassWrapper<TestNumberOfBlockExecutionsPass, FunctionPass> {
+ void runOnFunction() override {
+ llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
+ getAnalysis<NumberOfExecutions>().printBlockExecutions(
+ llvm::errs(), &getFunction().getBody());
+ }
+};
+
+struct TestNumberOfOperationExecutionsPass
+ : public PassWrapper<TestNumberOfOperationExecutionsPass, FunctionPass> {
+ void runOnFunction() override {
+ llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
+ getAnalysis<NumberOfExecutions>().printOperationExecutions(
+ llvm::errs(), &getFunction().getBody());
+ }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestNumberOfBlockExecutionsPass() {
+ PassRegistration<TestNumberOfBlockExecutionsPass>(
+ "test-print-number-of-block-executions",
+ "Print the contents of a constructed number of executions analysis for "
+ "all blocks.");
+}
+
+void registerTestNumberOfOperationExecutionsPass() {
+ PassRegistration<TestNumberOfOperationExecutionsPass>(
+ "test-print-number-of-operation-executions",
+ "Print the contents of a constructed number of executions analysis for "
+ "all operations.");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8f88ae35409d..58444e6a9501 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -81,6 +81,8 @@ void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
+void registerTestNumberOfBlockExecutionsPass();
+void registerTestNumberOfOperationExecutionsPass();
void registerTestOpaqueLoc();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
@@ -145,6 +147,8 @@ void registerTestPasses() {
test::registerTestLoopUnrollingPass();
test::registerTestMemRefDependenceCheck();
test::registerTestMemRefStrideCalculation();
+ test::registerTestNumberOfBlockExecutionsPass();
+ test::registerTestNumberOfOperationExecutionsPass();
test::registerTestOpaqueLoc();
test::registerTestRecursiveTypesPass();
test::registerTestSCFUtilsPass();
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index cec3c2d57386..2bc673336ba5 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRSupportTests
IndentedOstreamTest.cpp
+ MathExtrasTest.cpp
)
target_link_libraries(MLIRSupportTests
diff --git a/mlir/unittests/Support/MathExtrasTest.cpp b/mlir/unittests/Support/MathExtrasTest.cpp
new file mode 100644
index 000000000000..304b353d4167
--- /dev/null
+++ b/mlir/unittests/Support/MathExtrasTest.cpp
@@ -0,0 +1,27 @@
+//===- MathExtrasTest.cpp - MathExtras Tests ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/MathExtras.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+using ::testing::Eq;
+
+TEST(MathExtrasTest, CeilDivTest) {
+ EXPECT_THAT(ceilDiv(14, 3), Eq(5));
+ EXPECT_THAT(ceilDiv(14, -3), Eq(-4));
+ EXPECT_THAT(ceilDiv(-14, -3), Eq(5));
+ EXPECT_THAT(ceilDiv(-14, 3), Eq(-4));
+}
+
+TEST(MathExtrasTest, FloorDivTest) {
+ EXPECT_THAT(floorDiv(14, 3), Eq(4));
+ EXPECT_THAT(floorDiv(14, -3), Eq(-5));
+ EXPECT_THAT(floorDiv(-14, -3), Eq(4));
+ EXPECT_THAT(floorDiv(-14, 3), Eq(-5));
+}
More information about the Mlir-commits
mailing list