[Mlir-commits] [mlir] a3ad8f9 - [MLIR] Add type checking capability to RegionBranchOpInterface

Rahul Joshi llvmlistbot at llvm.org
Wed Jul 15 11:14:30 PDT 2020


Author: Rahul Joshi
Date: 2020-07-15T11:14:07-07:00
New Revision: a3ad8f92b44d79487a34b1151251b413ef769070

URL: https://github.com/llvm/llvm-project/commit/a3ad8f92b44d79487a34b1151251b413ef769070
DIFF: https://github.com/llvm/llvm-project/commit/a3ad8f92b44d79487a34b1151251b413ef769070.diff

LOG: [MLIR] Add type checking capability to RegionBranchOpInterface

- Add function `verifyTypes` that Op's can call to do type checking verification
  along the control flow edges described the Op's RegionBranchOpInterface.
- We cannot rely on the verify methods on the OpInterface because the interface
  functions assume valid Ops, so they may crash if invoked on unverified Ops.
  (For example, scf.for getSuccessorRegions() calls getRegionIterArgs(), which
  dereferences getBody() block. If the scf.for is invalid with no body, this
  can lead to a segfault). `verifyTypes` can be called post op-verification to
  avoid this.

Differential Revision: https://reviews.llvm.org/D82829

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/test/Dialect/SCF/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index c859cb794e85..78aefec00bf7 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -418,7 +418,8 @@ def ReduceReturnOp :
   let assemblyFormat = "$result attr-dict `:` type($result)";
 }
 
-def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator]> {
+def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
+                               ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
@@ -437,5 +438,8 @@ def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator]> {
     OpBuilder<"OpBuilder &builder, OperationState &result",
               [{ /* nothing to do */ }]>
   ];
+  // Override default verifier (defined in SCF_Op), no custom verification
+  // needed.
+  let verifier = ?;
 }
 #endif // MLIR_DIALECT_SCF_SCFOPS

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 7e609ca13a09..725e13b8b9d2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -18,6 +18,7 @@
 
 namespace mlir {
 class BranchOpInterface;
+class RegionBranchOpInterface;
 
 //===----------------------------------------------------------------------===//
 // BranchOpInterface
@@ -40,12 +41,21 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
+namespace detail {
+/// Verify that types match along control flow edges described the given op.
+LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
+} //  namespace detail
+
 /// This class represents a successor of a region. A region successor can either
 /// be another region, or the parent operation. If the successor is a region,
-/// this class accepts the destination region, as well as a set of arguments
+/// this class represents the destination region, as well as a set of arguments
 /// from that region that will be populated by values from the current region.
-/// If the successor is the parent operation, this class accepts an optional set
-/// of results that will be populated by values from the current region.
+/// If the successor is the parent operation, this class represents an optional
+/// set of results that will be populated by values from the current region.
+///
+/// This interface assumes that the values from the current region that are used
+/// to populate the successor inputs are the operands of the return-like
+/// terminator operations in the blocks within this region.
 class RegionSuccessor {
 public:
   /// Initialize a successor that branches to another region of the parent
@@ -61,6 +71,9 @@ class RegionSuccessor {
   /// parent operation.
   Region *getSuccessor() const { return region; }
 
+  /// Return true if the successor is the parent operation.
+  bool isParent() const { return region == nullptr; }
+
   /// Return the inputs to the successor that are remapped by the exit values of
   /// the current region.
   ValueRange getSuccessorInputs() const { return inputs; }

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8b5a0b769ab1..2cd2b9cb8b43 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -103,9 +103,9 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Returns the operands of this operation used as the entry arguments when
-        entering the region at `index`, which was specified as a successor by
-        `getSuccessorRegions`. These operands should correspond 1-1 with the
-        successor inputs specified in `getSuccessorRegions`, and may corre
+        entering the region at `index`, which was specified as a successor of this
+        operation by `getSuccessorRegions`. These operands should correspond 1-1
+        with the successor inputs specified in `getSuccessorRegions`.
       }],
       "OperandRange", "getSuccessorEntryOperands",
       (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{
@@ -132,6 +132,19 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
            "SmallVectorImpl<RegionSuccessor> &":$regions)
     >
   ];
+
+  let verify = [{
+    static_assert(!ConcreteOpType::template hasTrait<OpTrait::ZeroRegion>(),
+                  "expected operation to have non-zero regions");
+    return success();
+  }];
+
+  let extraClassDeclaration = [{
+    /// Verify types along control flow edges described by this interface.
+    static LogicalResult verifyTypes(Operation *op) {
+      return detail::verifyTypesAlongControlFlowEdges(op);
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 67a3ae34c1d9..d0958e54269f 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -137,7 +137,8 @@ static LogicalResult verify(ForOp op) {
 
     i++;
   }
-  return success();
+
+  return RegionBranchOpInterface::verifyTypes(op);
 }
 
 static void print(OpAsmPrinter &p, ForOp op) {
@@ -413,7 +414,7 @@ static LogicalResult verify(IfOp op) {
   if (op.getNumResults() != 0 && op.elseRegion().empty())
     return op.emitOpError("must have an else block if defining values");
 
-  return success();
+  return RegionBranchOpInterface::verifyTypes(op);
 }
 
 static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
@@ -592,6 +593,12 @@ static LogicalResult verify(ParallelOp op) {
       return op.emitOpError(
           "expects arguments for the induction variable to be of index type");
 
+  // Check that the yield has no results
+  Operation *yield = body->getTerminator();
+  if (yield->getNumOperands() != 0)
+    return yield->emitOpError() << "not allowed to have operands inside '"
+                                << ParallelOp::getOperationName() << "'";
+
   // Check that the number of results is the same as the number of ReduceOps.
   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
   auto resultsSize = op.results().size();
@@ -869,31 +876,6 @@ static LogicalResult verify(ReduceReturnOp op) {
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
-static LogicalResult verify(YieldOp op) {
-  auto parentOp = op.getParentOp();
-  auto results = parentOp->getResults();
-  auto operands = op.getOperands();
-
-  if (isa<IfOp, ForOp>(parentOp)) {
-    if (parentOp->getNumResults() != op.getNumOperands())
-      return op.emitOpError() << "parent of yield must have same number of "
-                                 "results as the yield operands";
-    for (auto e : llvm::zip(results, operands)) {
-      if (std::get<0>(e).getType() != std::get<1>(e).getType())
-        return op.emitOpError()
-               << "types mismatch between yield op and its parent";
-    }
-  } else if (isa<ParallelOp>(parentOp)) {
-    if (op.getNumOperands() != 0)
-      return op.emitOpError()
-             << "yield inside scf.parallel is not allowed to have operands";
-  } else {
-    return op.emitOpError()
-           << "yield only terminates If, For or Parallel regions";
-  }
-
-  return success();
-}
 
 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 4> operands;

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c1fa833f26da..fc79c820165d 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/SmallPtrSet.h"
 
 using namespace mlir;
 
@@ -24,8 +25,9 @@ using namespace mlir;
 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
 /// successor if 'operandIndex' is within the range of 'operands', or None if
 /// `operandIndex` isn't a successor operand index.
-Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
-    Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
+Optional<BlockArgument>
+detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
+                                   unsigned operandIndex, Block *successor) {
   // Check that the operands are valid.
   if (!operands || operands->empty())
     return llvm::None;
@@ -43,8 +45,8 @@ Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
 
 /// Verify that the given operands match those of the given successor block.
 LogicalResult
-mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
-                                            Optional<OperandRange> operands) {
+detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+                                      Optional<OperandRange> operands) {
   if (!operands)
     return success();
 
@@ -66,3 +68,139 @@ mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// RegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Verify that types match along all region control flow edges originating from
+/// `sourceNo` (region # if source is a region, llvm::None if source is parent
+/// op). `getInputsTypesForRegion` is a function that returns the types of the
+/// inputs that flow from `sourceIndex' to the given region.
+static LogicalResult verifyTypesAlongAllEdges(
+    Operation *op, Optional<unsigned> sourceNo,
+    function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
+  auto regionInterface = cast<RegionBranchOpInterface>(op);
+
+  SmallVector<RegionSuccessor, 2> successors;
+  unsigned numInputs;
+  if (sourceNo) {
+    Region &srcRegion = op->getRegion(sourceNo.getValue());
+    numInputs = srcRegion.getNumArguments();
+  } else {
+    numInputs = op->getNumOperands();
+  }
+  SmallVector<Attribute, 2> operands(numInputs, nullptr);
+  regionInterface.getSuccessorRegions(sourceNo, operands, successors);
+
+  for (RegionSuccessor &succ : successors) {
+    Optional<unsigned> succRegionNo;
+    if (!succ.isParent())
+      succRegionNo = succ.getSuccessor()->getRegionNumber();
+
+    auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
+      diag << "from ";
+      if (sourceNo)
+        diag << "Region #" << sourceNo.getValue();
+      else
+        diag << op->getName();
+
+      diag << " to ";
+      if (succRegionNo)
+        diag << "Region #" << succRegionNo.getValue();
+      else
+        diag << op->getName();
+      return diag;
+    };
+
+    TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
+    TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
+    if (sourceTypes.size() != succInputsTypes.size()) {
+      InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
+      return printEdgeName(diag)
+             << " has " << sourceTypes.size()
+             << " source operands, but target successor needs "
+             << succInputsTypes.size();
+    }
+
+    for (auto typesIdx :
+         llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
+      Type sourceType = std::get<0>(typesIdx.value());
+      Type inputType = std::get<1>(typesIdx.value());
+      if (sourceType != inputType) {
+        InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
+        return printEdgeName(diag)
+               << " source #" << typesIdx.index() << " type " << sourceType
+               << " should match input #" << typesIdx.index() << " type "
+               << inputType;
+      }
+    }
+  }
+  return success();
+}
+
+/// Verify that types match along control flow edges described the given op.
+LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
+  auto regionInterface = cast<RegionBranchOpInterface>(op);
+
+  auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
+    if (regionNo.hasValue()) {
+      return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
+          .getTypes();
+    }
+
+    // If the successor of a parent op is the parent itself
+    // RegionBranchOpInterface does not have an API to query what the entry
+    // operands will be in that case. Vend out the result types of the op in
+    // that case so that type checking succeeds for this case.
+    return op->getResultTypes();
+  };
+
+  // Verify types along control flow edges originating from the parent.
+  if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
+    return failure();
+
+  // RegionBranchOpInterface should not be implemented by Ops that do not have
+  // attached regions.
+  assert(op->getNumRegions() != 0);
+
+  // Verify types along control flow edges originating from each region.
+  for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
+    Region &region = op->getRegion(regionNo);
+
+    // Since the interface cannnot distinguish between 
diff erent ReturnLike
+    // ops within the region branching to 
diff erent successors, all ReturnLike
+    // ops in this region should have the same operand types. We will then use
+    // one of them as the representative for type matching.
+
+    Operation *regionReturn = nullptr;
+    for (Block &block : region) {
+      Operation *terminator = block.getTerminator();
+      if (!terminator->hasTrait<OpTrait::ReturnLike>())
+        continue;
+
+      if (!regionReturn) {
+        regionReturn = terminator;
+        continue;
+      }
+
+      // Found more than one ReturnLike terminator. Make sure the operand types
+      // match with the first one.
+      if (regionReturn->getOperandTypes() != terminator->getOperandTypes())
+        return op->emitOpError("Region #")
+               << regionNo
+               << " operands mismatch between return-like terminators";
+    }
+
+    auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
+      // All successors get the same set of operands.
+      return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
+                          : TypeRange();
+    };
+
+    if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
+      return failure();
+  }
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 37e760495eb1..517e8855c97b 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -325,13 +325,13 @@ func @reduceReturn_not_inside_reduce(%arg0 : f32) {
 
 func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
 {
+  // expected-error at +1 {{region control flow edge from Region #0 to scf.if has 1 source operands, but target successor needs 2}}
   %x, %y = scf.if %arg0 -> (f32, f32) {
     %0 = addf %arg1, %arg1 : f32
-    // expected-error at +1 {{parent of yield must have same number of results as the yield operands}}
     scf.yield %0 : f32
   } else {
     %0 = subf %arg1, %arg1 : f32
-    scf.yield %0 : f32
+    scf.yield %0, %0 : f32, f32
   }
   return
 }
@@ -396,14 +396,39 @@ func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {
   return
 }
 
+// -----
+
+func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %s0 = constant 0.0 : f32
+  %t0 = constant 1.0 : f32
+  // expected-error @+1 {{along control flow edge from Region #0 to Region #0 source #1 type 'i32' should match input #1 type 'f32'}}
+  %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
+                    iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
+    %sn = addf %si, %si : f32
+    %ic = constant 1 : i32
+    scf.yield %sn, %ic : f32, i32
+  }
+  return
+}
+
+
 // -----
 
 func @parallel_invalid_yield(
     %arg0: index, %arg1: index, %arg2: index) {
   scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
     %c0 = constant 1.0 : f32
-    // expected-error at +1 {{yield inside scf.parallel is not allowed to have operands}}
+    // expected-error at +1 {{'scf.yield' op not allowed to have operands inside 'scf.parallel'}}
     scf.yield %c0 : f32
   }
   return
 }
+
+// -----
+func @yield_invalid_parent_op() {
+  "my.op"() ({
+   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}}
+   scf.yield
+  }) : () -> ()
+  return
+}


        


More information about the Mlir-commits mailing list