[Mlir-commits] [mlir] bb0d5f7 - [mlir] Add NumberOfExecutions analysis + update RegionBranchOpInterface interface to query number of region invocations

Eugene Zhulenev llvmlistbot at llvm.org
Wed Nov 11 01:43:24 PST 2020


Author: Eugene Zhulenev
Date: 2020-11-11T01:43:17-08:00
New Revision: bb0d5f767dd7cf34a92ba2af2d6fdb206d883e8c

URL: https://github.com/llvm/llvm-project/commit/bb0d5f767dd7cf34a92ba2af2d6fdb206d883e8c
DIFF: https://github.com/llvm/llvm-project/commit/bb0d5f767dd7cf34a92ba2af2d6fdb206d883e8c.diff

LOG: [mlir] Add NumberOfExecutions analysis + update RegionBranchOpInterface interface to query number of region invocations

Implements RFC discussed in: https://llvm.discourse.group/t/rfc-operationinstancesinterface-or-any-better-name/2158/10

Reviewed By: silvas, ftynse, rriddle

Differential Revision: https://reviews.llvm.org/D90922

Added: 
    mlir/include/mlir/Analysis/NumberOfExecutions.h
    mlir/lib/Analysis/NumberOfExecutions.cpp
    mlir/test/Analysis/test-number-of-block-executions.mlir
    mlir/test/Analysis/test-number-of-operation-executions.mlir
    mlir/test/lib/Transforms/TestNumberOfExecutions.cpp
    mlir/unittests/Support/MathExtrasTest.cpp

Modified: 
    mlir/include/mlir/Dialect/Async/IR/Async.h
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/include/mlir/Support/MathExtras.h
    mlir/lib/Analysis/CMakeLists.txt
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    mlir/unittests/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/NumberOfExecutions.h b/mlir/include/mlir/Analysis/NumberOfExecutions.h
new file mode 100644
index 000000000000..aa3080431cf6
--- /dev/null
+++ b/mlir/include/mlir/Analysis/NumberOfExecutions.h
@@ -0,0 +1,107 @@
+//===- NumberOfExecutions.h - Number of executions analysis -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains an analysis for computing how many times a block within a
+// region is executed *each time* that region is entered. The analysis
+// iterates over all associated regions that are attached to the given top-level
+// operation.
+//
+// It is possible to query number of executions information on block level.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H
+#define MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+
+class Block;
+class BlockNumberOfExecutionsInfo;
+class Operation;
+class Region;
+
+/// Represents an analysis for computing how many times a block or an operation
+/// within a region is executed *each time* that region is entered. The analysis
+/// iterates over all associated regions that are attached to the given
+/// top-level operation.
+///
+/// This analysis assumes that all operations complete in a finite amount of
+/// time (do not abort and do not go into the infinite loop).
+class NumberOfExecutions {
+public:
+  /// Creates a new NumberOfExecutions analysis that computes how many times a
+  /// block within a region is executed for all associated regions.
+  explicit NumberOfExecutions(Operation *op);
+
+  /// Returns the number of times operations `op` is executed *each time* the
+  /// control flow enters the region `perEntryOfThisRegion`. Returns empty
+  /// optional if this is not known statically.
+  Optional<int64_t> getNumberOfExecutions(Operation *op,
+                                          Region *perEntryOfThisRegion) const;
+
+  /// Returns the number of times block `block` is executed *each time* the
+  /// control flow enters the region `perEntryOfThisRegion`. Returns empty
+  /// optional if this is not known statically.
+  Optional<int64_t> getNumberOfExecutions(Block *block,
+                                          Region *perEntryOfThisRegion) const;
+
+  /// Dumps the number of block executions *each time* the control flow enters
+  /// the region `perEntryOfThisRegion` to the given stream.
+  void printBlockExecutions(raw_ostream &os,
+                            Region *perEntryOfThisRegion) const;
+
+  /// Dumps the number of operation executions *each time* the control flow
+  /// enters the region `perEntryOfThisRegion` to the given stream.
+  void printOperationExecutions(raw_ostream &os,
+                                Region *perEntryOfThisRegion) const;
+
+private:
+  /// The operation this analysis was constructed from.
+  Operation *operation;
+
+  /// A mapping from blocks to number of executions information.
+  DenseMap<Block *, BlockNumberOfExecutionsInfo> blockNumbersOfExecution;
+};
+
+/// Represents number of block executions information.
+class BlockNumberOfExecutionsInfo {
+public:
+  BlockNumberOfExecutionsInfo(Block *block,
+                              Optional<int64_t> numberOfRegionInvocations,
+                              Optional<int64_t> numberOfBlockExecutions);
+
+  /// Returns the number of times this block will be executed *each time* the
+  /// parent operation is executed.
+  Optional<int64_t> getNumberOfExecutions() const;
+
+  /// Returns the number of times this block will be executed if the parent
+  /// region is invoked `numberOfRegionInvocations` times. This can be 
diff erent
+  /// from the number of region invocations by the parent operation.
+  Optional<int64_t>
+  getNumberOfExecutions(int64_t numberOfRegionInvocations) const;
+
+  Block *getBlock() const { return block; }
+
+private:
+  Block *block;
+
+  /// Number of `block` parent region invocations *each time* parent operation
+  /// is executed.
+  Optional<int64_t> numberOfRegionInvocations;
+
+  /// Number of `block` executions *each time* parent region is invoked.
+  Optional<int64_t> numberOfBlockExecutions;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H

diff  --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index 1519ccd1bdfc..28790557f36b 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index a1e489007f0f..7fad5ce48214 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -14,6 +14,7 @@
 #define ASYNC_OPS
 
 include "mlir/Dialect/Async/IR/AsyncBase.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -26,6 +27,8 @@ class Async_Op<string mnemonic, list<OpTrait> traits = []> :
 
 def Async_ExecuteOp :
   Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
+                       DeclareOpInterfaceMethods<RegionBranchOpInterface,
+                                                 ["getNumRegionInvocations"]>,
                        AttrSizedOperandSegments]> {
   let summary = "Asynchronous execute operation";
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 1dc6ef1f68a4..f30e8f20c4b7 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -196,6 +196,11 @@ def ForOp : SCF_Op<"for",
     /// induction variable. LoopOp only has one region, so 0 is the only valid
     /// value for `index`.
     OperandRange getSuccessorEntryOperands(unsigned index);
+
+    /// Returns the number of invocations of the body block if the loop bounds
+    /// are constants. Returns `kUnknownNumRegionInvocations` otherwise.
+    void getNumRegionInvocations(ArrayRef<Attribute> operands,
+                                 SmallVectorImpl<int64_t> &countPerRegion);
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 725e13b8b9d2..57ff63fbf784 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -41,6 +41,9 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
+// A constant value to represent unknown number of region invocations.
+extern const int64_t kUnknownNumRegionInvocations;
+
 namespace detail {
 /// Verify that types match along control flow edges described the given op.
 LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 217ab263877f..3c568a0e776f 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -130,6 +130,26 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       "void", "getSuccessorRegions",
       (ins "Optional<unsigned>":$index, "ArrayRef<Attribute>":$operands,
            "SmallVectorImpl<RegionSuccessor> &":$regions)
+    >,
+    InterfaceMethod<[{
+        Populates countPerRegion with the number of times this operation will
+        invoke the attached regions (assuming the regions yield normally, i.e.
+        do not abort or invoke an infinite loop). If the number of region
+        invocations is not known statically it will set the number of
+        invocations to `kUnknownNumRegionInvocations`.
+
+        `operands` is a set of optional attributes that either correspond to a
+        constant values for each operand of this operation, or null if that
+        operand is not a constant.
+      }],
+      "void", "getNumRegionInvocations",
+      (ins "ArrayRef<Attribute>":$operands,
+           "SmallVectorImpl<int64_t> &":$countPerRegion), [{}],
+      /*defaultImplementation=*/[{
+        unsigned numRegions = this->getOperation()->getNumRegions();
+        assert(countPerRegion.empty());
+        countPerRegion.resize(numRegions, kUnknownNumRegionInvocations);
+      }]
     >
   ];
 

diff  --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h
index 98bdacedf657..65429e5f0265 100644
--- a/mlir/include/mlir/Support/MathExtras.h
+++ b/mlir/include/mlir/Support/MathExtras.h
@@ -19,19 +19,21 @@
 namespace mlir {
 
 /// Returns the result of MLIR's ceildiv operation on constants. The RHS is
-/// expected to be positive.
+/// expected to be non-zero.
 inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
-  assert(rhs >= 1);
+  assert(rhs != 0);
   // C/C++'s integer division rounds towards 0.
-  return lhs % rhs > 0 ? lhs / rhs + 1 : lhs / rhs;
+  int64_t x = (rhs > 0) ? -1 : 1;
+  return (lhs * rhs > 0) ? ((lhs + x) / rhs) + 1 : -(-lhs / rhs);
 }
 
 /// Returns the result of MLIR's floordiv operation on constants. The RHS is
-/// expected to be positive.
+/// expected to be non-zero.
 inline int64_t floorDiv(int64_t lhs, int64_t rhs) {
-  assert(rhs >= 1);
+  assert(rhs != 0);
   // C/C++'s integer division rounds towards 0.
-  return lhs % rhs < 0 ? lhs / rhs - 1 : lhs / rhs;
+  int64_t x = (rhs < 0) ? 1 : -1;
+  return (lhs * rhs < 0) ? -((-lhs + x) / rhs) - 1 : lhs / rhs;
 }
 
 /// Returns MLIR's mod operation on constants. MLIR's mod operation yields the

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 632310bc5f13..3247ef1f56b0 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
   Liveness.cpp
   LoopAnalysis.cpp
   NestedMatcher.cpp
+  NumberOfExecutions.cpp
   PresburgerSet.cpp
   SliceAnalysis.cpp
   Utils.cpp
@@ -15,6 +16,7 @@ add_mlir_library(MLIRAnalysis
   BufferAliasAnalysis.cpp
   CallGraph.cpp
   Liveness.cpp
+  NumberOfExecutions.cpp
   SliceAnalysis.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -53,5 +55,5 @@ add_mlir_library(MLIRLoopAnalysis
   MLIRPresburger
   MLIRSCF
   )
-  
+
 add_subdirectory(Presburger)

diff  --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp
new file mode 100644
index 000000000000..425936b5eaaf
--- /dev/null
+++ b/mlir/lib/Analysis/NumberOfExecutions.cpp
@@ -0,0 +1,243 @@
+//===- NumberOfExecutions.cpp - Number of executions analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of the number of executions analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/NumberOfExecutions.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "number-of-executions-analysis"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// NumberOfExecutions
+//===----------------------------------------------------------------------===//
+
+/// Computes blocks number of executions information for the given region.
+static void computeRegionBlockNumberOfExecutions(
+    Region &region, DenseMap<Block *, BlockNumberOfExecutionsInfo> &blockInfo) {
+  Operation *parentOp = region.getParentOp();
+  int regionId = region.getRegionNumber();
+
+  auto regionKindInterface = dyn_cast<RegionKindInterface>(parentOp);
+  bool isGraphRegion =
+      regionKindInterface &&
+      regionKindInterface.getRegionKind(regionId) == RegionKind::Graph;
+
+  // CFG analysis does not make sense for Graph regions, set the number of
+  // executions for all blocks as unknown.
+  if (isGraphRegion) {
+    for (Block &block : region)
+      blockInfo.insert({&block, {&block, None, None}});
+    return;
+  }
+
+  // Number of region invocations for all attached regions.
+  SmallVector<int64_t, 4> numRegionsInvocations;
+
+  // Query RegionBranchOpInterface interface if it is available.
+  if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)) {
+    SmallVector<Attribute, 4> operands(parentOp->getNumOperands());
+    for (auto operandIt : llvm::enumerate(parentOp->getOperands()))
+      matchPattern(operandIt.value(), m_Constant(&operands[operandIt.index()]));
+
+    regionInterface.getNumRegionInvocations(operands, numRegionsInvocations);
+  }
+
+  // Number of region invocations *each time* parent operation is invoked.
+  Optional<int64_t> numRegionInvocations;
+
+  if (!numRegionsInvocations.empty() &&
+      numRegionsInvocations[regionId] != kUnknownNumRegionInvocations) {
+    numRegionInvocations = numRegionsInvocations[regionId];
+  }
+
+  // DFS traversal looking for loops in the CFG.
+  llvm::SmallSet<Block *, 4> loopStart;
+
+  llvm::unique_function<void(Block *, llvm::SmallSet<Block *, 4> &)> dfs =
+      [&](Block *block, llvm::SmallSet<Block *, 4> &visited) {
+        // Found a loop in the CFG that starts at the `block`.
+        if (visited.contains(block)) {
+          loopStart.insert(block);
+          return;
+        }
+
+        // Continue DFS traversal.
+        visited.insert(block);
+        for (Block *successor : block->getSuccessors())
+          dfs(successor, visited);
+        visited.erase(block);
+      };
+
+  llvm::SmallSet<Block *, 4> visited;
+  dfs(&region.front(), visited);
+
+  // Start from the entry block and follow only blocks with single succesor.
+  Block *block = &region.front();
+  while (block && !loopStart.contains(block)) {
+    // Block will be executed exactly once.
+    blockInfo.insert(
+        {block, BlockNumberOfExecutionsInfo(block, numRegionInvocations,
+                                            /*numberOfBlockExecutions=*/1)});
+
+    // We reached the exit block or block with multiple successors.
+    if (block->getNumSuccessors() != 1)
+      break;
+
+    // Continue traversal.
+    block = block->getSuccessor(0);
+  }
+
+  // For all blocks that we did not visit set the executions number to unknown.
+  for (Block &block : region)
+    if (blockInfo.count(&block) == 0)
+      blockInfo.insert({&block, BlockNumberOfExecutionsInfo(
+                                    &block, numRegionInvocations,
+                                    /*numberOfBlockExecutions=*/None)});
+}
+
+/// Creates a new NumberOfExecutions analysis that computes how many times a
+/// block within a region is executed for all associated regions.
+NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
+  operation->walk([&](Region *region) {
+    computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
+  });
+}
+
+Optional<int64_t>
+NumberOfExecutions::getNumberOfExecutions(Operation *op,
+                                          Region *perEntryOfThisRegion) const {
+  // Assuming that all operations complete in a finite amount of time (do not
+  // abort and do not go into the infinite loop), the number of operation
+  // executions is equal to the number of block executions that contains the
+  // operation.
+  return getNumberOfExecutions(op->getBlock(), perEntryOfThisRegion);
+}
+
+Optional<int64_t>
+NumberOfExecutions::getNumberOfExecutions(Block *block,
+                                          Region *perEntryOfThisRegion) const {
+  // Return None if the given `block` does not lie inside the
+  // `perEntryOfThisRegion` region.
+  if (!perEntryOfThisRegion->findAncestorBlockInRegion(*block))
+    return None;
+
+  // Find the block information for the given `block.
+  auto blockIt = blockNumbersOfExecution.find(block);
+  if (blockIt == blockNumbersOfExecution.end())
+    return None;
+  const auto &blockInfo = blockIt->getSecond();
+
+  // Override the number of region invocations with `1` if the
+  // `perEntryOfThisRegion` region owns the block.
+  auto getNumberOfExecutions = [&](const BlockNumberOfExecutionsInfo &info) {
+    if (info.getBlock()->getParent() == perEntryOfThisRegion)
+      return info.getNumberOfExecutions(/*numberOfRegionInvocations=*/1);
+    return info.getNumberOfExecutions();
+  };
+
+  // Immediately return None if we do not know the block number of executions.
+  auto blockExecutions = getNumberOfExecutions(blockInfo);
+  if (!blockExecutions.hasValue())
+    return None;
+
+  // Follow parent operations until we reach the operations that owns the
+  // `perEntryOfThisRegion`.
+  int64_t numberOfExecutions = *blockExecutions;
+  Operation *parentOp = block->getParentOp();
+
+  while (parentOp != perEntryOfThisRegion->getParentOp()) {
+    // Find how many times will be executed the block that owns the parent
+    // operation.
+    Block *parentBlock = parentOp->getBlock();
+
+    auto parentBlockIt = blockNumbersOfExecution.find(parentBlock);
+    if (parentBlockIt == blockNumbersOfExecution.end())
+      return None;
+    const auto &parentBlockInfo = parentBlockIt->getSecond();
+    auto parentBlockExecutions = getNumberOfExecutions(parentBlockInfo);
+
+    // We stumbled upon an operation with unknown number of executions.
+    if (!parentBlockExecutions.hasValue())
+      return None;
+
+    // Number of block executions is a product of all parent blocks executions.
+    numberOfExecutions *= *parentBlockExecutions;
+    parentOp = parentOp->getParentOp();
+
+    assert(parentOp != nullptr);
+  }
+
+  return numberOfExecutions;
+}
+
+void NumberOfExecutions::printBlockExecutions(
+    raw_ostream &os, Region *perEntryOfThisRegion) const {
+  unsigned blockId = 0;
+
+  operation->walk([&](Block *block) {
+    llvm::errs() << "Block: " << blockId++ << "\n";
+    llvm::errs() << "Number of executions: ";
+    if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
+      llvm::errs() << *n << "\n";
+    else
+      llvm::errs() << "<unknown>\n";
+  });
+}
+
+void NumberOfExecutions::printOperationExecutions(
+    raw_ostream &os, Region *perEntryOfThisRegion) const {
+  operation->walk([&](Block *block) {
+    block->walk([&](Operation *operation) {
+      // Skip the operation that was used to build the analysis.
+      if (operation == this->operation)
+        return;
+
+      llvm::errs() << "Operation: " << operation->getName() << "\n";
+      llvm::errs() << "Number of executions: ";
+      if (auto n = getNumberOfExecutions(operation, perEntryOfThisRegion))
+        llvm::errs() << *n << "\n";
+      else
+        llvm::errs() << "<unknown>\n";
+    });
+  });
+}
+
+//===----------------------------------------------------------------------===//
+// BlockNumberOfExecutionsInfo
+//===----------------------------------------------------------------------===//
+
+BlockNumberOfExecutionsInfo::BlockNumberOfExecutionsInfo(
+    Block *block, Optional<int64_t> numberOfRegionInvocations,
+    Optional<int64_t> numberOfBlockExecutions)
+    : block(block), numberOfRegionInvocations(numberOfRegionInvocations),
+      numberOfBlockExecutions(numberOfBlockExecutions) {}
+
+Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions() const {
+  if (numberOfRegionInvocations && numberOfBlockExecutions)
+    return *numberOfRegionInvocations * *numberOfBlockExecutions;
+  return None;
+}
+
+Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions(
+    int64_t numberOfRegionInvocations) const {
+  if (numberOfBlockExecutions)
+    return numberOfRegionInvocations * *numberOfBlockExecutions;
+  return None;
+}

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 36ef1af8cfa7..0c22ea22a9da 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -118,6 +118,27 @@ static LogicalResult verify(YieldOp op) {
 
 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
 
+void ExecuteOp::getNumRegionInvocations(
+    ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
+  (void)operands;
+  assert(countPerRegion.empty());
+  countPerRegion.push_back(1);
+}
+
+void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
+                                    ArrayRef<Attribute> operands,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `body` region branch back to the parent operation.
+  if (index.hasValue()) {
+    assert(*index == 0);
+    regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  // Otherwise the successor is the body region.
+  regions.push_back(RegionSuccessor(&body()));
+}
+
 static void print(OpAsmPrinter &p, ExecuteOp op) {
   p << op.getOperationName();
 

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index ea607a3a402a..bc8671b9ba85 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 using namespace mlir;
@@ -284,6 +285,26 @@ void ForOp::getSuccessorRegions(Optional<unsigned> index,
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
+                                    SmallVectorImpl<int64_t> &countPerRegion) {
+  assert(countPerRegion.empty());
+  countPerRegion.resize(1);
+
+  auto lb = operands[0].dyn_cast_or_null<IntegerAttr>();
+  auto ub = operands[1].dyn_cast_or_null<IntegerAttr>();
+  auto step = operands[2].dyn_cast_or_null<IntegerAttr>();
+
+  // Loop bounds are not known statically.
+  if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
+    countPerRegion[0] = -1;
+    return;
+  }
+
+  countPerRegion[0] =
+      ceilDiv(ub.getValue().getSExtValue() - lb.getValue().getSExtValue(),
+              step.getValue().getSExtValue());
+}
+
 ValueVector mlir::scf::buildLoopNest(
     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
     ValueRange steps, ValueRange iterArgs,

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 99e86bcb5057..5b845584b2e8 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -73,6 +73,9 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
+// A constant value to represent unknown number of region invocations.
+const int64_t mlir::kUnknownNumRegionInvocations = -1;
+
 /// Verify that types match along all region control flow edges originating from
 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
 /// op). `getInputsTypesForRegion` is a function that returns the types of the

diff  --git a/mlir/test/Analysis/test-number-of-block-executions.mlir b/mlir/test/Analysis/test-number-of-block-executions.mlir
new file mode 100644
index 000000000000..d6ac9d2709a4
--- /dev/null
+++ b/mlir/test/Analysis/test-number-of-block-executions.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s                                                            \
+// RUN:     -test-print-number-of-block-executions                             \
+// RUN:     -split-input-file 2>&1                                             \
+// RUN:   | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: Number of executions: empty
+func @empty() {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: sequential
+func @sequential() {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  br ^bb1
+^bb1:
+  // CHECK: Block: 1
+  // CHECK-NEXT: Number of executions: 1
+  br ^bb2
+^bb2:
+  // CHECK: Block: 2
+  // CHECK-NEXT: Number of executions: 1
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: conditional
+func @conditional(%cond : i1) {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  br ^bb1
+^bb1:
+  // CHECK: Block: 1
+  // CHECK-NEXT: Number of executions: 1
+  cond_br %cond, ^bb2, ^bb3
+^bb2:
+  // CHECK: Block: 2
+  // CHECK-NEXT: Number of executions: <unknown>
+  br ^bb4
+^bb3:
+  // CHECK: Block: 3
+  // CHECK-NEXT: Number of executions: <unknown>
+  br ^bb4
+^bb4:
+  // CHECK: Block: 4
+  // CHECK-NEXT: Number of executions: <unknown>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: loop
+func @loop(%cond : i1) {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  br ^bb1
+^bb1:
+  // CHECK: Block: 1
+  // CHECK-NEXT: Number of executions: <unknown>
+  br ^bb2
+^bb2:
+  // CHECK: Block: 2
+  // CHECK-NEXT: Number of executions: <unknown>
+  br ^bb3
+^bb3:
+  // CHECK: Block: 3
+  // CHECK-NEXT: Number of executions: <unknown>
+  cond_br %cond, ^bb1, ^bb4
+^bb4:
+  // CHECK: Block: 4
+  // CHECK-NEXT: Number of executions: <unknown>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: scf_if_dynamic_branch
+func @scf_if_dynamic_branch(%cond : i1) {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  scf.if %cond {
+    // CHECK: Block: 1
+    // CHECK-NEXT: Number of executions: <unknown>
+  } else {
+    // CHECK: Block: 2
+    // CHECK-NEXT: Number of executions: <unknown>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: async_execute
+func @async_execute() {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  async.execute {
+   // CHECK: Block: 1
+   // CHECK-NEXT: Number of executions: 1
+    async.yield
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: async_execute_with_scf_if
+func @async_execute_with_scf_if(%cond : i1) {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  async.execute {
+    // CHECK: Block: 1
+    // CHECK-NEXT: Number of executions: 1
+    scf.if %cond {
+    // CHECK: Block: 2
+    // CHECK-NEXT: Number of executions: <unknown>
+    } else {
+    // CHECK: Block: 3
+    // CHECK-NEXT: Number of executions: <unknown>
+    }
+    async.yield
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: scf_for_constant_bounds
+func @scf_for_constant_bounds() {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+
+  scf.for %i = %c0 to %c2 step %c1 {
+    // CHECK: Block: 1
+    // CHECK-NEXT: Number of executions: 2
+  }
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: propagate_parent_num_executions
+func @propagate_parent_num_executions() {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+
+  scf.for %i = %c0 to %c2 step %c1 {
+    // CHECK: Block: 1
+    // CHECK-NEXT: Number of executions: 2
+    async.execute {
+      // CHECK: Block: 2
+      // CHECK-NEXT: Number of executions: 2
+      async.yield
+    }
+  }
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: clear_num_executions
+func @clear_num_executions(%step : index) {
+  // CHECK: Block: 0
+  // CHECK-NEXT: Number of executions: 1
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+
+  scf.for %i = %c0 to %c2 step %step {
+    // CHECK: Block: 1
+    // CHECK-NEXT: Number of executions: <unknown>
+    async.execute {
+      // CHECK: Block: 2
+      // CHECK-NEXT: Number of executions: <unknown>
+      async.yield
+    }
+  }
+
+  return
+}

diff  --git a/mlir/test/Analysis/test-number-of-operation-executions.mlir b/mlir/test/Analysis/test-number-of-operation-executions.mlir
new file mode 100644
index 000000000000..60e9a8a86f42
--- /dev/null
+++ b/mlir/test/Analysis/test-number-of-operation-executions.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s                                                            \
+// RUN:     -test-print-number-of-operation-executions                         \
+// RUN:     -split-input-file 2>&1                                             \
+// RUN:   | FileCheck %s
+
+// CHECK-LABEL: Number of executions: empty
+func @empty() {
+  // CHECK: Operation: std.return
+  // CHECK-NEXT: Number of executions: 1
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: propagate_parent_num_executions
+func @propagate_parent_num_executions() {
+  // CHECK: Operation: std.constant
+  // CHECK-NEXT: Number of executions: 1
+  %c0 = constant 0 : index
+  // CHECK: Operation: std.constant
+  // CHECK-NEXT: Number of executions: 1
+  %c1 = constant 1 : index
+  // CHECK: Operation: std.constant
+  // CHECK-NEXT: Number of executions: 1
+  %c2 = constant 2 : index
+
+  // CHECK-DAG: Operation: scf.for
+  // CHECK-NEXT: Number of executions: 1
+  scf.for %i = %c0 to %c2 step %c1 {
+    // CHECK-DAG: Operation: async.execute
+    // CHECK-NEXT: Number of executions: 2
+    async.execute {
+      // CHECK-DAG: Operation: async.yield
+      // CHECK-NEXT: Number of executions: 2
+      async.yield
+    }
+  }
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Number of executions: clear_num_executions
+func @clear_num_executions(%step : index) {
+  // CHECK: Operation: std.constant
+  // CHECK-NEXT: Number of executions: 1
+  %c0 = constant 0 : index
+  // CHECK: Operation: std.constant
+  // CHECK-NEXT: Number of executions: 1
+  %c2 = constant 2 : index
+
+  // CHECK: Operation: scf.for
+  // CHECK-NEXT: Number of executions: 1
+  scf.for %i = %c0 to %c2 step %step {
+    // CHECK: Operation: async.execute
+    // CHECK-NEXT: Number of executions: <unknown>
+    async.execute {
+      // CHECK: Operation: async.yield
+      // CHECK-NEXT: Number of executions: <unknown>
+      async.yield
+    }
+  }
+
+  return
+}

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 249addc61b9b..69d45b570a3c 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_library(MLIRTestTransforms
   TestLoopMapping.cpp
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
+  TestNumberOfExecutions.cpp
   TestOpaqueLoc.cpp
   TestMemRefBoundCheck.cpp
   TestMemRefDependenceCheck.cpp

diff  --git a/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp b/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp
new file mode 100644
index 000000000000..908e596ac158
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp
@@ -0,0 +1,57 @@
+//===- TestNumberOfExecutions.cpp - Test number of executions analysis ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for constructing and resolving number of
+// executions information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/NumberOfExecutions.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestNumberOfBlockExecutionsPass
+    : public PassWrapper<TestNumberOfBlockExecutionsPass, FunctionPass> {
+  void runOnFunction() override {
+    llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
+    getAnalysis<NumberOfExecutions>().printBlockExecutions(
+        llvm::errs(), &getFunction().getBody());
+  }
+};
+
+struct TestNumberOfOperationExecutionsPass
+    : public PassWrapper<TestNumberOfOperationExecutionsPass, FunctionPass> {
+  void runOnFunction() override {
+    llvm::errs() << "Number of executions: " << getFunction().getName() << "\n";
+    getAnalysis<NumberOfExecutions>().printOperationExecutions(
+        llvm::errs(), &getFunction().getBody());
+  }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestNumberOfBlockExecutionsPass() {
+  PassRegistration<TestNumberOfBlockExecutionsPass>(
+      "test-print-number-of-block-executions",
+      "Print the contents of a constructed number of executions analysis for "
+      "all blocks.");
+}
+
+void registerTestNumberOfOperationExecutionsPass() {
+  PassRegistration<TestNumberOfOperationExecutionsPass>(
+      "test-print-number-of-operation-executions",
+      "Print the contents of a constructed number of executions analysis for "
+      "all operations.");
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8f88ae35409d..58444e6a9501 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -81,6 +81,8 @@ void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestNumberOfBlockExecutionsPass();
+void registerTestNumberOfOperationExecutionsPass();
 void registerTestOpaqueLoc();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
@@ -145,6 +147,8 @@ void registerTestPasses() {
   test::registerTestLoopUnrollingPass();
   test::registerTestMemRefDependenceCheck();
   test::registerTestMemRefStrideCalculation();
+  test::registerTestNumberOfBlockExecutionsPass();
+  test::registerTestNumberOfOperationExecutionsPass();
   test::registerTestOpaqueLoc();
   test::registerTestRecursiveTypesPass();
   test::registerTestSCFUtilsPass();

diff  --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index cec3c2d57386..2bc673336ba5 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRSupportTests
   IndentedOstreamTest.cpp
+  MathExtrasTest.cpp
 )
 
 target_link_libraries(MLIRSupportTests

diff  --git a/mlir/unittests/Support/MathExtrasTest.cpp b/mlir/unittests/Support/MathExtrasTest.cpp
new file mode 100644
index 000000000000..304b353d4167
--- /dev/null
+++ b/mlir/unittests/Support/MathExtrasTest.cpp
@@ -0,0 +1,27 @@
+//===- MathExtrasTest.cpp - MathExtras Tests ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/MathExtras.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+using ::testing::Eq;
+
+TEST(MathExtrasTest, CeilDivTest) {
+  EXPECT_THAT(ceilDiv(14, 3), Eq(5));
+  EXPECT_THAT(ceilDiv(14, -3), Eq(-4));
+  EXPECT_THAT(ceilDiv(-14, -3), Eq(5));
+  EXPECT_THAT(ceilDiv(-14, 3), Eq(-4));
+}
+
+TEST(MathExtrasTest, FloorDivTest) {
+  EXPECT_THAT(floorDiv(14, 3), Eq(4));
+  EXPECT_THAT(floorDiv(14, -3), Eq(-5));
+  EXPECT_THAT(floorDiv(-14, -3), Eq(4));
+  EXPECT_THAT(floorDiv(-14, 3), Eq(-5));
+}


        


More information about the Mlir-commits mailing list