[Mlir-commits] [mlir] ee70039 - [mlir] Fix handling of some region branch terminator successors

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 8 10:17:07 PDT 2022


Author: Mogball
Date: 2022-06-08T17:17:03Z
New Revision: ee70039ae27ae3a3db4aec65b044f24080d328ea

URL: https://github.com/llvm/llvm-project/commit/ee70039ae27ae3a3db4aec65b044f24080d328ea
DIFF: https://github.com/llvm/llvm-project/commit/ee70039ae27ae3a3db4aec65b044f24080d328ea.diff

LOG: [mlir] Fix handling of some region branch terminator successors

When `RegionBranchOpInterface::getSuccessorRegions` is called for anything other than the parent op, it expects the operands of the terminator of the source region to be passed, not the operands of the parent op. This was not always respected.

This fixes a bug in integer range inference and ForwardDataFlowSolver and changes `scf.while` to allow narrowing of successors using constant inputs.

Fixes #55873

Reviewed By: mehdi_amini, krzysz00

Differential Revision: https://reviews.llvm.org/D127261

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/DataFlowAnalysis.cpp
    mlir/lib/Analysis/IntRangeAnalysis.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/test/Dialect/SCF/invalid.mlir
    mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
    mlir/test/Transforms/sccp-structured.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 3b11d7ede650d..a74714481db74 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -964,7 +964,7 @@ def WhileOp : SCF_Op<"while",
 
   let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
-  let hasRegionVerifier = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 9266bc33b4924..e74f1f7517a24 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -207,10 +207,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let extraClassDeclaration = [{
     /// Convenience helper in case none of the operands is known.
     void getSuccessorRegions(Optional<unsigned> index,
-                             SmallVectorImpl<RegionSuccessor> &regions) {
-       SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
-       getSuccessorRegions(index, nullAttrs, regions);
-    }
+                             SmallVectorImpl<RegionSuccessor> &regions);
 
     /// Return `true` if control flow originating from the given region may
     /// eventually branch back to the same region. (Maybe after passing through

diff  --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index 239d9e4060bca..b9c963cdb5743 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -576,45 +576,12 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
       return;
 
-    // If the branch is a RegionBranchTerminatorOpInterface,
-    // construct the set of operand lattices as the set of non control-flow
-    // arguments of the parent and the values this op returns. This allows
-    // for the correct lattices to be passed to getSuccessorsForOperands()
-    // in cases such as scf.while.
-    ArrayRef<AbstractLatticeElement *> branchOpLattices = operandLattices;
-    SmallVector<AbstractLatticeElement *, 0> parentLattices;
-    if (auto regionTerminator =
-            dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
-      parentLattices.reserve(regionInterface->getNumOperands());
-      for (Value parentOperand : regionInterface->getOperands()) {
-        AbstractLatticeElement *operandLattice =
-            analysis.lookupLatticeElement(parentOperand);
-        if (!operandLattice || operandLattice->isUninitialized())
-          return;
-        parentLattices.push_back(operandLattice);
-      }
-      unsigned regionNumber = parentRegion->getRegionNumber();
-      OperandRange iterArgs =
-          regionInterface.getSuccessorEntryOperands(regionNumber);
-      OperandRange terminatorArgs =
-          regionTerminator.getSuccessorOperands(regionNumber);
-      assert(iterArgs.size() == terminatorArgs.size() &&
-             "Number of iteration arguments for region should equal number of "
-             "those arguments defined by terminator");
-      if (!iterArgs.empty()) {
-        unsigned iterStart = iterArgs.getBeginOperandIndex();
-        unsigned terminatorStart = terminatorArgs.getBeginOperandIndex();
-        for (unsigned i = 0, e = iterArgs.size(); i < e; ++i)
-          parentLattices[iterStart + i] = operandLattices[terminatorStart + i];
-      }
-      branchOpLattices = parentLattices;
-    }
     // Query the set of successors of the current region using the current
     // optimistic lattice state.
     SmallVector<RegionSuccessor, 1> regionSuccessors;
     analysis.getSuccessorsForOperands(regionInterface,
                                       parentRegion->getRegionNumber(),
-                                      branchOpLattices, regionSuccessors);
+                                      operandLattices, regionSuccessors);
     if (regionSuccessors.empty())
       return;
 
@@ -622,11 +589,11 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
     // propagate the operand states to the successors.
     if (isRegionReturnLike(op)) {
       auto getOperands = [&](Optional<unsigned> regionIndex) {
-        // Determine the individual region  successor operands for the given
+        // Determine the individual region successor operands for the given
         // region index (if any).
         return *getRegionBranchSuccessorOperands(op, regionIndex);
       };
-      return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices,
+      return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
                                    getOperands);
     }
 

diff  --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
index 7e6d61ff89560..fc01607c92ee3 100644
--- a/mlir/lib/Analysis/IntRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp
@@ -214,12 +214,24 @@ void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
     RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
     ArrayRef<LatticeElement<IntRangeLattice> *> operands,
     SmallVectorImpl<RegionSuccessor> &successors) {
-  auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
-    Optional<APInt> maybeConstValue =
-        enumPair.value()->getValue().value.getConstantValue();
+  // Get a type with which to construct a constant.
+  auto getOperandType = [branch, sourceIndex](unsigned index) {
+    // The types of all return-like operations are the same.
+    if (!sourceIndex)
+      return branch->getOperand(index).getType();
+
+    for (Block &block : branch->getRegion(*sourceIndex)) {
+      Operation *terminator = block.getTerminator();
+      if (getRegionBranchSuccessorOperands(terminator, *sourceIndex))
+        return terminator->getOperand(index).getType();
+    }
+    return Type();
+  };
 
-    if (maybeConstValue) {
-      return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+  auto toConstantAttr = [&getOperandType](auto enumPair) -> Attribute {
+    if (Optional<APInt> maybeConstValue =
+            enumPair.value()->getValue().value.getConstantValue()) {
+      return IntegerAttr::get(getOperandType(enumPair.index()),
                               *maybeConstValue);
     }
     return {};

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 2460e535d3ad5..f2673a519378f 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2631,21 +2631,26 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
                                   ArrayRef<Attribute> operands,
                                   SmallVectorImpl<RegionSuccessor> &regions) {
-  (void)operands;
-
+  // The parent op always branches to the condition region.
   if (!index.hasValue()) {
     regions.emplace_back(&getBefore(), getBefore().getArguments());
     return;
   }
 
   assert(*index < 2 && "there are only two regions in a WhileOp");
-  if (*index == 0) {
-    regions.emplace_back(&getAfter(), getAfter().getArguments());
-    regions.emplace_back(getResults());
+  // The body region always branches back to the condition region.
+  if (*index == 1) {
+    regions.emplace_back(&getBefore(), getBefore().getArguments());
     return;
   }
 
-  regions.emplace_back(&getBefore(), getBefore().getArguments());
+  // Try to narrow the successor to the condition region.
+  assert(!operands.empty() && "expected at least one operand");
+  auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
+  if (!cond || !cond.getValue())
+    regions.emplace_back(getResults());
+  if (!cond || cond.getValue())
+    regions.emplace_back(&getAfter(), getAfter().getArguments());
 }
 
 /// Parses a `while` op.
@@ -2745,7 +2750,7 @@ static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
   return nullptr;
 }
 
-LogicalResult scf::WhileOp::verifyRegions() {
+LogicalResult scf::WhileOp::verify() {
   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
       *this, getBefore(),
       "expects the 'before' region to terminate with 'scf.condition'");

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index a8068b482e766..7336fd4a6bc6a 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,7 +9,6 @@
 #include <utility>
 
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
 using namespace mlir;
@@ -97,15 +96,7 @@ verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
   SmallVector<RegionSuccessor, 2> successors;
-  unsigned numInputs;
-  if (sourceNo) {
-    Region &srcRegion = op->getRegion(sourceNo.getValue());
-    numInputs = srcRegion.getNumArguments();
-  } else {
-    numInputs = op->getNumOperands();
-  }
-  SmallVector<Attribute, 2> operands(numInputs, nullptr);
-  regionInterface.getSuccessorRegions(sourceNo, operands, successors);
+  regionInterface.getSuccessorRegions(sourceNo, successors);
 
   for (RegionSuccessor &succ : successors) {
     Optional<unsigned> succRegionNo;
@@ -327,6 +318,27 @@ bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
   return isRegionReachable(region, region);
 }
 
+void RegionBranchOpInterface::getSuccessorRegions(
+    Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  unsigned numInputs = 0;
+  if (index) {
+    // If the predecessor is a region, get the number of operands from an
+    // exiting terminator in the region.
+    for (Block &block : getOperation()->getRegion(*index)) {
+      Operation *terminator = block.getTerminator();
+      if (getRegionBranchSuccessorOperands(terminator, *index)) {
+        numInputs = terminator->getNumOperands();
+        break;
+      }
+    }
+  } else {
+    // Otherwise, use the number of parent operation operands.
+    numInputs = getOperation()->getNumOperands();
+  }
+  SmallVector<Attribute, 2> operands(numInputs, nullptr);
+  getSuccessorRegions(index, operands, regions);
+}
+
 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
   while (Region *region = op->getParentRegion()) {
     op = region->getParentOp();

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 6ba30f2c2e8b6..c251f97420748 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @while_cross_region_type_mismatch() {
 func.func @while_cross_region_type_mismatch() {
   %true = arith.constant true
   // expected-error at +1 {{'scf.while' op  along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}}
-  scf.while : () -> () {
+  %0 = scf.while : () -> (i1) {
     scf.condition(%true) %true : i1
   } do {
   ^bb0(%arg0: i32):

diff  --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 45d506d00d65d..f9c551c0b9929 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -100,3 +100,22 @@ func.func @func_args_unbound(%arg0 : index) -> index {
   %0 = test.reflect_bounds %arg0
   func.return %0 : index
 }
+
+// CHECK-LABEL: func @propagate_across_while_loop()
+func.func @propagate_across_while_loop() -> index {
+  // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
+  // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
+  %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                          smin = 0 : index, smax = 0 : index }
+  %1 = scf.while : () -> index {
+    %true = arith.constant true
+    // CHECK: scf.condition(%{{.*}}) %[[C0]]
+    scf.condition(%true) %0 : index
+  } do {
+  ^bb0(%i1: index):
+    scf.yield
+  }
+  // CHECK: return %[[C1]]
+  %2 = test.increment %1
+  return %2 : index
+}

diff  --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir
index b2e21a6692d1d..d3249d9fe8296 100644
--- a/mlir/test/Transforms/sccp-structured.mlir
+++ b/mlir/test/Transforms/sccp-structured.mlir
@@ -179,3 +179,43 @@ func.func @affine_loop_zero_iter(%arg0 : index, %arg1 : index, %arg2 : index) ->
   // CHECK: return %[[C0]] : i32
   return %s0 : i32
 }
+
+// CHECK-LABEL: func @while_loop_
diff erent_arg_count
+func.func @while_loop_
diff erent_arg_count() -> index {
+  // CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[WHILE:.*]] = scf.while
+  %0 = scf.while (%arg3 = %c0, %arg4 = %c1) : (index, index) -> index {
+    %1 = arith.cmpi slt, %arg3, %c1 : index
+    // CHECK: scf.condition(%[[TRUE]]) %[[C1]]
+    scf.condition(%1) %arg4 : index
+  } do {
+  ^bb0(%arg3: index):
+    %1 = arith.muli %arg3, %c1 : index
+    // CHECK: scf.yield %[[C0]], %[[C1]]
+    scf.yield %c0, %1 : index, index
+  }
+  // CHECK: return %[[WHILE]]
+  return %0 : index
+}
+
+// CHECK-LABEL: func @while_loop_false_condition
+func.func @while_loop_false_condition(%arg0 : index) -> index {
+  // CHECK: %[[C0:.*]] = arith.constant 0
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = arith.muli %arg0, %c0 : index
+  %1 = scf.while (%arg1 = %0) : (index) -> index {
+    %2 = arith.cmpi slt, %arg1, %c0 : index
+    scf.condition(%2) %arg1 : index
+  } do {
+  ^bb0(%arg2 : index):
+    %3 = arith.addi %arg2, %c1 : index
+    scf.yield %3 : index
+  }
+  // CHECK: return %[[C0]]
+  func.return %1 : index
+}


        


More information about the Mlir-commits mailing list