[Mlir-commits] [mlir] c0fd5e6 - [mlir] Add traits for verifying the number of successors and providing relevant accessors.

River Riddle llvmlistbot at llvm.org
Thu Mar 5 12:58:45 PST 2020


Author: River Riddle
Date: 2020-03-05T12:49:59-08:00
New Revision: c0fd5e657e5d38a480d65b4e8f6f7a835afd6c76

URL: https://github.com/llvm/llvm-project/commit/c0fd5e657e5d38a480d65b4e8f6f7a835afd6c76
DIFF: https://github.com/llvm/llvm-project/commit/c0fd5e657e5d38a480d65b4e8f6f7a835afd6c76.diff

LOG: [mlir] Add traits for verifying the number of successors and providing relevant accessors.

This allows for simplifying OpDefGen, as well providing specializing accessors for the different successor counts. This mirrors the existing traits for operands and results.

Differential Revision: https://reviews.llvm.org/D75313

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/SPIRV/control-flow-ops.mlir
    mlir/test/mlir-tblgen/op-decl.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index efbcf0afb4d7..1a8132178a50 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -381,6 +381,10 @@ LogicalResult verifyResultsAreBoolLike(Operation *op);
 LogicalResult verifyResultsAreFloatLike(Operation *op);
 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
 LogicalResult verifyIsTerminator(Operation *op);
+LogicalResult verifyZeroSuccessor(Operation *op);
+LogicalResult verifyOneSuccessor(Operation *op);
+LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
+LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
 } // namespace impl
@@ -410,6 +414,9 @@ class TraitBase {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Operand Traits
+
 namespace detail {
 /// Utility trait base that provides accessors for derived traits that have
 /// multiple operands.
@@ -522,6 +529,9 @@ template <typename ConcreteType>
 class VariadicOperands
     : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
 
+//===----------------------------------------------------------------------===//
+// Result Traits
+
 /// This class provides return value APIs for ops that are known to have
 /// zero results.
 template <typename ConcreteType>
@@ -644,6 +654,123 @@ template <typename ConcreteType>
 class VariadicResults
     : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
 
+//===----------------------------------------------------------------------===//
+// Terminator Traits
+
+/// This class provides the API for ops that are known to be terminators.
+template <typename ConcreteType>
+class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
+public:
+  static AbstractOperation::OperationProperties getTraitProperties() {
+    return static_cast<AbstractOperation::OperationProperties>(
+        OperationProperty::Terminator);
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyIsTerminator(op);
+  }
+
+  unsigned getNumSuccessorOperands(unsigned index) {
+    return this->getOperation()->getNumSuccessorOperands(index);
+  }
+};
+
+/// This class provides verification for ops that are known to have zero
+/// successors.
+template <typename ConcreteType>
+class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyZeroSuccessor(op);
+  }
+};
+
+namespace detail {
+/// Utility trait base that provides accessors for derived traits that have
+/// multiple successors.
+template <typename ConcreteType, template <typename> class TraitType>
+struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
+  using succ_iterator = Operation::succ_iterator;
+  using succ_range = SuccessorRange;
+
+  /// Return the number of successors.
+  unsigned getNumSuccessors() {
+    return this->getOperation()->getNumSuccessors();
+  }
+
+  /// Return the successor at `index`.
+  Block *getSuccessor(unsigned i) {
+    return this->getOperation()->getSuccessor(i);
+  }
+
+  /// Set the successor at `index`.
+  void setSuccessor(Block *block, unsigned i) {
+    return this->getOperation()->setSuccessor(block, i);
+  }
+
+  /// Successor iterator access.
+  succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
+  succ_iterator succ_end() { return this->getOperation()->succ_end(); }
+  succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
+};
+} // end namespace detail
+
+/// This class provides APIs for ops that are known to have a single successor.
+template <typename ConcreteType>
+class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
+public:
+  Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
+  void setSuccessor(Block *succ) {
+    this->getOperation()->setSuccessor(succ, 0);
+  }
+
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOneSuccessor(op);
+  }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of successors.
+template <unsigned N>
+class NSuccessors {
+public:
+  static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
+
+  template <typename ConcreteType>
+  class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
+                                                      NSuccessors<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyNSuccessors(op, N);
+    }
+  };
+};
+
+/// This class provides APIs for ops that are known to have at least a specified
+/// number of successors.
+template <unsigned N>
+class AtLeastNSuccessors {
+public:
+  template <typename ConcreteType>
+  class Impl
+      : public detail::MultiSuccessorTraitBase<ConcreteType,
+                                               AtLeastNSuccessors<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyAtLeastNSuccessors(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops which have an unknown number of
+/// successors.
+template <typename ConcreteType>
+class VariadicSuccessors
+    : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
+};
+
+//===----------------------------------------------------------------------===//
+// Misc Traits
+
 /// This class provides verification for ops that are known to have the same
 /// operand shape: all operands are scalars, vectors/tensors of the same
 /// shape.
@@ -789,41 +916,6 @@ class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
   }
 };
 
-/// This class provides the API for ops that are known to be terminators.
-template <typename ConcreteType>
-class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
-public:
-  static AbstractOperation::OperationProperties getTraitProperties() {
-    return static_cast<AbstractOperation::OperationProperties>(
-        OperationProperty::Terminator);
-  }
-  static LogicalResult verifyTrait(Operation *op) {
-    return impl::verifyIsTerminator(op);
-  }
-
-  unsigned getNumSuccessors() {
-    return this->getOperation()->getNumSuccessors();
-  }
-  unsigned getNumSuccessorOperands(unsigned index) {
-    return this->getOperation()->getNumSuccessorOperands(index);
-  }
-
-  Block *getSuccessor(unsigned index) {
-    return this->getOperation()->getSuccessor(index);
-  }
-
-  void setSuccessor(Block *block, unsigned index) {
-    return this->getOperation()->setSuccessor(block, index);
-  }
-
-  void addSuccessorOperand(unsigned index, Value value) {
-    return this->getOperation()->addSuccessorOperand(index, value);
-  }
-  void addSuccessorOperands(unsigned index, ArrayRef<Value> values) {
-    return this->getOperation()->addSuccessorOperand(index, values);
-  }
-};
-
 /// This class provides the API for ops that are known to be isolated from
 /// above.
 template <typename ConcreteType>

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 2bf1969835f3..6a638673d584 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1894,7 +1894,7 @@ static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
     return false;
 
   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
-  return branchOp && branchOp.getSuccessor(0) == &dstBlock;
+  return branchOp && branchOp.getSuccessor() == &dstBlock;
 }
 
 static LogicalResult verify(spirv::LoopOp loopOp) {

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 21314970c199..1059e66d1fc5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -477,9 +477,9 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
 };
 } // end anonymous namespace.
 
-Block *BranchOp::getDest() { return getSuccessor(0); }
+Block *BranchOp::getDest() { return getSuccessor(); }
 
-void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
+void BranchOp::setDest(Block *block) { return setSuccessor(block); }
 
 void BranchOp::eraseOperand(unsigned index) {
   getOperation()->eraseSuccessorOperand(0, index);

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 49185eb159dd..bfd4b40b317b 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -942,6 +942,14 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
   return success();
 }
 
+LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
+  Block *block = op->getBlock();
+  // Verify that the operation is at the end of the respective parent block.
+  if (!block || &block->back() != op)
+    return op->emitOpError("must be the last operation in the parent block");
+  return success();
+}
+
 static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
   Operation::operand_range operands = op->getSuccessorOperands(succNo);
   unsigned operandCount = op->getNumSuccessorOperands(succNo);
@@ -976,18 +984,40 @@ static LogicalResult verifyTerminatorSuccessors(Operation *op) {
   return success();
 }
 
-LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
-  Block *block = op->getBlock();
-  // Verify that the operation is at the end of the respective parent block.
-  if (!block || &block->back() != op)
-    return op->emitOpError("must be the last operation in the parent block");
-
-  // Verify the state of the successor blocks.
-  if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op)))
-    return failure();
+LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) {
+  if (op->getNumSuccessors() != 0) {
+    return op->emitOpError("requires 0 successors but found ")
+           << op->getNumSuccessors();
+  }
   return success();
 }
 
+LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) {
+  if (op->getNumSuccessors() != 1) {
+    return op->emitOpError("requires 1 successor but found ")
+           << op->getNumSuccessors();
+  }
+  return verifyTerminatorSuccessors(op);
+}
+LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op,
+                                               unsigned numSuccessors) {
+  if (op->getNumSuccessors() != numSuccessors) {
+    return op->emitOpError("requires ")
+           << numSuccessors << " successors but found "
+           << op->getNumSuccessors();
+  }
+  return verifyTerminatorSuccessors(op);
+}
+LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op,
+                                                      unsigned numSuccessors) {
+  if (op->getNumSuccessors() < numSuccessors) {
+    return op->emitOpError("requires at least ")
+           << numSuccessors << " successors but found "
+           << op->getNumSuccessors();
+  }
+  return verifyTerminatorSuccessors(op);
+}
+
 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
   for (auto resultType : op->getResultTypes()) {
     auto elementType = getTensorOrVectorElementType(resultType);

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d45f05b6d196..19bf61d311a6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -356,7 +356,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   // Emit branches.  We need to look up the remapped blocks and ignore the block
   // arguments that were transformed into PHI nodes.
   if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
-    builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
+    builder.CreateBr(blockMapping[brOp.getSuccessor()]);
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {

diff  --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index cc7e09fcd1e1..a6dcac1f029f 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -24,7 +24,7 @@ func @branch_argument() -> () {
 // -----
 
 func @missing_accessor() -> () {
-  // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}}
+  // expected-error @+1 {{requires 1 successor but found 0}}
   spv.Branch
 }
 
@@ -32,7 +32,7 @@ func @missing_accessor() -> () {
 
 func @wrong_accessor_count() -> () {
   %true = spv.constant true
-  // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}}
+  // expected-error @+1 {{requires 1 successor but found 2}}
   "spv.Branch"()[^one, ^two] : () -> ()
 ^one:
   spv.Return
@@ -116,7 +116,7 @@ func @wrong_condition_type() -> () {
 
 func @wrong_accessor_count() -> () {
   %true = spv.constant true
-  // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}}
+  // expected-error @+1 {{requires 2 successors but found 1}}
   "spv.BranchConditional"(%true)[^one] : (i1) -> ()
 ^one:
   spv.Return

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 61f0c563974e..f07f99527ff8 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -54,7 +54,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
 // CHECK:   ArrayRef<Value> tblgen_operands;
 // CHECK: };
 
-// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
+// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
 // CHECK: public:
 // CHECK:   using Op::Op;
 // CHECK:   using OperandAdaptor = AOpOperandAdaptor;

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 8c6ba60b11f4..ebd82f9fb8bf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1390,26 +1390,8 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
 }
 
 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
-  unsigned numSuccessors = op.getNumSuccessors();
-
-  const char *checkSuccessorSizeCode = R"(
-  if (this->getOperation()->getNumSuccessors() {0} {1}) {
-    return emitOpError("has incorrect number of successors: expected{2} {1}"
-                       " but found ")
-             << this->getOperation()->getNumSuccessors();
-  }
-  )";
-
-  // Verify this op has the correct number of successors.
-  unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
-  if (numVariadicSuccessors == 0) {
-    body << formatv(checkSuccessorSizeCode, "!=", numSuccessors, "");
-  } else if (numVariadicSuccessors != numSuccessors) {
-    body << formatv(checkSuccessorSizeCode, "<",
-                    numSuccessors - numVariadicSuccessors, " at least");
-  }
-
   // If we have no successors, there is nothing more to do.
+  unsigned numSuccessors = op.getNumSuccessors();
   if (numSuccessors == 0)
     return;
 
@@ -1441,31 +1423,44 @@ void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
   body << "  }\n";
 }
 
+/// Add a size count trait to the given operation class.
+static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
+                              int numNonVariadic, int numVariadic) {
+  if (numVariadic != 0) {
+    if (numNonVariadic == numVariadic)
+      opClass.addTrait("OpTrait::Variadic" + traitKind + "s");
+    else
+      opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" +
+                       Twine(numNonVariadic - numVariadic) + ">::Impl");
+    return;
+  }
+  switch (numNonVariadic) {
+  case 0:
+    opClass.addTrait("OpTrait::Zero" + traitKind);
+    break;
+  case 1:
+    opClass.addTrait("OpTrait::One" + traitKind);
+    break;
+  default:
+    opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) +
+                     ">::Impl");
+    break;
+  }
+}
+
 void OpEmitter::genTraits() {
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariadicResults();
 
   // Add return size trait.
-  if (numVariadicResults != 0) {
-    if (numResults == numVariadicResults)
-      opClass.addTrait("OpTrait::VariadicResults");
-    else
-      opClass.addTrait("OpTrait::AtLeastNResults<" +
-                       Twine(numResults - numVariadicResults) + ">::Impl");
-  } else {
-    switch (numResults) {
-    case 0:
-      opClass.addTrait("OpTrait::ZeroResult");
-      break;
-    case 1:
-      opClass.addTrait("OpTrait::OneResult");
-      break;
-    default:
-      opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl");
-      break;
-    }
-  }
+  addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
+
+  // Add successor size trait.
+  unsigned numSuccessors = op.getNumSuccessors();
+  unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
+  addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
 
+  // Add the native and interface traits.
   for (const auto &trait : op.getTraits()) {
     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
       opClass.addTrait(opTrait->getTrait());


        


More information about the Mlir-commits mailing list