[Mlir-commits] [mlir] e7c7b16 - [mlir] Region/BranchOpInterface: Allow implicit type conversions along control-flow edges

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 4 12:33:18 PST 2022


Author: Mogball
Date: 2022-03-04T20:33:14Z
New Revision: e7c7b16a849fb40169a708d27cb59747139ed8c7

URL: https://github.com/llvm/llvm-project/commit/e7c7b16a849fb40169a708d27cb59747139ed8c7
DIFF: https://github.com/llvm/llvm-project/commit/e7c7b16a849fb40169a708d27cb59747139ed8c7.diff

LOG: [mlir] Region/BranchOpInterface: Allow implicit type conversions along control-flow edges

RegionBranchOpInterface and BranchOpInterface are allowed to make implicit type conversions along control-flow edges. In effect, this adds an interface method, `areTypesCompatible`, to both interfaces, which should return whether the types of corresponding successor operands and block arguments are compatible. Users of the interfaces, here on forth, must be aware that types may mismatch, although current users (in MLIR core), are not affected by this change. By default, type equality is used.

`async.execute` already has unequal types along control-flow edges (`!async.value<f32>` vs. `f32`), but it opted out of calling `RegionBranchOpInterface::verifyTypes` in its verifier. That method has now been removed and `RegionBranchOpInterface` will verify types along control edges by default in its verifier.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D120790

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/test/Transforms/control-flow-sink.mlir
    mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index d199bb2b42b09..f2138131e1f83 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -30,7 +30,8 @@ class Async_Op<string mnemonic, list<Trait> traits = []> :
 def Async_ExecuteOp :
   Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
                        DeclareOpInterfaceMethods<RegionBranchOpInterface,
-                                                 ["getSuccessorEntryOperands"]>,
+                                                 ["getSuccessorEntryOperands",
+                                                  "areTypesCompatible"]>,
                        AttrSizedOperandSegments,
                        AutomaticAllocationScope]> {
   let summary = "Asynchronous execute operation";

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63fcebada1ee5..d94c01e40db3f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -273,7 +273,6 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$bodyRegion);
   let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 31a69522c86c3..51534dd554157 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -865,7 +865,6 @@ def Shape_AssumingOp : Shape_Op<"assuming", [
 
   let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
 }
 
 def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 1fadf5bf57edc..25da76866ba5c 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -24,6 +24,13 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
   let description = [{
     This interface provides information for branching terminator operations,
     i.e. terminator operations with successors.
+
+    This interface is meant to model well-defined cases of control-flow of
+    value propagation, where what occurs along control-flow edges is assumed to
+    be side-effect free. For example, corresponding successor operands and
+    successor block arguments may have 
diff erent types. In such cases,
+    `areTypesCompatible` can be implemented to compare types along control-flow
+    edges. By default, type equality is used.
   }];
   let cppNamespace = "::mlir";
 
@@ -73,7 +80,15 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
       "::mlir::Block *", "getSuccessorForOperands",
       (ins "::mlir::ArrayRef<::mlir::Attribute>":$operands), [{}],
       /*defaultImplementation=*/[{ return nullptr; }]
-    >
+    >,
+    InterfaceMethod<[{
+        This method is called to compare types along control-flow edges. By
+        default, the types are checked as equal.
+      }],
+      "bool", "areTypesCompatible",
+      (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
+       [{ return lhs == rhs; }]
+    >,
   ];
 
   let verify = [{
@@ -96,6 +111,13 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     This interface provides information for region operations that contain
     branching behavior between held regions, i.e. this interface allows for
     expressing control flow information for region holding operations.
+
+    This interface is meant to model well-defined cases of control-flow of
+    value propagation, where what occurs along control-flow edges is assumed to
+    be side-effect free. For example, corresponding successor operands and
+    successor block arguments may have 
diff erent types. In such cases,
+    `areTypesCompatible` can be implemented to compare types along control-flow
+    edges. By default, type equality is used.
   }];
   let cppNamespace = "::mlir";
 
@@ -156,12 +178,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
        [{ invocationBounds.append($_op->getNumRegions(),
                                   ::mlir::InvocationBounds::getUnknown()); }]
     >,
+    InterfaceMethod<[{
+        This method is called to compare types along control-flow edges. By
+        default, the types are checked as equal.
+      }],
+      "bool", "areTypesCompatible",
+      (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
+       [{ return lhs == rhs; }]
+    >,
   ];
 
   let verify = [{
     static_assert(!ConcreteOp::template hasTrait<OpTrait::ZeroRegion>(),
                   "expected operation to have non-zero regions");
-    return success();
+    return detail::verifyTypesAlongControlFlowEdges($_op);
   }];
 
   let extraClassDeclaration = [{
@@ -171,11 +201,6 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
        SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
        getSuccessorRegions(index, nullAttrs, regions);
     }
-
-    /// Verify types along control flow edges described by this interface.
-    static LogicalResult verifyTypes(Operation *op) {
-      return detail::verifyTypesAlongControlFlowEdges(op);
-    }
   }];
 }
 

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 916e2d1af451b..541c69ec58bc0 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -50,7 +50,6 @@ LogicalResult YieldOp::verify() {
 
 MutableOperandRange
 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
-  assert(!index.hasValue());
   return operandsMutable();
 }
 
@@ -65,6 +64,15 @@ OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
   return operands();
 }
 
+bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
+  const auto getValueOrTokenType = [](Type type) {
+    if (auto value = type.dyn_cast<ValueType>())
+      return value.getValueType();
+    return type;
+  };
+  return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
+}
+
 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
                                     ArrayRef<Attribute>,
                                     SmallVectorImpl<RegionSuccessor> &regions) {

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index bf8edc0043488..28f2052d66fd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -244,10 +244,6 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-LogicalResult AllocaScopeOp::verify() {
-  return RegionBranchOpInterface::verifyTypes(*this);
-}
-
 void AllocaScopeOp::getSuccessorRegions(
     Optional<unsigned> index, ArrayRef<Attribute> operands,
     SmallVectorImpl<RegionSuccessor> &regions) {

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 23ef80dfb02d8..588c675f50423 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -315,8 +315,7 @@ LogicalResult ForOp::verify() {
 
     i++;
   }
-
-  return RegionBranchOpInterface::verifyTypes(*this);
+  return success();
 }
 
 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
@@ -1075,8 +1074,7 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
 LogicalResult IfOp::verify() {
   if (getNumResults() != 0 && getElseRegion().empty())
     return emitOpError("must have an else block if defining values");
-
-  return RegionBranchOpInterface::verifyTypes(*this);
+  return success();
 }
 
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2335,9 +2333,6 @@ static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
 }
 
 LogicalResult scf::WhileOp::verify() {
-  if (failed(RegionBranchOpInterface::verifyTypes(*this)))
-    return failure();
-
   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
       *this, getBefore(),
       "expects the 'before' region to terminate with 'scf.condition'");

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0f633312eaddc..9396cab054701 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -415,10 +415,6 @@ void AssumingOp::build(
   result.addTypes(assumingTypes);
 }
 
-LogicalResult AssumingOp::verify() {
-  return RegionBranchOpInterface::verifyTypes(*this);
-}
-
 //===----------------------------------------------------------------------===//
 // AddOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c8319bc1db960..02845c011472a 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -62,7 +62,8 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
   // Check the types.
   auto operandIt = operands->begin();
   for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
-    if ((*operandIt).getType() != destBB->getArgument(i).getType())
+    if (!cast<BranchOpInterface>(op).areTypesCompatible(
+            (*operandIt).getType(), destBB->getArgument(i).getType()))
       return op->emitError() << "type mismatch for bb argument #" << i
                              << " of successor #" << succNo;
   }
@@ -132,7 +133,7 @@ verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
       Type sourceType = std::get<0>(typesIdx.value());
       Type inputType = std::get<1>(typesIdx.value());
-      if (sourceType != inputType) {
+      if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
         return printEdgeName(diag)
                << ": source type #" << typesIdx.index() << " " << sourceType
@@ -169,6 +170,18 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   // attached regions.
   assert(op->getNumRegions() != 0);
 
+  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
+    if (lhs.size() != rhs.size())
+      return false;
+    for (auto types : llvm::zip(lhs, rhs)) {
+      if (!regionInterface.areTypesCompatible(std::get<0>(types),
+                                              std::get<1>(types))) {
+        return false;
+      }
+    }
+    return true;
+  };
+
   // Verify types along control flow edges originating from each region.
   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
     Region &region = op->getRegion(regionNo);
@@ -192,7 +205,8 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
 
       // Found more than one ReturnLike terminator. Make sure the operand types
       // match with the first one.
-      if (regionReturnOperands->getTypes() != terminatorOperands->getTypes())
+      if (!areTypesCompatible(regionReturnOperands->getTypes(),
+                              terminatorOperands->getTypes()))
         return op->emitOpError("Region #")
                << regionNo
                << " operands mismatch between return-like terminators";

diff  --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir
index 1327e01ab1b19..09ebd5b981f72 100644
--- a/mlir/test/Transforms/control-flow-sink.mlir
+++ b/mlir/test/Transforms/control-flow-sink.mlir
@@ -3,32 +3,38 @@
 // Test that operations can be sunk.
 
 // CHECK-LABEL: @test_simple_sink
-// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-SAME:  (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
 // CHECK-NEXT: %[[V0:.*]] = arith.subi %[[ARG2]], %[[ARG1]]
-// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i32 -> i32 then {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
 // CHECK-NEXT:   test.region_if_yield %[[V2]]
 // CHECK-NEXT: } else {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
 // CHECK-NEXT:   %[[V3:.*]] = arith.addi %[[V0]], %[[V2]]
 // CHECK-NEXT:   test.region_if_yield %[[V3]]
 // CHECK-NEXT: } join {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   %[[V2:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
 // CHECK-NEXT:   %[[V3:.*]] = arith.addi %[[V2]], %[[V0]]
 // CHECK-NEXT:   test.region_if_yield %[[V3]]
 // CHECK-NEXT: }
 // CHECK-NEXT: return %[[V1]]
-func @test_simple_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+func @test_simple_sink(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %0 = arith.subi %arg1, %arg2 : i32
   %1 = arith.subi %arg2, %arg1 : i32
   %2 = arith.addi %arg1, %arg1 : i32
   %3 = arith.addi %arg2, %arg2 : i32
-  %4 = test.region_if %arg0: i1 -> i32 then {
+  %4 = test.region_if %arg0: i32 -> i32 then {
+  ^bb0(%arg3: i32):
     test.region_if_yield %0 : i32
   } else {
+  ^bb0(%arg3: i32):
     %5 = arith.addi %1, %2 : i32
     test.region_if_yield %5 : i32
   } join {
+  ^bb0(%arg3: i32):
     %5 = arith.addi %3, %1 : i32
     test.region_if_yield %5 : i32
   }
@@ -38,37 +44,49 @@ func @test_simple_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
 // Test that a region op can be sunk.
 
 // CHECK-LABEL: @test_region_sink
-// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
-// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
-// CHECK-NEXT:   %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-SAME:  (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i32 -> i32 then {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
+// CHECK-NEXT:   %[[V1:.*]] = test.region_if %[[ARG0]]: i32 -> i32 then {
+// CHECK-NEXT:   ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:     test.region_if_yield %[[ARG1]]
 // CHECK-NEXT:   } else {
+// CHECK-NEXT:   ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:     %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
 // CHECK-NEXT:     test.region_if_yield %[[V2]]
 // CHECK-NEXT:   } join {
+// CHECK-NEXT:   ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:     test.region_if_yield %[[ARG2]]
 // CHECK-NEXT:   }
 // CHECK-NEXT:   test.region_if_yield %[[V1]]
 // CHECK-NEXT: } else {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   test.region_if_yield %[[ARG1]]
 // CHECK-NEXT: } join {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   test.region_if_yield %[[ARG2]]
 // CHECK-NEXT: }
 // CHECK-NEXT: return %[[V0]]
-func @test_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+func @test_region_sink(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %0 = arith.subi %arg1, %arg2 : i32
-  %1 = test.region_if %arg0: i1 -> i32 then {
+  %1 = test.region_if %arg0: i32 -> i32 then {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg1 : i32
   } else {
+  ^bb0(%arg3: i32):
     test.region_if_yield %0 : i32
   } join {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg2 : i32
   }
-  %2 = test.region_if %arg0: i1 -> i32 then {
+  %2 = test.region_if %arg0: i32 -> i32 then {
+  ^bb0(%arg3: i32):
     test.region_if_yield %1 : i32
   } else {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg1 : i32
   } join {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg2 : i32
   }
   return %2 : i32
@@ -77,8 +95,9 @@ func @test_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
 // Test that an entire subgraph can be sunk.
 
 // CHECK-LABEL: @test_subgraph_sink
-// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
-// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-SAME:  (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i32 -> i32 then {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   %[[V1:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
 // CHECK-NEXT:   %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG2]]
 // CHECK-NEXT:   %[[V3:.*]] = arith.subi %[[ARG2]], %[[ARG1]]
@@ -87,23 +106,28 @@ func @test_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
 // CHECK-NEXT:   %[[V6:.*]] = arith.addi %[[V5]], %[[V4]]
 // CHECK-NEXT:   test.region_if_yield %[[V6]]
 // CHECK-NEXT: } else {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   test.region_if_yield %[[ARG1]]
 // CHECK-NEXT: } join {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   test.region_if_yield %[[ARG2]]
 // CHECK-NEXT: }
 // CHECK-NEXT: return %[[V0]]
-func @test_subgraph_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+func @test_subgraph_sink(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %0 = arith.addi %arg1, %arg2 : i32
   %1 = arith.subi %arg1, %arg2 : i32
   %2 = arith.subi %arg2, %arg1 : i32
   %3 = arith.muli %0, %1 : i32
   %4 = arith.muli %2, %2 : i32
   %5 = arith.addi %3, %4 : i32
-  %6 = test.region_if %arg0: i1 -> i32 then {
+  %6 = test.region_if %arg0: i32 -> i32 then {
+  ^bb0(%arg3: i32):
     test.region_if_yield %5 : i32
   } else {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg1 : i32
   } join {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg2 : i32
   }
   return %6 : i32
@@ -112,7 +136,7 @@ func @test_subgraph_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
 // Test that ops can be sunk into regions with multiple blocks.
 
 // CHECK-LABEL: @test_multiblock_region_sink
-// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-SAME:  (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
 // CHECK-NEXT: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG2]]
 // CHECK-NEXT: %[[V1:.*]] = "test.any_cond"() ({
 // CHECK-NEXT:   %[[V3:.*]] = arith.addi %[[V0]], %[[ARG2]]
@@ -124,7 +148,7 @@ func @test_subgraph_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
 // CHECK-NEXT: })
 // CHECK-NEXT: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]]
 // CHECK-NEXT: return %[[V2]]
-func @test_multiblock_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+func @test_multiblock_region_sink(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %0 = arith.addi %arg1, %arg2 : i32
   %1 = arith.addi %0, %arg2 : i32
   %2 = arith.addi %1, %arg1 : i32
@@ -141,37 +165,49 @@ func @test_multiblock_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
 // Test that ops can be sunk recursively into nested regions.
 
 // CHECK-LABEL: @test_nested_region_sink
-// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) -> i32 {
-// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
-// CHECK-NEXT:   %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-SAME:  (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i32 -> i32 then {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
+// CHECK-NEXT:   %[[V1:.*]] = test.region_if %[[ARG0]]: i32 -> i32 then {
+// CHECK-NEXT:   ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:     %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
 // CHECK-NEXT:     test.region_if_yield %[[V2]]
 // CHECK-NEXT:   } else {
+// CHECK-NEXT:   ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:     test.region_if_yield %[[ARG1]]
 // CHECK-NEXT:   } join {
+// CHECK-NEXT:   ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:     test.region_if_yield %[[ARG1]]
 // CHECK-NEXT:   }
 // CHECK-NEXT:   test.region_if_yield %[[V1]]
 // CHECK-NEXT: } else {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   test.region_if_yield %[[ARG1]]
 // CHECK-NEXT: } join {
+// CHECK-NEXT: ^bb0(%{{.*}}: i32):
 // CHECK-NEXT:   test.region_if_yield %[[ARG1]]
 // CHECK-NEXT: }
 // CHECK-NEXT: return %[[V0]]
-func @test_nested_region_sink(%arg0: i1, %arg1: i32) -> i32 {
+func @test_nested_region_sink(%arg0: i32, %arg1: i32) -> i32 {
   %0 = arith.addi %arg1, %arg1 : i32
-  %1 = test.region_if %arg0: i1 -> i32 then {
-    %2 = test.region_if %arg0: i1 -> i32 then {
+  %1 = test.region_if %arg0: i32 -> i32 then {
+  ^bb0(%arg3: i32):
+    %2 = test.region_if %arg0: i32 -> i32 then {
+    ^bb0(%arg4: i32):
       test.region_if_yield %0 : i32
     } else {
+    ^bb0(%arg4: i32):
       test.region_if_yield %arg1 : i32
     } join {
+    ^bb0(%arg4: i32):
       test.region_if_yield %arg1 : i32
     }
     test.region_if_yield %2 : i32
   } else {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg1 : i32
   } join {
+  ^bb0(%arg3: i32):
     test.region_if_yield %arg1 : i32
   }
   return %1 : i32

diff  --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index b7f6d0ab925e8..1dfcc6d239a9e 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -54,8 +54,7 @@ struct SequentialRegionsOp
   void getSuccessorRegions(Optional<unsigned> index,
                            ArrayRef<Attribute> operands,
                            SmallVectorImpl<RegionSuccessor> &regions) {
-    assert(index.hasValue() && "expected index");
-    if (*index == 0) {
+    if (index == 0u) {
       Operation *thisOp = this->getOperation();
       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
     }


        


More information about the Mlir-commits mailing list