[Mlir-commits] [mlir] 9445b39 - [mlir] Support verification order (2/3)

Chia-hung Duan llvmlistbot at llvm.org
Fri Feb 25 11:09:03 PST 2022


Author: Chia-hung Duan
Date: 2022-02-25T19:04:56Z
New Revision: 9445b39673c81e9b5ffeda0f71be0c1476e9f313

URL: https://github.com/llvm/llvm-project/commit/9445b39673c81e9b5ffeda0f71be0c1476e9f313
DIFF: https://github.com/llvm/llvm-project/commit/9445b39673c81e9b5ffeda0f71be0c1476e9f313.diff

LOG: [mlir] Support verification order (2/3)

    This change gives explicit order of verifier execution and adds
    `hasRegionVerifier` and `verifyWithRegions` to increase the granularity
    of verifier classification. The orders are as below,

    1. InternalOpTrait will be verified first, they can be run independently.
    2. `verifyInvariants` which is constructed by ODS, it verifies the type,
       attributes, .etc.
    3. Other Traits/Interfaces that have marked their verifier as
       `verifyTrait` or `verifyWithRegions=0`.
    4. Custom verifier which is defined in the op and has marked
       `hasVerifier=1`

    If an operation has regions, then it may have the second phase,

    5. Traits/Interfaces that have marked their verifier as
       `verifyRegionTrait` or
       `verifyWithRegions=1`. This implies the verifier needs to access the
       operations in its regions.
    6. Custom verifier which is defined in the op and has marked
       `hasRegionVerifier=1`

    Note that the second phase will be run after the operations in the
    region are verified. Based on the verification order, you will be able to
    avoid verifying duplicate things.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116789

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/docs/Traits.md
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/TableGen/Interfaces.h
    mlir/include/mlir/TableGen/Trait.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Verifier.cpp
    mlir/lib/TableGen/Interfaces.cpp
    mlir/lib/TableGen/Trait.cpp
    mlir/test/Dialect/Arithmetic/invalid.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/LLVMIR/global.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/traits.mlir
    mlir/test/IR/invalid-module-op.mlir
    mlir/test/IR/traits.mlir
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/test/mlir-tblgen/op-interface.td
    mlir/test/mlir-tblgen/types.mlir
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index e9aa37f5fa76c..09a3ca570060f 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -567,10 +567,39 @@ _additional_ verification, you can use
 let hasVerifier = 1;
 ```
 
-This will generate a `LogicalResult verify()` method declaration on the op class
-that can be defined with any additional verification constraints. This method
-will be invoked after the auto-generated verification code. The order of trait
-verification excluding those of `hasVerifier` should not be relied upon.
+or
+
+```tablegen
+let hasRegionVerifier = 1;
+```
+
+This will generate either `LogicalResult verify()` or
+`LogicalResult verifyRegions()` method declaration on the op class
+that can be defined with any additional verification constraints. These method
+will be invoked on its verification order.
+
+#### Verification Ordering
+
+The verification of an operation involves several steps,
+
+1. StructuralOpTrait will be verified first, they can be run independently.
+1. `verifyInvariants` which is constructed by ODS, it verifies the type,
+   attributes, .etc.
+1. Other Traits/Interfaces that have marked their verifier as `verifyTrait` or
+   `verifyWithRegions=0`.
+1. Custom verifier which is defined in the op and has marked `hasVerifier=1`
+
+If an operation has regions, then it may have the second phase,
+
+1. Traits/Interfaces that have marked their verifier as `verifyRegionTrait` or
+   `verifyWithRegions=1`. This implies the verifier needs to access the
+   operations in its regions.
+1. Custom verifier which is defined in the op and has marked
+   `hasRegionVerifier=1`
+
+Note that the second phase will be run after the operations in the region are
+verified. Verifiers further down the order can rely on certain invariants being
+verified by a previous verifier and do not need to re-verify them.
 
 ### Declarative Assembly Format
 

diff  --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index 065c5155b935e..4a6915c74d231 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -36,9 +36,12 @@ class MyTrait : public TraitBase<ConcreteType, MyTrait> {
 };
 ```
 
-Operation traits may also provide a `verifyTrait` hook, that is called when
-verifying the concrete operation. The trait verifiers will currently always be
-invoked before the main `Op::verify`.
+Operation traits may also provide a `verifyTrait` or `verifyRegionTrait` hook
+that is called when verifying the concrete operation. The 
diff erence between
+these two is that whether the verifier needs to access the regions, if so, the
+operations in the regions will be verified before the verification of this
+trait. The [verification order](OpDefinitions.md/#verification-ordering)
+determines when a verifier will be invoked.
 
 ```c++
 template <typename ConcreteType>
@@ -53,8 +56,9 @@ public:
 ```
 
 Note: It is generally good practice to define the implementation of the
-`verifyTrait` hook out-of-line as a free function when possible to avoid
-instantiating the implementation for every concrete operation type.
+`verifyTrait` or `verifyRegionTrait` hook out-of-line as a free function when
+possible to avoid instantiating the implementation for every concrete operation
+type.
 
 Operation traits may also provide a `foldTrait` hook that is called when folding
 the concrete operation. The trait folders will only be invoked if the concrete

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index d2bef9d6ace26..96d134d802841 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -76,7 +76,7 @@ bool isTopLevelValue(Value value);
 class AffineDmaStartOp
     : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
                 OpTrait::VariadicOperands, OpTrait::ZeroResult,
-                AffineMapAccessInterface::Trait> {
+                OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
   static ArrayRef<StringRef> getAttributeNames() { return {}; }
@@ -227,7 +227,8 @@ class AffineDmaStartOp
   static StringRef getOperationName() { return "affine.dma_start"; }
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  LogicalResult verifyInvariants();
+  LogicalResult verifyInvariantsImpl();
+  LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
 
@@ -268,7 +269,7 @@ class AffineDmaStartOp
 class AffineDmaWaitOp
     : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
                 OpTrait::VariadicOperands, OpTrait::ZeroResult,
-                AffineMapAccessInterface::Trait> {
+                OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
   static ArrayRef<StringRef> getAttributeNames() { return {}; }
@@ -315,7 +316,8 @@ class AffineDmaWaitOp
   static StringRef getTagMapAttrName() { return "tag_map"; }
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  LogicalResult verifyInvariants();
+  LogicalResult verifyInvariantsImpl();
+  LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
 };

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 5ad8bd45b339e..11bc1a639794a 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2023,6 +2023,10 @@ class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
 // OpTrait definitions
 //===----------------------------------------------------------------------===//
 
+// A trait that describes the structure of operation will be marked with
+// `StructuralOpTrait` and they will be verified first.
+class StructuralOpTrait;
+
 // These classes are used to define operation specific traits.
 class NativeOpTrait<string name, list<Trait> traits = []>
     : NativeTrait<name, "Op"> {
@@ -2053,7 +2057,8 @@ class PredOpTrait<string descr, Pred pred, list<Trait> traits = []>
 // Op defines an affine scope.
 def AffineScope : NativeOpTrait<"AffineScope">;
 // Op defines an automatic allocation scope.
-def AutomaticAllocationScope : NativeOpTrait<"AutomaticAllocationScope">;
+def AutomaticAllocationScope :
+  NativeOpTrait<"AutomaticAllocationScope">;
 // Op supports operand broadcast behavior.
 def ResultsBroadcastableShape :
   NativeOpTrait<"ResultsBroadcastableShape">;
@@ -2074,9 +2079,11 @@ def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
 // Op has same shape for all operands.
 def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
 // Op has same operand and result shape.
-def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
+def SameOperandsAndResultShape :
+  NativeOpTrait<"SameOperandsAndResultShape">;
 // Op has the same element type (or type itself, if scalar) for all operands.
-def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
+def SameOperandsElementType :
+  NativeOpTrait<"SameOperandsElementType">;
 // Op has the same operand and result element type (or type itself, if scalar).
 def SameOperandsAndResultElementType :
   NativeOpTrait<"SameOperandsAndResultElementType">;
@@ -2104,21 +2111,23 @@ def ElementwiseMappable : TraitList<[
 ]>;
 
 // Op's regions have a single block.
-def SingleBlock : NativeOpTrait<"SingleBlock">;
+def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait;
 
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>
-    : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
+    : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>,
+      StructuralOpTrait;
 
 // Op's regions don't have terminator.
-def NoTerminator : NativeOpTrait<"NoTerminator">;
+def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
 
 // Op's parent operation is the provided one.
 class HasParent<string op>
-    : ParamNativeOpTrait<"HasParent", op>;
+    : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
 
 class ParentOneOf<list<string> ops>
-    : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>;
+    : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
+      StructuralOpTrait;
 
 // Op result type is derived from the first attribute. If the attribute is an
 // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
@@ -2147,13 +2156,15 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
 // vector that has the same number of elements as the number of ODS declared
 // operands. That means even if some operands are non-variadic, the attribute
 // still need to have an element for its size, which is always 1.
-def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">;
+def AttrSizedOperandSegments :
+  NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait;
 // Similar to AttrSizedOperandSegments, but used for results. The attribute
 // should be named as `result_segment_sizes`.
-def AttrSizedResultSegments  : NativeOpTrait<"AttrSizedResultSegments">;
+def AttrSizedResultSegments  :
+  NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait;
 
 // Op attached regions have no arguments
-def NoRegionArguments : NativeOpTrait<"NoRegionArguments">;
+def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait;
 
 //===----------------------------------------------------------------------===//
 // OpInterface definitions
@@ -2191,6 +2202,11 @@ class OpInterfaceTrait<string name, code verifyBody = [{}],
   // the operation being verified.
   code verify = verifyBody;
 
+  // A bit indicating if the verifier needs to access the ops in the regions. If
+  // it set to `1`, the region ops will be verified before invoking this
+  // verifier.
+  bit verifyWithRegions = 0;
+
   // Specify the list of traits that need to be verified before the verification
   // of this OpInterfaceTrait.
   list<Trait> dependentTraits = traits;
@@ -2467,6 +2483,16 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
   // operation class. The operation should implement this method and verify the
   // additional necessary invariants.
   bit hasVerifier = 0;
+
+  // A bit indicating if the operation has additional invariants that need to
+  // verified and which associate with regions (aside from those verified by the
+  // traits). If set to `1`, an additional `LogicalResult verifyRegions()`
+  // declaration will be generated on the operation class. The operation should
+  // implement this method and verify the additional necessary invariants
+  // associated with regions. Note that this method is invoked after all the
+  // region ops are verified.
+  bit hasRegionVerifier = 0;
+
   // A custom code block corresponding to the extra verification code of the
   // operation.
   // NOTE: This field is deprecated in favor of `hasVerifier` and is slated for

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index e38b0cb5db576..268b666e13669 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -200,7 +200,8 @@ class OpState {
 protected:
   /// If the concrete type didn't implement a custom verifier hook, just fall
   /// back to this one which accepts everything.
-  LogicalResult verifyInvariants() { return success(); }
+  LogicalResult verify() { return success(); }
+  LogicalResult verifyRegions() { return success(); }
 
   /// Parse the custom form of an operation. Unless overridden, this method will
   /// first try to get an operation parser from the op's dialect. Otherwise the
@@ -376,6 +377,18 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
 };
 } // namespace detail
 
+/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
+/// It should be run after core traits and before any other user defined traits.
+/// In order to run it in the correct order, wrap it with OpInvariants trait so
+/// that tblgen will be able to put it in the right order.
+template <typename ConcreteType>
+class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return cast<ConcreteType>(op).verifyInvariantsImpl();
+  }
+};
+
 /// This class provides the API for ops that are known to have no
 /// SSA operand.
 template <typename ConcreteType>
@@ -1572,6 +1585,14 @@ using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
 template <typename T>
 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
 
+/// Trait to check if T provides a `verifyTrait` method.
+template <typename T, typename... Args>
+using has_verify_region_trait =
+    decltype(T::verifyRegionTrait(std::declval<Operation *>()));
+template <typename T>
+using detect_has_verify_region_trait =
+    llvm::is_detected<has_verify_region_trait, T>;
+
 /// The internal implementation of `verifyTraits` below that returns the result
 /// of verifying the current operation with all of the provided trait types
 /// `Ts`.
@@ -1589,6 +1610,26 @@ template <typename TraitTupleT>
 static LogicalResult verifyTraits(Operation *op) {
   return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
 }
+
+/// The internal implementation of `verifyRegionTraits` below that returns the
+/// result of verifying the current operation with all of the provided trait
+/// types `Ts`.
+template <typename... Ts>
+static LogicalResult verifyRegionTraitsImpl(Operation *op,
+                                            std::tuple<Ts...> *) {
+  LogicalResult result = success();
+  (void)std::initializer_list<int>{
+      (result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
+       0)...};
+  return result;
+}
+
+/// Given a tuple type containing a set of traits that contain a
+/// `verifyTrait` method, return the result of verifying the given operation.
+template <typename TraitTupleT>
+static LogicalResult verifyRegionTraits(Operation *op) {
+  return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
+}
 } // namespace op_definition_impl
 
 //===----------------------------------------------------------------------===//
@@ -1603,7 +1644,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
 public:
   /// Inherit getOperation from `OpState`.
   using OpState::getOperation;
-  using OpState::verifyInvariants;
+  using OpState::verify;
+  using OpState::verifyRegions;
 
   /// Return if this operation contains the provided trait.
   template <template <typename T> class Trait>
@@ -1704,6 +1746,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
   using VerifiableTraitsTupleT =
       typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
                                    Traits<ConcreteType>...>::type;
+  /// A tuple type containing the region traits that have a verify function.
+  using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
+      op_definition_impl::detect_has_verify_region_trait,
+      Traits<ConcreteType>...>::type;
 
   /// Returns an interface map containing the interfaces registered to this
   /// operation.
@@ -1839,11 +1885,22 @@ class Op : public OpState, public Traits<ConcreteType>... {
                   "Op class shouldn't define new data members");
     return failure(
         failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
-        failed(cast<ConcreteType>(op).verifyInvariants()));
+        failed(cast<ConcreteType>(op).verify()));
   }
   static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
     return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
   }
+  /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
+  static LogicalResult verifyRegionInvariants(Operation *op) {
+    static_assert(hasNoDataMembers(),
+                  "Op class shouldn't define new data members");
+    return failure(failed(op_definition_impl::verifyRegionTraits<
+                          VerifiableRegionTraitsTupleT>(op)) ||
+                   failed(cast<ConcreteType>(op).verifyRegions()));
+  }
+  static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
+    return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
+  }
 
   static constexpr bool hasNoDataMembers() {
     // Checking that the derived class does not define any member by comparing

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 5ae07287b88c0..f72c24480d71c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -73,6 +73,8 @@ class OperationName {
       llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
   using VerifyInvariantsFn =
       llvm::unique_function<LogicalResult(Operation *) const>;
+  using VerifyRegionInvariantsFn =
+      llvm::unique_function<LogicalResult(Operation *) const>;
 
 protected:
   /// This class represents a type erased version of an operation. It contains
@@ -112,6 +114,7 @@ class OperationName {
     ParseAssemblyFn parseAssemblyFn;
     PrintAssemblyFn printAssemblyFn;
     VerifyInvariantsFn verifyInvariantsFn;
+    VerifyRegionInvariantsFn verifyRegionInvariantsFn;
 
     /// A list of attribute names registered to this operation in StringAttr
     /// form. This allows for operation classes to use StringAttr for attribute
@@ -238,16 +241,18 @@ class RegisteredOperationName : public OperationName {
   static void insert(Dialect &dialect) {
     insert(T::getOperationName(), dialect, TypeID::get<T>(),
            T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
-           T::getVerifyInvariantsFn(), T::getFoldHookFn(),
-           T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
-           T::getHasTraitFn(), T::getAttributeNames());
+           T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
+           T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
+           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
   }
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
   static void
   insert(StringRef name, Dialect &dialect, TypeID typeID,
          ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
-         VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+         VerifyInvariantsFn &&verifyInvariants,
+         VerifyRegionInvariantsFn &&verifyRegionInvariants,
+         FoldHookFn &&foldHook,
          GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
          detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
          ArrayRef<StringRef> attrNames);
@@ -272,12 +277,15 @@ class RegisteredOperationName : public OperationName {
     return impl->printAssemblyFn(op, p, defaultDialect);
   }
 
-  /// This hook implements the verifier for this operation.  It should emits an
-  /// error message and returns failure if a problem is detected, or returns
+  /// These hooks implement the verifiers for this operation.  It should emits
+  /// an error message and returns failure if a problem is detected, or returns
   /// success if everything is ok.
   LogicalResult verifyInvariants(Operation *op) const {
     return impl->verifyInvariantsFn(op);
   }
+  LogicalResult verifyRegionInvariants(Operation *op) const {
+    return impl->verifyRegionInvariantsFn(op);
+  }
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by

diff  --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 74a15b86385aa..9d50b26ac9a3c 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -98,6 +98,10 @@ class Interface {
   // Return the verify method body if it has one.
   llvm::Optional<StringRef> getVerify() const;
 
+  // If there's a verify method, return if it needs to access the ops in the
+  // regions.
+  bool verifyWithRegions() const;
+
   // Returns the Tablegen definition this interface was constructed from.
   const llvm::Record &getDef() const { return *def; }
 

diff  --git a/mlir/include/mlir/TableGen/Trait.h b/mlir/include/mlir/TableGen/Trait.h
index c3d0d2af7ed44..8da5303855fee 100644
--- a/mlir/include/mlir/TableGen/Trait.h
+++ b/mlir/include/mlir/TableGen/Trait.h
@@ -65,6 +65,9 @@ class NativeTrait : public Trait {
   // Returns the trait corresponding to a C++ trait class.
   std::string getFullyQualifiedTraitName() const;
 
+  // Returns if this is a structural op trait.
+  bool isStructuralOpTrait() const;
+
   static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
 };
 

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2fa26232cf507..90aa7baef8e31 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1117,7 +1117,7 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
   return success();
 }
 
-LogicalResult AffineDmaStartOp::verifyInvariants() {
+LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
     return emitOpError("expected DMA source to be of memref type");
   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
@@ -1219,7 +1219,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
   return success();
 }
 
-LogicalResult AffineDmaWaitOp::verifyInvariants() {
+LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
   if (!getOperand(0).getType().isa<MemRefType>())
     return emitOpError("expected DMA tag to be of memref type");
   Region *scope = getAffineScope(*this);

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 6c4dcaca311ef..1b6d354eff37e 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -693,7 +693,8 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
 void RegisteredOperationName::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
-    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+    VerifyInvariantsFn &&verifyInvariants,
+    VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
     ArrayRef<StringRef> attrNames) {
@@ -749,6 +750,7 @@ void RegisteredOperationName::insert(
   impl.parseAssemblyFn = std::move(parseAssembly);
   impl.printAssemblyFn = std::move(printAssembly);
   impl.verifyInvariantsFn = std::move(verifyInvariants);
+  impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
   impl.attributeNames = cachedAttrNames;
 }
 

diff  --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index bbc560d429d79..0c8724de8cdab 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -217,6 +217,11 @@ LogicalResult OperationVerifier::verifyOperation(
     }
   }
 
+  // After the region ops are verified, run the verifiers that have additional
+  // region invariants need to veirfy.
+  if (registeredInfo && failed(registeredInfo->verifyRegionInvariants(&op)))
+    return failure();
+
   // If this is a registered operation, there is nothing left to do.
   if (registeredInfo)
     return success();

diff  --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index d26ca0b689f9b..4d72ceeb45fc9 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -125,6 +125,10 @@ llvm::Optional<StringRef> Interface::getVerify() const {
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
 
+bool Interface::verifyWithRegions() const {
+  return def->getValueAsBit("verifyWithRegions");
+}
+
 //===----------------------------------------------------------------------===//
 // AttrInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp
index 4e28e9987e752..ee4b999c4bd47 100644
--- a/mlir/lib/TableGen/Trait.cpp
+++ b/mlir/lib/TableGen/Trait.cpp
@@ -50,6 +50,10 @@ std::string NativeTrait::getFullyQualifiedTraitName() const {
                               : (cppNamespace + "::" + trait).str();
 }
 
+bool NativeTrait::isStructuralOpTrait() const {
+  return def->isSubClassOf("StructuralOpTrait");
+}
+
 //===----------------------------------------------------------------------===//
 // InternalTrait
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index c0e6ebd1947e6..a4f13536a4bb8 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -168,7 +168,7 @@ func @func_with_ops(i32, i32) {
 func @func_with_ops() {
 ^bb0:
   %c = arith.constant dense<0> : vector<42 x i32>
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
+  // expected-error at +1 {{op failed to verify that result type has i1 element type and same shape as operands}}
   %r = "arith.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
 }
 
@@ -249,7 +249,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 {
 // -----
 
 func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
+  // expected-error at +1 {{op failed to verify that result type has i1 element type and same shape as operands}}
   %r = "arith.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
 }
 
@@ -285,7 +285,7 @@ func @index_cast_index_to_index(%arg0: index) {
 // -----
 
 func @index_cast_float(%arg0: index, %arg1: f32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
   %0 = arith.index_cast %arg0 : index to f32
   return
 }
@@ -293,7 +293,7 @@ func @index_cast_float(%arg0: index, %arg1: f32) {
 // -----
 
 func @index_cast_float_to_index(%arg0: f32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
   %0 = arith.index_cast %arg0 : f32 to index
   return
 }
@@ -301,7 +301,7 @@ func @index_cast_float_to_index(%arg0: f32) {
 // -----
 
 func @sitofp_i32_to_i64(%arg0 : i32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be floating-point-like, but got 'i64'}}
   %0 = arith.sitofp %arg0 : i32 to i64
   return
 }
@@ -309,7 +309,7 @@ func @sitofp_i32_to_i64(%arg0 : i32) {
 // -----
 
 func @sitofp_f32_to_i32(%arg0 : f32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'f32'}}
   %0 = arith.sitofp %arg0 : f32 to i32
   return
 }
@@ -333,7 +333,7 @@ func @fpext_f16_to_f16(%arg0 : f16) {
 // -----
 
 func @fpext_i32_to_f32(%arg0 : i32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.extf %arg0 : i32 to f32
   return
 }
@@ -341,7 +341,7 @@ func @fpext_i32_to_f32(%arg0 : i32) {
 // -----
 
 func @fpext_f32_to_i32(%arg0 : f32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.extf %arg0 : f32 to i32
   return
 }
@@ -373,7 +373,7 @@ func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) {
 // -----
 
 func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.extf %arg0 : vector<2xi32> to vector<2xf32>
   return
 }
@@ -381,7 +381,7 @@ func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
 // -----
 
 func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.extf %arg0 : vector<2xf32> to vector<2xi32>
   return
 }
@@ -405,7 +405,7 @@ func @fptrunc_f32_to_f32(%arg0 : f32) {
 // -----
 
 func @fptrunc_i32_to_f32(%arg0 : i32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.truncf %arg0 : i32 to f32
   return
 }
@@ -413,7 +413,7 @@ func @fptrunc_i32_to_f32(%arg0 : i32) {
 // -----
 
 func @fptrunc_f32_to_i32(%arg0 : f32) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.truncf %arg0 : f32 to i32
   return
 }
@@ -445,7 +445,7 @@ func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) {
 // -----
 
 func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.truncf %arg0 : vector<2xi32> to vector<2xf32>
   return
 }
@@ -453,7 +453,7 @@ func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
 // -----
 
 func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.truncf %arg0 : vector<2xf32> to vector<2xi32>
   return
 }
@@ -461,7 +461,7 @@ func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
 // -----
 
 func @sexti_index_as_operand(%arg0 : index) {
-  // expected-error at +1 {{are cast incompatible}}
+  // expected-error at +1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extsi %arg0 : index to i128
   return
 }
@@ -469,7 +469,7 @@ func @sexti_index_as_operand(%arg0 : index) {
 // -----
 
 func @zexti_index_as_operand(%arg0 : index) {
-  // expected-error at +1 {{operand type 'index' and result type}}
+  // expected-error at +1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extui %arg0 : index to i128
   return
 }
@@ -477,7 +477,7 @@ func @zexti_index_as_operand(%arg0 : index) {
 // -----
 
 func @trunci_index_as_operand(%arg0 : index) {
-  // expected-error at +1 {{operand type 'index' and result type}}
+  // expected-error at +1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %2 = arith.trunci %arg0 : index to i128
   return
 }
@@ -485,7 +485,7 @@ func @trunci_index_as_operand(%arg0 : index) {
 // -----
 
 func @sexti_index_as_result(%arg0 : i1) {
-  // expected-error at +1 {{result type 'index' are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extsi %arg0 : i1 to index
   return
 }
@@ -493,7 +493,7 @@ func @sexti_index_as_result(%arg0 : i1) {
 // -----
 
 func @zexti_index_as_operand(%arg0 : i1) {
-  // expected-error at +1 {{result type 'index' are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extui %arg0 : i1 to index
   return
 }
@@ -501,7 +501,7 @@ func @zexti_index_as_operand(%arg0 : i1) {
 // -----
 
 func @trunci_index_as_result(%arg0 : i128) {
-  // expected-error at +1 {{result type 'index' are cast incompatible}}
+  // expected-error at +1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %2 = arith.trunci %arg0 : i128 to index
   return
 }

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 006e1188de3b3..cdaef0196929c 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -301,7 +301,7 @@ func @reduce_incorrect_yield(%arg0 : f32) {
 // -----
 
 func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
-  // expected-error at +1 {{inferred type(s) 'f32', 'i1' are incompatible with return type(s) of operation 'i32', 'i1'}}
+  // expected-error at +1 {{op failed to verify that all of {value, result} have same type}}
   %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = #gpu<"shuffle_mode xor"> } : (f32, i32, i32) -> (i32, i1)
   return
 }

diff  --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index b831f3eb7c012..84f1d9e93584e 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -80,8 +80,8 @@ llvm.mlir.global internal constant @sectionvar("teststring")  {section = ".mysec
 
 // -----
 
-// expected-error @+1 {{requires string attribute 'sym_name'}}
-"llvm.mlir.global"() ({}) {type = i64, constant, value = 42 : i64} : () -> ()
+// expected-error @+1 {{op requires attribute 'sym_name'}}
+"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> ()
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40defe47a1cb6..c45c58abf91c8 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -214,15 +214,15 @@ func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
 
 // -----
 
-func @generic_scalar_operand_block_arg_type(%arg0: f32) {
+func @generic_scalar_operand_block_arg_type(%arg0: tensor<f32>) {
   // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> ()> ],
     iterator_types = []}
-      outs(%arg0 : f32) {
+      outs(%arg0 : tensor<f32>) {
     ^bb(%i: i1):
     linalg.yield %i : i1
-  }
+  } -> tensor<f32>
 }
 
 // -----
@@ -243,7 +243,7 @@ func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(o
 
 func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
                                  %arg1: tensor<?xf32>) {
-  // expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('f32')}}
+  // expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('tensor<f32>')}}
   %0 = linalg.generic {
     indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
@@ -251,7 +251,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
       outs(%arg1 : tensor<?xf32>) {
     ^bb(%i: f32, %j: f32):
       linalg.yield %i: f32
-  } -> f32
+  } -> tensor<f32>
 }
 
 // -----
@@ -362,11 +362,11 @@ func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
 
 // -----
 
-func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
+func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
 {
-  // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}}
-  %0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> memref<?x?xf32>
-  return %0 : memref<?x?xf32>
+  // expected-error @+1 {{op expected the number of results (1) to be equal to the number of output tensors (0)}}
+  %0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
 }
 
 // -----
@@ -384,7 +384,7 @@ func @illegal_fill_memref_with_tensor_return
 func @illegal_fill_tensor_with_memref_return
   (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
 {
-  // expected-error @+1 {{expected type of operand #1 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
+  // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'memref<?x?xf32>'}}
   %0 = linalg.fill(%arg1, %arg0) : f32, tensor<?x?xf32> -> memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
@@ -477,7 +477,7 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
   %c0 = arith.constant 0 : index
   %c192 = arith.constant 192 : index
   // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}}
-  %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ({
+  %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( {
     ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
          %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
          %C_: memref<192x192xf32>):
@@ -502,7 +502,7 @@ func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
   %c192 = arith.constant 192 : index
   %c24 = arith.constant 24 : index
   // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}}
-  "linalg.tiled_loop"(%c0, %c192, %c24, %A) ({
+  "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
     ^bb0(%arg4: index, %A_: memref<100xf32>):
       call @foo(%A_) : (memref<100xf32>)-> ()
       linalg.yield

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 8de70c2ce8758..a06ca5f801278 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -111,7 +111,7 @@ func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x11
 // -----
 
 func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
-  // expected-error @+1 {{incorrect element type for index attribute 'strides'}}
+  // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
   linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
     ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
     outs(%output: memref<1x56x56x96xf32>)
@@ -121,7 +121,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
 // -----
 
 func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
-  // expected-error @+1 {{incorrect shape for index attribute 'strides'}}
+  // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
   linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
     ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
     outs(%output: memref<1x56x56x96xf32>)

diff  --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index cd775b4d945fa..607d4b1f3a983 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -59,7 +59,7 @@ func @bit_field_u_extract_vec(%base: vector<3xi32>, %offset: i8, %count: i8) ->
 // -----
 
 func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> {
-  // expected-error @+1 {{inferred type(s) 'vector<3xi32>' are incompatible with return type(s) of operation 'vector<4xi32>'}}
+  // expected-error @+1 {{failed to verify that all of {base, result} have same type}}
   %0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32>
   spv.ReturnValue %0 : vector<4xi32>
 }
@@ -181,7 +181,7 @@ func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
 // -----
 
 func @shift_left_logical_invalid_result_type(%arg0: i32, %arg1 : i16) -> i16 {
-  // expected-error @+1 {{op inferred type(s) 'i32' are incompatible with return type(s) of operation 'i16'}}
+  // expected-error @+1 {{op failed to verify that all of {operand1, result} have same type}}
   %0 = "spv.ShiftLeftLogical" (%arg0, %arg1) : (i32, i16) -> (i16)
   spv.ReturnValue %0 : i16
 }

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index cdbe57a81e182..648786e88c6eb 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -98,8 +98,8 @@ func @shape_of(%value_arg : !shape.value_shape,
 // -----
 
 func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
-  // expected-error at +1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}}
-  %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32>
+  // expected-error at +1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}}
+  %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex>
   return
 }
 

diff  --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index daf09ebd79a2e..74bbcf7f01661 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -58,7 +58,7 @@ func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) ->
 // Check incompatible vector and tensor result type
 func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
 ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
-  // expected-error @+1 {{cannot broadcast vector with tensor}}
+  // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
   return %0 : vector<4xf32>
 }

diff  --git a/mlir/test/IR/invalid-module-op.mlir b/mlir/test/IR/invalid-module-op.mlir
index 8caace93dc7b0..8db29bfd165ca 100644
--- a/mlir/test/IR/invalid-module-op.mlir
+++ b/mlir/test/IR/invalid-module-op.mlir
@@ -3,7 +3,7 @@
 // -----
 
 func @module_op() {
-  // expected-error at +1 {{Operations with a 'SymbolTable' must have exactly one block}}
+  // expected-error at +1 {{'builtin.module' op expects region #0 to have 0 or 1 blocks}}
   builtin.module {
   ^bb1:
     "test.dummy"() : () -> ()

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index e6283b52caa52..7432069ee13dc 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -332,12 +332,12 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
 
 // Test the invariants of operations with the Symbol Trait.
 
-// expected-error at +1 {{requires string attribute 'sym_name'}}
+// expected-error at +1 {{op requires attribute 'sym_name'}}
 "test.symbol"() {} : () -> ()
 
 // -----
 
-// expected-error at +1 {{requires visibility attribute 'sym_visibility' to be a string attribute}}
+// expected-error at +1 {{op attribute 'sym_visibility' failed to satisfy constraint: string attribute}}
 "test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> ()
 
 // -----
@@ -364,7 +364,7 @@ func private @foo()
 // -----
 
 // Test that operation with the SymbolTable Trait fails with  too many blocks.
-// expected-error at +1 {{Operations with a 'SymbolTable' must have exactly one block}}
+// expected-error at +1 {{op expects region #0 to have 0 or 1 blocks}}
 "test.symbol_scope"() ({
   ^entry:
     "test.finish" () : () -> ()
@@ -668,4 +668,4 @@ func @failed_attr_traits() {
   // expected-error at +1 {{'attr' attribute should have trait 'TestAttrTrait'}}
   "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
   return
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 2c97422830f53..047a237db1278 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -68,7 +68,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   ::mlir::ValueRange odsOperands;
 // CHECK: };
 
-// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::IsIsolatedFromAbove
+// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsIsolatedFromAbove
 // CHECK-NOT: ::mlir::OpTrait::IsIsolatedFromAbove
 // CHECK: public:
 // CHECK:   using Op::Op;

diff  --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 64392c33f036c..51403fb286a8e 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -42,6 +42,19 @@ def TestOpInterface : OpInterface<"TestOpInterface"> {
   ];
 }
 
+def TestOpInterfaceVerify : OpInterface<"TestOpInterfaceVerify"> {
+  let verify = [{
+    return foo();
+  }];
+}
+
+def TestOpInterfaceVerifyRegion : OpInterface<"TestOpInterfaceVerifyRegion"> {
+  let verify = [{
+    return foo();
+  }];
+  let verifyWithRegions = 1;
+}
+
 // Define Ops with TestOpInterface and
 // DeclareOpInterfaceMethods<TestOpInterface> traits to check that there
 // are not duplicated C++ classes generated.
@@ -65,6 +78,12 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 // DECL: template<typename ConcreteOp>
 // DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
 
+// DECL-LABEL: struct TestOpInterfaceVerifyTrait
+// DECL: verifyTrait
+
+// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
+// DECL: verifyRegionTrait
+
 // OP_DECL-LABEL: class DeclareMethodsOp : public
 // OP_DECL: int foo(int input);
 // OP_DECL-NOT: int default_foo(int input);

diff  --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index 33ad65e780fad..a8dbf60ad44ae 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -58,7 +58,7 @@ func @complex_f64_tensor_success() {
 // -----
 
 func @complex_f64_failure() {
-  // expected-error at +1 {{op inferred type(s) 'complex<f64>' are incompatible with return type(s) of operation 'f64'}}
+  // expected-error at +1 {{op result #0 must be complex type with 64-bit float elements, but got 'f64'}}
   "test.complex_f64"() : () -> (f64)
   return
 }
@@ -438,7 +438,7 @@ func @operand_rank_equals_result_size_failure(%arg : tensor<1x2x3x4xi32>) {
 // -----
 
 func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
-  // expected-error at +1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<*xf32>'}}
+  // expected-error at +1 {{op failed to verify that all of {x, res} have same type}}
   "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
   return
 }
@@ -446,7 +446,7 @@ func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>
 // -----
 
 func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
-  // expected-error at +1 {{op inferred type(s) 'tensor<1x2xi32>' are incompatible with return type(s) of operation 'tensor<2x1xi32>'}}
+  // expected-error at +1 {{op failed to verify that all of {x, res} have same type}}
   "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32>
   return
 }

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6a45e1f7750ff..2c56c92ae1272 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -394,6 +394,9 @@ class OpEmitter {
   // Generates verify method for the operation.
   void genVerifier();
 
+  // Generates custom verify methods for the operation.
+  void genCustomVerifier();
+
   // Generates verify statements for operands and results in the operation.
   // The generated code will be attached to `body`.
   void genOperandResultVerifier(MethodBody &body,
@@ -593,6 +596,7 @@ OpEmitter::OpEmitter(const Operator &op,
   genParser();
   genPrinter();
   genVerifier();
+  genCustomVerifier();
   genCanonicalizerDecls();
   genFolderDecls();
   genTypeInterfaceMethods();
@@ -2236,47 +2240,76 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
 }
 
 void OpEmitter::genVerifier() {
-  auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
-  ERROR_IF_PRUNED(method, "verifyInvariants", op);
-  auto &body = method->body();
+  auto *implMethod =
+      opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
+  ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
+  auto &implBody = implMethod->body();
 
   OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
-  genNativeTraitAttrVerifier(body, emitHelper);
+  genNativeTraitAttrVerifier(implBody, emitHelper);
 
-  auto *valueInit = def.getValueInit("verifier");
-  StringInit *stringInit = dyn_cast<StringInit>(valueInit);
-  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
   populateSubstitutions(emitHelper, verifyCtx);
 
-  genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
-  genOperandResultVerifier(body, op.getOperands(), "operand");
-  genOperandResultVerifier(body, op.getResults(), "result");
+  genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
+  genOperandResultVerifier(implBody, op.getOperands(), "operand");
+  genOperandResultVerifier(implBody, op.getResults(), "result");
 
   for (auto &trait : op.getTraits()) {
     if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
-      body << tgfmt("  if (!($0))\n    "
-                    "return emitOpError(\"failed to verify that $1\");\n",
-                    &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
-                    t->getSummary());
+      implBody << tgfmt("  if (!($0))\n    "
+                        "return emitOpError(\"failed to verify that $1\");\n",
+                        &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
+                        t->getSummary());
     }
   }
 
-  genRegionVerifier(body);
-  genSuccessorVerifier(body);
+  genRegionVerifier(implBody);
+  genSuccessorVerifier(implBody);
+
+  implBody << "  return ::mlir::success();\n";
+
+  // TODO: Some places use the `verifyInvariants` to do operation verification.
+  // This may not act as their expectation because this doesn't call any
+  // verifiers of native/interface traits. Needs to review those use cases and
+  // see if we should use the mlir::verify() instead.
+  auto *valueInit = def.getValueInit("verifier");
+  StringInit *stringInit = dyn_cast<StringInit>(valueInit);
+  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
+
+  auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
+  ERROR_IF_PRUNED(method, "verifyInvariants", op);
+  auto &body = method->body();
+  if (hasCustomVerifyCodeBlock || def.getValueAsBit("hasVerifier")) {
+    body << "  if(::mlir::succeeded(verifyInvariantsImpl()) && "
+            "::mlir::succeeded(verify()))\n";
+    body << "    return ::mlir::success();\n";
+    body << "  return ::mlir::failure();";
+  } else {
+    body << "  return verifyInvariantsImpl();";
+  }
+}
+
+void OpEmitter::genCustomVerifier() {
+  auto *valueInit = def.getValueInit("verifier");
+  StringInit *stringInit = dyn_cast<StringInit>(valueInit);
+  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
 
   if (def.getValueAsBit("hasVerifier")) {
-    auto *method = opClass.declareMethod<Method::Private>(
-        "::mlir::LogicalResult", "verify");
+    auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify");
     ERROR_IF_PRUNED(method, "verify", op);
-    body << "  return verify();\n";
-
+  } else if (def.getValueAsBit("hasRegionVerifier")) {
+    auto *method =
+        opClass.declareMethod("::mlir::LogicalResult", "verifyRegions");
+    ERROR_IF_PRUNED(method, "verifyRegions", op);
   } else if (hasCustomVerifyCodeBlock) {
+    auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
+    ERROR_IF_PRUNED(method, "verify", op);
+    auto &body = method->body();
+
     FmtContext fctx;
     fctx.addSubst("cppClass", opClass.getClassName());
     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
     body << "  " << tgfmt(printer, &fctx);
-  } else {
-    body << "  return ::mlir::success();\n";
   }
 }
 
@@ -2508,12 +2541,27 @@ void OpEmitter::genTraits() {
     }
   }
 
+  // The op traits defined internal are ensured that they can be verified
+  // earlier.
+  for (const auto &trait : op.getTraits()) {
+    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      if (opTrait->isStructuralOpTrait())
+        opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    }
+  }
+
+  // OpInvariants wrapps the verifyInvariants which needs to be run before
+  // native/interface traits and after all the traits with `StructuralOpTrait`.
+  opClass.addTrait("::mlir::OpTrait::OpInvariants");
+
   // Add the native and interface traits.
   for (const auto &trait : op.getTraits()) {
-    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
-      opClass.addTrait(opTrait->getFullyQualifiedTraitName());
-    else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
+    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      if (!opTrait->isStructuralOpTrait())
+        opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    }
   }
 }
 

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 0513c5e7b1ab2..76a5164eba3c8 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -413,9 +413,12 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
 
     tblgen::FmtContext verifyCtx;
     verifyCtx.withOp("op");
-    os << "    static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) "
-          "{\n      "
-       << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n    }\n";
+    os << llvm::formatv(
+              "    static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
+              (interface.verifyWithRegions() ? "verifyRegionTrait"
+                                             : "verifyTrait"))
+       << "{\n      " << tblgen::tgfmt(verify->trim(), &verifyCtx)
+       << "\n    }\n";
   }
   if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";


        


More information about the Mlir-commits mailing list