[Mlir-commits] [mlir] 537f220 - [mlir] Support getSuccessorInputs from parent op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 13 15:21:39 PDT 2022


Author: Mogball
Date: 2022-06-13T22:21:34Z
New Revision: 537f220891168d4feaebf37d44ae559b2393b8ad

URL: https://github.com/llvm/llvm-project/commit/537f220891168d4feaebf37d44ae559b2393b8ad
DIFF: https://github.com/llvm/llvm-project/commit/537f220891168d4feaebf37d44ae559b2393b8ad.diff

LOG: [mlir] Support getSuccessorInputs from parent op

Ops that implement `RegionBranchOpInterface` are allowed to indicate that they can branch back to themselves in `getSuccessorRegions`, but there is no API that allows them to specify the forwarded operands. This patch enables that by changing `getSuccessorEntryOperands` to accept `None`.

Fixes #54928

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127239

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
    mlir/lib/Analysis/DataFlowAnalysis.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/test/Analysis/test-alias-analysis.mlir
    mlir/test/Transforms/sccp-structured.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index a74714481db7..d15c51d4251e 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -309,7 +309,7 @@ def ForOp : SCF_Op<"for",
     /// correspond to the loop iterator operands, i.e., those exclusing the
     /// induction variable. LoopOp only has one region, so 0 is the only valid
     /// value for `index`.
-    OperandRange getSuccessorEntryOperands(unsigned index);
+    OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
   }];
 
   let hasCanonicalizer = 1;
@@ -955,7 +955,7 @@ def WhileOp : SCF_Op<"while",
   let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
 
   let extraClassDeclaration = [{
-    OperandRange getSuccessorEntryOperands(unsigned index);
+    OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
     ConditionOp getConditionOp();
     YieldOp getYieldOp();
     Block::BlockArgListType getBeforeArguments();

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index e74f1f7517a2..2fc4977e75db 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -134,12 +134,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     InterfaceMethod<[{
         Returns the operands of this operation used as the entry arguments when
         entering the region at `index`, which was specified as a successor of
-        this operation by `getSuccessorRegions`. These operands should
-        correspond 1-1 with the successor inputs specified in
+        this operation by `getSuccessorRegions`, or the operands forwarded to
+        the operation's results when it branches back to itself. These operands
+        should correspond 1-1 with the successor inputs specified in
         `getSuccessorRegions`.
       }],
       "::mlir::OperandRange", "getSuccessorEntryOperands",
-      (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{
+      (ins "::llvm::Optional<unsigned>":$index), [{}], 
+      /*defaultImplementation=*/[{
         auto operandEnd = this->getOperation()->operand_end();
         return ::mlir::OperandRange(operandEnd, operandEnd);
       }]

diff  --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 78eb0e414bdf..196312d4f6a5 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -78,12 +78,12 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
   if (region) {
     // Determine the actual region number from the passed region.
     regionIndex = region->getRegionNumber();
-    if (Optional<unsigned> operandIndex =
-            getOperandIndexIfPred(/*predIndex=*/llvm::None)) {
-      collectUnderlyingAddressValues(
-          branch.getSuccessorEntryOperands(*regionIndex)[*operandIndex],
-          maxDepth, visited, output);
-    }
+  }
+  if (Optional<unsigned> operandIndex =
+          getOperandIndexIfPred(/*predIndex=*/llvm::None)) {
+    collectUnderlyingAddressValues(
+        branch.getSuccessorEntryOperands(regionIndex)[*operandIndex], maxDepth,
+        visited, output);
   }
   // Check branches from each child region.
   Operation *op = branch.getOperation();

diff  --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index b9c963cdb574..073a23b88a6a 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -470,11 +470,10 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
   // also allow for the parent operation to have itself as a region successor.
   if (successors.empty())
     return markAllPessimisticFixpoint(branch, branch->getResults());
-  return visitRegionSuccessors(
-      branch, successors, operandLattices, [&](Optional<unsigned> index) {
-        assert(index && "expected valid region index");
-        return branch.getSuccessorEntryOperands(*index);
-      });
+  return visitRegionSuccessors(branch, successors, operandLattices,
+                               [&](Optional<unsigned> index) {
+                                 return branch.getSuccessorEntryOperands(index);
+                               });
 }
 
 void ForwardDataFlowSolver::visitRegionSuccessors(

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 85cb05a9168f..318b4909c31a 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1731,11 +1731,11 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// correspond to the loop iterator operands, i.e., those excluding the
 /// induction variable. AffineForOp only has one region, so zero is the only
 /// valid value for `index`.
-OperandRange AffineForOp::getSuccessorEntryOperands(unsigned index) {
-  assert(index == 0 && "invalid region index");
+OperandRange AffineForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  assert(!index || *index == 0 && "invalid region index");
 
   // The initial operands map to the loop arguments after the induction
-  // variable.
+  // variable or are forwarded to the results when the trip count is zero.
   return getIterOperands();
 }
 

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 656e75b30754..888b14c434ad 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -59,8 +59,8 @@ YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
 
 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
 
-OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
-  assert(index == 0 && "invalid region index");
+OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  assert(index && *index == 0 && "invalid region index");
   return operands();
 }
 

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index f2673a519378..c434a85142ae 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -473,8 +473,8 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
 /// correspond to the loop iterator operands, i.e., those excluding the
 /// induction variable. LoopOp only has one region, so 0 is the only valid value
 /// for `index`.
-OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
-  assert(index == 0 && "invalid region index");
+OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  assert(index && *index == 0 && "invalid region index");
 
   // The initial operands map to the loop arguments after the induction
   // variable.
@@ -2605,8 +2605,8 @@ LogicalResult ReduceReturnOp::verify() {
 // WhileOp
 //===----------------------------------------------------------------------===//
 
-OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
-  assert(index == 0 &&
+OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  assert(index && *index == 0 &&
          "WhileOp is expected to branch only to the first region");
 
   return getInits();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 9dadd269fa15..95d43601a6ea 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -312,8 +312,9 @@ void transform::SequenceOp::getEffects(
   }
 }
 
-OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) {
-  assert(index == 0 && "unexpected region index");
+OperandRange
+transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  assert(index && *index == 0 && "unexpected region index");
   if (getOperation()->getNumOperands() == 1)
     return getOperation()->getOperands();
   return OperandRange(getOperation()->operand_end(),

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 7336fd4a6bc6..991f16e23b7e 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include <utility>
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -151,16 +152,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
   auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
-    if (regionNo.hasValue()) {
-      return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
-          .getTypes();
-    }
-
-    // If the successor of a parent op is the parent itself
-    // RegionBranchOpInterface does not have an API to query what the entry
-    // operands will be in that case. Vend out the result types of the op in
-    // that case so that type checking succeeds for this case.
-    return op->getResultTypes();
+    return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
   };
 
   // Verify types along control flow edges originating from the parent.

diff  --git a/mlir/test/Analysis/test-alias-analysis.mlir b/mlir/test/Analysis/test-alias-analysis.mlir
index 1b9816e4d671..0e19282dbc9e 100644
--- a/mlir/test/Analysis/test-alias-analysis.mlir
+++ b/mlir/test/Analysis/test-alias-analysis.mlir
@@ -191,6 +191,31 @@ func.func @region_loop_control_flow(%arg: memref<2xf32>, %loopI0 : index,
 
 // -----
 
+// CHECK-LABEL: Testing : "region_loop_zero_trip_count"
+// CHECK-DAG: alloca_1#0 <-> alloca_2#0: NoAlias
+// CHECK-DAG: alloca_1#0 <-> for_alloca#0: MustAlias
+// CHECK-DAG: alloca_1#0 <-> for_alloca.region0#0: MayAlias
+// CHECK-DAG: alloca_1#0 <-> for_alloca.region0#1: MayAlias
+
+// CHECK-DAG: alloca_2#0 <-> for_alloca#0: NoAlias
+// CHECK-DAG: alloca_2#0 <-> for_alloca.region0#0: MayAlias
+// CHECK-DAG: alloca_2#0 <-> for_alloca.region0#1: MayAlias
+
+// CHECK-DAG: for_alloca#0 <-> for_alloca.region0#0: MayAlias
+// CHECK-DAG: for_alloca#0 <-> for_alloca.region0#1: MayAlias
+
+// CHECK-DAG: for_alloca.region0#0 <-> for_alloca.region0#1: MayAlias
+func.func @region_loop_zero_trip_count() attributes {test.ptr = "func"} {
+  %0 = memref.alloca() {test.ptr = "alloca_1"} : memref<i32>
+  %1 = memref.alloca() {test.ptr = "alloca_2"} : memref<i32>
+  %result = affine.for %i = 0 to 0 iter_args(%si = %0) -> (memref<i32>) {
+    affine.yield %si : memref<i32>
+  } {test.ptr = "for_alloca"}
+  return
+}
+
+// -----
+
 // CHECK-LABEL: Testing : "view_like"
 // CHECK-DAG: alloc_1#0 <-> view#0: NoAlias
 

diff  --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir
index d3249d9fe829..32af26d62301 100644
--- a/mlir/test/Transforms/sccp-structured.mlir
+++ b/mlir/test/Transforms/sccp-structured.mlir
@@ -154,7 +154,7 @@ func.func @loop_region_branch_terminator_op(%arg1 : i32) {
 /// interface as well.
 
 // CHECK-LABEL: func @affine_loop_one_iter(
-func.func @affine_loop_one_iter(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
+func.func @affine_loop_one_iter() -> i32 {
   // CHECK: %[[C1:.*]] = arith.constant 1 : i32
   %s0 = arith.constant 0 : i32
   %s1 = arith.constant 1 : i32
@@ -167,17 +167,27 @@ func.func @affine_loop_one_iter(%arg0 : index, %arg1 : index, %arg2 : index) ->
 }
 
 // CHECK-LABEL: func @affine_loop_zero_iter(
-func.func @affine_loop_zero_iter(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
-  // This exposes a crash in sccp/forward data flow analysis: https://github.com/llvm/llvm-project/issues/54928
+func.func @affine_loop_zero_iter() -> i32 {
+  // CHECK: %[[C1:.*]] = arith.constant 1 : i32
+  %s1 = arith.constant 1 : i32
+  %result = affine.for %i = 0 to 0 iter_args(%si = %s1) -> (i32) {
+   %sn = arith.addi %si, %si : i32
+   affine.yield %sn : i32
+  }
+  // CHECK: return %[[C1]] : i32
+  return %result : i32
+}
+
+// CHECK-LABEL: func @affine_loop_unknown_trip_count(
+func.func @affine_loop_unknown_trip_count(%ub: index) -> i32 {
   // CHECK: %[[C0:.*]] = arith.constant 0 : i32
   %s0 = arith.constant 0 : i32
-  // %result = affine.for %i = 0 to 0 iter_args(%si = %s0) -> (i32) {
-  //  %sn = arith.addi %si, %si : i32
-  //  affine.yield %sn : i32
-  // }
-  // return %result : i32
+  %result = affine.for %i = 0 to %ub iter_args(%si = %s0) -> (i32) {
+   %sn = arith.addi %si, %si : i32
+   affine.yield %sn : i32
+  }
   // CHECK: return %[[C0]] : i32
-  return %s0 : i32
+  return %result : i32
 }
 
 // CHECK-LABEL: func @while_loop_
diff erent_arg_count

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 415b545a6eae..ab0c543ed69d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1301,8 +1301,8 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
                                 parser.getCurrentLocation(), result.operands);
 }
 
-OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
-  assert(index < 2 && "invalid region index");
+OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  assert(index && *index < 2 && "invalid region index");
   return getOperands();
 }
 
@@ -1339,7 +1339,7 @@ void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
   // The parent op branches into the only region, and the region branches back
   // to the parent op.
-  if (index)
+  if (!index)
     regions.emplace_back(&getRegion());
   else
     regions.emplace_back(getResults());

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 94556c4d59a7..9ce9b21f15fd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2549,7 +2549,8 @@ def RegionIfOp : TEST_Op<"region_if",
     ::mlir::Block::BlockArgListType getJoinArgs() {
       return getBody(2)->getArguments();
     }
-    ::mlir::OperandRange getSuccessorEntryOperands(unsigned index);
+    ::mlir::OperandRange getSuccessorEntryOperands(
+        ::llvm::Optional<unsigned> index);
   }];
   let hasCustomAssemblyFormat = 1;
 }


        


More information about the Mlir-commits mailing list