[flang-commits] [flang] 42e5f1d - [mlir] Refactor how additional verification is specified in ODS

River Riddle via flang-commits flang-commits at lists.llvm.org
Wed Feb 2 13:35:34 PST 2022


Author: River Riddle
Date: 2022-02-02T13:34:28-08:00
New Revision: 42e5f1d97b3ecf6f967a0e63ca39f05d3262e2b2

URL: https://github.com/llvm/llvm-project/commit/42e5f1d97b3ecf6f967a0e63ca39f05d3262e2b2
DIFF: https://github.com/llvm/llvm-project/commit/42e5f1d97b3ecf6f967a0e63ca39f05d3262e2b2.diff

LOG: [mlir] Refactor how additional verification is specified in ODS

Currently if an operation requires additional verification, it specifies an inline
code block (`let verifier = "blah"`). This is quite problematic for various reasons, e.g.
it requires defining C++ inside of Tablegen which is discouraged when possible, but mainly because
nearly all usages simply forward to a static function `static LogicalResult verify(SomeOp op)`.
This commit adds support for a `hasVerifier` bit field that specifies if an additional verifier
is needed, and when set to `1` declares a `LogicalResult verify()` method for operations to
override. For migration purposes, the existing behavior is untouched. Upstream usages will
be replaced in a followup to keep this patch focused on the hasVerifier implementation.

One main user facing change is that what was one `MyOp::verify` is now `MyOp::verifyInvariants`.
This better matches the name this method is called everywhere else, and also frees up `verify` for
the user defined additional verification. The `verify` function when generated now (for additional
verification) is private to the operation class, which should also help avoid accidental usages after
this switch.

Differential Revision: https://reviews.llvm.org/D118742

Added: 
    

Modified: 
    flang/tools/tco/tco.cpp
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Parser.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index e363394f23289..d242e70c0641e 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -86,7 +86,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
     errs() << "Error can't load file " << inputFilename << '\n';
     return mlir::failure();
   }
-  if (mlir::failed(owningRef->verify())) {
+  if (mlir::failed(owningRef->verifyInvariants())) {
     errs() << "Error verifying FIR module\n";
     return mlir::failure();
   }

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index f26afa5666e23..1058b33480073 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -564,14 +564,13 @@ Verification code will be automatically generated for
 _additional_ verification, you can use
 
 ```tablegen
-let verifier = [{
-  ...
-}];
+let hasVerifier = 1;
 ```
 
-Code placed in `verifier` will be called after the auto-generated verification
-code. The order of trait verification excluding those of `verifier` should not
-be relied upon.
+This will generate a `LogicalResult verify()` method declaration on the op class
+that can be defined with any additional verification constraints. This method
+will be invoked after the auto-generated verification code. The order of trait
+verification excluding those of `hasVerifier` should not be relied upon.
 
 ### Declarative Assembly Format
 

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index c350a53b1fb77..bcfc327be0e4c 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -225,7 +225,7 @@ class AffineDmaStartOp
   static StringRef getOperationName() { return "affine.dma_start"; }
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  LogicalResult verify();
+  LogicalResult verifyInvariants();
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
 
@@ -313,7 +313,7 @@ class AffineDmaWaitOp
   static StringRef getTagMapAttrName() { return "tag_map"; }
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  LogicalResult verify();
+  LogicalResult verifyInvariants();
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
 };

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 992ae35e4548b..92c9d524b586f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2451,7 +2451,16 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
   // Custom assembly format.
   string assemblyFormat = ?;
 
-  // Custom verifier.
+  // A bit indicating if the operation has additional invariants that need to
+  // verified (aside from those verified by other ODS constructs). If set to `1`,
+  // an additional `LogicalResult verify()` declaration will be generated on the
+  // operation class. The operation should implement this method and verify the
+  // additional necessary invariants.
+  bit hasVerifier = 0;
+  // A custom code block corresponding to the extra verification code of the
+  // operation.
+  // NOTE: This field is deprecated in favor of `hasVerifier` and is slated for
+  // deletion.
   code verifier = ?;
 
   // Whether this op has associated canonicalization patterns.

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index d1400c8557834..d83a7be5c0cd3 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -201,7 +201,7 @@ class OpState {
 protected:
   /// If the concrete type didn't implement a custom verifier hook, just fall
   /// back to this one which accepts everything.
-  LogicalResult verify() { return success(); }
+  LogicalResult verifyInvariants() { return success(); }
 
   /// Parse the custom form of an operation. Unless overridden, this method will
   /// first try to get an operation parser from the op's dialect. Otherwise the
@@ -1604,6 +1604,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
 public:
   /// Inherit getOperation from `OpState`.
   using OpState::getOperation;
+  using OpState::verifyInvariants;
 
   /// Return if this operation contains the provided trait.
   template <template <typename T> class Trait>
@@ -1834,8 +1835,15 @@ class Op : public OpState, public Traits<ConcreteType>... {
     return cast<ConcreteType>(op).print(p);
   }
   /// Implementation of `VerifyInvariantsFn` OperationName hook.
+  static LogicalResult verifyInvariants(Operation *op) {
+    static_assert(hasNoDataMembers(),
+                  "Op class shouldn't define new data members");
+    return failure(
+        failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
+        failed(cast<ConcreteType>(op).verifyInvariants()));
+  }
   static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
-    return &verifyInvariants;
+    return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
   }
 
   static constexpr bool hasNoDataMembers() {
@@ -1845,14 +1853,6 @@ class Op : public OpState, public Traits<ConcreteType>... {
     return sizeof(ConcreteType) == sizeof(EmptyOp);
   }
 
-  static LogicalResult verifyInvariants(Operation *op) {
-    static_assert(hasNoDataMembers(),
-                  "Op class shouldn't define new data members");
-    return failure(
-        failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
-        failed(cast<ConcreteType>(op).verify()));
-  }
-
   /// Allow access to internal implementation methods.
   friend RegisteredOperationName;
 };

diff  --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h
index 908c1188a96ee..9bf55fee3d1be 100644
--- a/mlir/include/mlir/Parser.h
+++ b/mlir/include/mlir/Parser.h
@@ -67,7 +67,7 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
 
   // After splicing, verify just this operation to ensure it can properly
   // contain the operations inside of it.
-  if (failed(op.verify()))
+  if (failed(op.verifyInvariants()))
     return OwningOpRef<ContainerOpT>();
   return opRef;
 }

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 958c5322e8e00..1d11ed6f82436 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1119,7 +1119,7 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
   return success();
 }
 
-LogicalResult AffineDmaStartOp::verify() {
+LogicalResult AffineDmaStartOp::verifyInvariants() {
   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
     return emitOpError("expected DMA source to be of memref type");
   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
@@ -1221,7 +1221,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
   return success();
 }
 
-LogicalResult AffineDmaWaitOp::verify() {
+LogicalResult AffineDmaWaitOp::verifyInvariants() {
   if (!getOperand(0).getType().isa<MemRefType>())
     return emitOpError("expected DMA tag to be of memref type");
   Region *scope = getAffineScope(*this);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index d6df234ecf87a..740582125998d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -86,7 +86,7 @@ Serializer::Serializer(spirv::ModuleOp module,
 LogicalResult Serializer::serialize() {
   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
 
-  if (failed(module.verify()))
+  if (failed(module.verifyInvariants()))
     return failure();
 
   // TODO: handle the other sections

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 21db4e0a9d11a..e23173d9ca576 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1118,6 +1118,26 @@ void StringAttrPrettyNameOp::getAsmResultNames(
         setNameFn(getResult(i), str.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// ResultTypeWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ResultTypeWithTraitOp::verify() {
+  if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+    return success();
+  return emitError("result type should have trait 'TestTypeTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// AttrWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AttrWithTraitOp::verify() {
+  if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
+    return success();
+  return emitError("'attr' attribute should have trait 'TestAttrTrait'");
+}
+
 //===----------------------------------------------------------------------===//
 // RegionIfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c37007f9ab0e6..6e14b62a3e73c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -666,27 +666,16 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
 // This operation requires its return type to have the trait 'TestTypeTrait'.
 def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
   let results = (outs AnyType);
-
-  let verifier = [{
-    if((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
-      return success();
-    return this->emitError("result type should have trait 'TestTypeTrait'");
-  }];
+  let hasVerifier = 1;
 }
 
 // This operation requires its "attr" attribute to have the
 // trait 'TestAttrTrait'.
 def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
   let arguments = (ins AnyAttr:$attr);
-
-  let verifier = [{
-    if (this->getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
-      return success();
-    return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
-  }];
+  let hasVerifier = 1;
 }
 
-
 //===----------------------------------------------------------------------===//
 // Test Locations
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index aace7cb14a4f6..0f801b6792fd1 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -98,7 +98,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions)
 // CHECK:   static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
 // CHECK:   void print(::mlir::OpAsmPrinter &p);
-// CHECK:   ::mlir::LogicalResult verify();
+// CHECK:   ::mlir::LogicalResult verifyInvariants();
 // CHECK:   static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
 // CHECK:   ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
 // CHECK:   // Display a graph for debugging purposes.

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 870a7e9dc6a55..c11f8484aa04f 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2208,8 +2208,8 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
 }
 
 void OpEmitter::genVerifier() {
-  auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
-  ERROR_IF_PRUNED(method, "verify", op);
+  auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
+  ERROR_IF_PRUNED(method, "verifyInvariants", op);
   auto &body = method->body();
 
   OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
@@ -2217,7 +2217,7 @@ void OpEmitter::genVerifier() {
 
   auto *valueInit = def.getValueInit("verifier");
   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
-  bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
+  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
   populateSubstitutions(emitHelper, verifyCtx);
 
   genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
@@ -2236,7 +2236,13 @@ void OpEmitter::genVerifier() {
   genRegionVerifier(body);
   genSuccessorVerifier(body);
 
-  if (hasCustomVerify) {
+  if (def.getValueAsBit("hasVerifier")) {
+    auto *method = opClass.declareMethod<Method::Private>(
+        "::mlir::LogicalResult", "verify");
+    ERROR_IF_PRUNED(method, "verify", op);
+    body << "  return verify();\n";
+
+  } else if (hasCustomVerifyCodeBlock) {
     FmtContext fctx;
     fctx.addSubst("cppClass", opClass.getClassName());
     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");


        


More information about the flang-commits mailing list