[Mlir-commits] [mlir] b0921f6 - [mlir] Add verify method to adaptor

Jacques Pienaar llvmlistbot at llvm.org
Fri Jun 5 09:47:53 PDT 2020


Author: Jacques Pienaar
Date: 2020-06-05T09:47:37-07:00
New Revision: b0921f68e1eeb3ac0cf4e178014237e14c20be03

URL: https://github.com/llvm/llvm-project/commit/b0921f68e1eeb3ac0cf4e178014237e14c20be03
DIFF: https://github.com/llvm/llvm-project/commit/b0921f68e1eeb3ac0cf4e178014237e14c20be03.diff

LOG: [mlir] Add verify method to adaptor

This allows verifying op-indepent attributes (e.g., attributes that do not require the op to have been created) before constructing an operation. These include checking whether required attributes are defined or constraints on attributes (such as I32 attribute). This is not perfect (e.g., if one had a disjunctive constraint where one part relied on the op and the other doesn't, then this would not try and extract the op independent from the op dependent).

The next step is to move these out to a trait that could be verified earlier than in the generated method. The first use case is for inferring the return type while constructing the op. At that point you don't have an Operation yet and that ends up in one having to duplicate the same checks, e.g., verify that attribute A is defined before querying A in shape function which requires that duplication. Instead this allows one to invoke a method to verify all the traits and, if this is checked first during verification, then all other traits could use attributes knowing they have been verified.

It is a little bit funny to have these on the adaptor, but I see the adaptor as a place to collect information about the op before the op is constructed (e.g., avoiding stringly typed accessors, verifying what is possible to verify before the op is constructed) while being cheap to use even with constructed op (so layer of indirection between the op constructed/being constructed). And from that point of view it made sense to me.

Differential Revision: https://reviews.llvm.org/D80842

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/LLVMIR/global.mlir
    mlir/test/Dialect/SPIRV/composite-ops.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/mlir-tblgen/op-attribute.td
    mlir/test/mlir-tblgen/predicate.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 42c431d13f8e..0c0b08509ff4 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -626,7 +626,8 @@ let verifier = [{
 ```
 
 Code placed in `verifier` will be called after the auto-generated verification
-code.
+code. The order of trait verification excluding those of `verifier` should not
+be relied upon.
 
 ### Declarative Assembly Format
 

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index b0cc4dd7a6eb..36b2ee9b5a8a 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -254,7 +254,7 @@ func @reduce_op_and_body(%arg0 : f32) {
 // -----
 
 func @reduce_invalid_op(%arg0 : f32) {
-  // expected-error at +1 {{gpu.all_reduce' op attribute 'op' failed to satisfy constraint}}
+  // expected-error at +1 {{attribute 'op' failed to satisfy constraint}}
   %res = "gpu.all_reduce"(%arg0) ({}) {op = "foo"} : (f32) -> (f32)
   return
 }
@@ -321,14 +321,14 @@ func @reduce_incorrect_yield(%arg0 : f32) {
 // -----
 
 func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
-  // expected-error at +1 {{'gpu.shuffle' op requires the same type for value operand and result}}
+  // expected-error at +1 {{requires the same type for value operand and result}}
   %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1)
 }
 
 // -----
 
 func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
-  // expected-error at +1 {{'gpu.shuffle' op requires value operand type to be f32 or i32}}
+  // expected-error at +1 {{requires value operand type to be f32 or i32}}
   %shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 0b97a8ebb1e5..b5b5639a5bd9 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -65,12 +65,12 @@ func @references() {
 
 // -----
 
-// expected-error @+1 {{op requires string attribute 'sym_name'}}
+// expected-error @+1 {{requires string attribute 'sym_name'}}
 "llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
 
 // -----
 
-// expected-error @+1 {{op requires attribute 'type'}}
+// expected-error @+1 {{requires attribute 'type'}}
 "llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> ()
 
 // -----

diff  --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir
index ca3f60311576..04153162e0dc 100644
--- a/mlir/test/Dialect/SPIRV/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir
@@ -124,7 +124,7 @@ func @composite_extract_invalid_index_type_1() -> () {
 // -----
 
 func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
-  // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
+  // expected-error @+1 {{attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
   %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>>
   return
 }

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1f6da8190bae..52d0586e98f2 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1069,7 +1069,7 @@ func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 {
 // -----
 
 func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
-  // expected-error at +1 {{'vector.reduction' op attribute 'kind' failed to satisfy constraint: string attribute}}
+  // expected-error at +1 {{attribute 'kind' failed to satisfy constraint: string attribute}}
   %0 = vector.reduction 1234, %arg0 : vector<16xf32> into i32
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index c8e40c520139..1ccf322ee8b5 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -58,7 +58,7 @@ func @constant_wrong_type() {
 func @affine_apply_no_map() {
 ^bb0:
   %i = constant 0 : index
-  %x = "affine.apply" (%i) { } : (index) -> (index) //  expected-error {{'affine.apply' op requires attribute 'map'}}
+  %x = "affine.apply" (%i) { } : (index) -> (index) //  expected-error {{requires attribute 'map'}}
   return
 }
 
@@ -1205,7 +1205,7 @@ func @assume_alignment(%0: memref<4x4xf16>) {
 
 // 0 alignment value.
 func @assume_alignment(%0: memref<4x4xf16>) {
-  // expected-error at +1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  // expected-error at +1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
   std.assume_alignment %0, 0 : memref<4x4xf16>
   return
 }

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 522dc2459fca..b4c850269a1d 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -30,6 +30,20 @@ def AOp : NS_Op<"a_op", []> {
 
 // DEF-LABEL: AOp definitions
 
+// Test verify method
+// ---
+
+// DEF:      LogicalResult AOpOperandAdaptor::verify
+// DEF:      auto tblgen_aAttr = odsAttrs.get("aAttr");
+// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
+// DEF:        if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      auto tblgen_bAttr = odsAttrs.get("bAttr");
+// DEF-NEXT: if (tblgen_bAttr) {
+// DEF-NEXT:   if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      auto tblgen_cAttr = odsAttrs.get("cAttr");
+// DEF-NEXT: if (tblgen_cAttr) {
+// DEF-NEXT:   if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+
 // Test getter methods
 // ---
 
@@ -80,20 +94,6 @@ def AOp : NS_Op<"a_op", []> {
 // DEF:        ArrayRef<NamedAttribute> attributes
 // DEF:      odsState.addAttributes(attributes);
 
-// Test verify method
-// ---
-
-// DEF:      AOp::verify()
-// DEF:      auto tblgen_aAttr = this->getAttr("aAttr");
-// DEF-NEXT: if (!tblgen_aAttr) return emitOpError("requires attribute 'aAttr'");
-// DEF:        if (!((some-condition))) return emitOpError("attribute 'aAttr' failed to satisfy constraint: some attribute kind");
-// DEF:      auto tblgen_bAttr = this->getAttr("bAttr");
-// DEF-NEXT: if (tblgen_bAttr) {
-// DEF-NEXT:   if (!((some-condition))) return emitOpError("attribute 'bAttr' failed to satisfy constraint: some attribute kind");
-// DEF:      auto tblgen_cAttr = this->getAttr("cAttr");
-// DEF-NEXT: if (tblgen_cAttr) {
-// DEF-NEXT:   if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
-
 def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
 
 def BOp : NS_Op<"b_op", []> {
@@ -114,27 +114,11 @@ def BOp : NS_Op<"b_op", []> {
   );
 }
 
-// Test common attribute kind getters' return types
-// ---
-
-// DEF: Attribute BOp::any_attr()
-// DEF: bool BOp::bool_attr()
-// DEF: APInt BOp::i32_attr()
-// DEF: APInt BOp::i64_attr()
-// DEF: APFloat BOp::f32_attr()
-// DEF: APFloat BOp::f64_attr()
-// DEF: StringRef BOp::str_attr()
-// DEF: ElementsAttr BOp::elements_attr()
-// DEF: StringRef BOp::function_attr()
-// DEF: SomeType BOp::type_attr()
-// DEF: ArrayAttr BOp::array_attr()
-// DEF: ArrayAttr BOp::some_attr_array()
-// DEF: Type BOp::type_attr()
 
 // Test common attribute kinds' constraints
 // ---
 
-// DEF-LABEL: BOp::verify
+// DEF-LABEL: BOpOperandAdaptor::verify
 // DEF: if (!((true)))
 // DEF: if (!((tblgen_bool_attr.isa<BoolAttr>())))
 // DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isSignlessInteger(32)))))
@@ -149,6 +133,23 @@ def BOp : NS_Op<"b_op", []> {
 // DEF: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
 // DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<Type>()))))
 
+// Test common attribute kind getters' return types
+// ---
+
+// DEF: Attribute BOp::any_attr()
+// DEF: bool BOp::bool_attr()
+// DEF: APInt BOp::i32_attr()
+// DEF: APInt BOp::i64_attr()
+// DEF: APFloat BOp::f32_attr()
+// DEF: APFloat BOp::f64_attr()
+// DEF: StringRef BOp::str_attr()
+// DEF: ElementsAttr BOp::elements_attr()
+// DEF: StringRef BOp::function_attr()
+// DEF: SomeType BOp::type_attr()
+// DEF: ArrayAttr BOp::array_attr()
+// DEF: ArrayAttr BOp::some_attr_array()
+// DEF: Type BOp::type_attr()
+
 // Test building constant values for array attribute kinds
 // ---
 

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index aa7b50710cde..a617208d157a 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -1,4 +1,4 @@
-// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --dump-input-on-failure
 
 include "mlir/IR/OpBase.td"
 
@@ -32,41 +32,41 @@ def OpF : NS_Op<"op_for_int_min_val", []> {
   let arguments = (ins Confined<I32Attr, [IntMinValue<10>]>:$attr);
 }
 
-// CHECK-LABEL: OpF::verify()
+// CHECK-LABEL: OpFOperandAdaptor::verify
 // CHECK:       (tblgen_attr.cast<IntegerAttr>().getInt() >= 10)
-// CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10");
+// CHECK-SAME:  "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"
 
 def OpFX : NS_Op<"op_for_int_max_val", []> {
   let arguments = (ins Confined<I32Attr, [IntMaxValue<10>]>:$attr);
 }
 
-// CHECK-LABEL: OpFX::verify()
+// CHECK-LABEL: OpFXOperandAdaptor::verify
 // CHECK:       (tblgen_attr.cast<IntegerAttr>().getInt() <= 10)
-// CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10");
+// CHECK-SAME:  "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"
 
 def OpG : NS_Op<"op_for_arr_min_count", []> {
   let arguments = (ins Confined<ArrayAttr, [ArrayMinCount<8>]>:$attr);
 }
 
-// CHECK-LABEL: OpG::verify()
+// CHECK-LABEL: OpGOperandAdaptor::verify
 // CHECK:       (tblgen_attr.cast<ArrayAttr>().size() >= 8)
-// CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements");
+// CHECK-SAME:  "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"
 
 def OpH : NS_Op<"op_for_arr_value_at_index", []> {
   let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr);
 }
 
-// CHECK-LABEL: OpH::verify()
+// CHECK-LABEL: OpHOperandAdaptor::verify
 // CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() == 8)))))
-// CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8");
+// CHECK-SAME:  "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"
 
 def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
   let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr);
 }
 
-// CHECK-LABEL: OpI::verify()
+// CHECK-LABEL: OpIOperandAdaptor::verify
 // CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() >= 8)))))
-// CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8");
+// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"
 
 def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
                 PredOpTrait<"operands indexed at 0, 2, 3 should all have "
@@ -80,11 +80,11 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
   );
 }
 
-// CHECK-LABEL: OpJ::verify()
+// CHECK-LABEL: OpJOperandAdaptor::verify
 // CHECK:      llvm::is_splat(llvm::map_range(
 // CHECK-SAME:   llvm::ArrayRef<unsigned>({0, 2, 3}),
 // CHECK-SAME:   [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
-// CHECK:   return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
+// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"
 
 def OpK : NS_Op<"op_for_AnyTensorOf", []> {
   let arguments = (ins TensorOf<[F32, I32]>:$x);

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7b0cd9d7a482..21dccd4f3d5a 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -321,6 +321,116 @@ class OpEmitter {
 };
 } // end anonymous namespace
 
+// Populate the format context `ctx` with substitutions of attributes, operands
+// and results.
+// - attrGet corresponds to the name of the function to call to get value of
+//   attribute (the generated function call returns an Attribute);
+// - operandGet corresponds to the name of the function with which to retrieve
+//   an operand (the generaed function call returns an OperandRange);
+// - reultGet corresponds to the name of the function to get an result (the
+//   generated function call returns a ValueRange);
+static void populateSubstitutions(const Operator &op, const char *attrGet,
+                                  const char *operandGet, const char *resultGet,
+                                  FmtContext &ctx) {
+  // Populate substitutions for attributes and named operands.
+  for (const auto &namedAttr : op.getAttributes())
+    ctx.addSubst(namedAttr.name,
+                 formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
+    auto &value = op.getOperand(i);
+    if (value.name.empty())
+      continue;
+
+    if (value.isVariadic())
+      ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
+    else
+      ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
+  }
+
+  // Populate substitutions for results.
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    auto &value = op.getResult(i);
+    if (value.name.empty())
+      continue;
+
+    if (value.isVariadic())
+      ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
+    else
+      ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
+  }
+}
+
+// Generate attribute verification. If emitVerificationRequiringOp is set then
+// only verification for attributes whose value depend on op being known are
+// emitted, else only verification that doesn't depend on the op being known are
+// generated.
+// - emitErrorPrefix is the prefix for the error emitting call which consists
+//   of the entire function call up to start of error message fragment;
+// - emitVerificationRequiringOp specifies whether verification should be
+//   emitted for verification that require the op to exist;
+static void genAttributeVerifier(const Operator &op, const char *attrGet,
+                                 const Twine &emitErrorPrefix,
+                                 bool emitVerificationRequiringOp,
+                                 FmtContext &ctx, OpMethodBody &body) {
+  for (const auto &namedAttr : op.getAttributes()) {
+    const auto &attr = namedAttr.attr;
+    if (attr.isDerivedAttr())
+      continue;
+
+    auto attrName = namedAttr.name;
+    bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
+    auto attrPred = attr.getPredicate();
+    auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
+    // There is a condition to emit only if the use of $_op and whether to
+    // emit verifications for op matches.
+    bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
+                               emitVerificationRequiringOp);
+
+    // Prefix with `tblgen_` to avoid hiding the attribute accessor.
+    auto varName = tblgenNamePrefix + attrName;
+
+    // If the attribute is
+    //  1. Required (not allowed missing) and not in op verification, or
+    //  2. Has a condition that will get verified
+    // then the variable will be used.
+    //
+    // Therefore, for optional attributes whose verification requires that an
+    // op already exists for verification/emitVerificationRequiringOp is set
+    // has nothing that can be verified here.
+    if ((allowMissingAttr || emitVerificationRequiringOp) &&
+        !hasConditionToEmit)
+      continue;
+
+    body << formatv("  {\n  auto {0} = {1}(\"{2}\");\n", varName, attrGet,
+                    attrName);
+
+    if (!emitVerificationRequiringOp && !allowMissingAttr) {
+      body << "  if (!" << varName << ") return " << emitErrorPrefix
+           << "\"requires attribute '" << attrName << "'\");\n";
+    }
+
+    if (!hasConditionToEmit) {
+      body << "  }\n";
+      continue;
+    }
+
+    if (allowMissingAttr) {
+      // If the attribute has a default value, then only verify the predicate if
+      // set. This does effectively assume that the default value is valid.
+      // TODO: verify the debug value is valid (perhaps in debug mode only).
+      body << "  if (" << varName << ") {\n";
+    }
+
+    body << tgfmt("    if (!($0)) return $1\"attribute '$2' "
+                  "failed to satisfy constraint: $3\");\n",
+                  /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
+                  emitErrorPrefix, attrName, attr.getDescription());
+    if (allowMissingAttr)
+      body << "  }\n";
+    body << "  }\n";
+  }
+}
+
 OpEmitter::OpEmitter(const Operator &op)
     : def(op.getDef()), op(op),
       opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
@@ -1512,110 +1622,27 @@ void OpEmitter::genPrinter() {
 }
 
 void OpEmitter::genVerifier() {
-  auto valueInit = def.getValueInit("verifier");
-  CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
-  bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
-
   auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
   auto &body = method.body();
+  body << "  if (failed(" << op.getAdaptorName()
+       << "(*this).verify(this->getLoc()))) "
+       << "return failure();\n";
 
-  const char *checkAttrSizedValueSegmentsCode = R"(
-  {
-    auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
-    auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
-    if (numElements != {1}) {{
-      return emitOpError("'{0}' attribute for specifying {2} segments "
-                         "must have {1} elements");
-    }
-  }
-  )";
-
-  // Verify a few traits first so that we can use
-  // getODSOperands()/getODSResults() in the rest of the verifier.
-  for (auto &trait : op.getTraits()) {
-    if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
-      if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
-        body << formatv(checkAttrSizedValueSegmentsCode,
-                        "operand_segment_sizes", op.getNumOperands(),
-                        "operand");
-      } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
-        body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
-                        op.getNumResults(), "result");
-      }
-    }
-  }
-
-  // Populate substitutions for attributes and named operands and results.
-  for (const auto &namedAttr : op.getAttributes())
-    verifyCtx.addSubst(namedAttr.name,
-                       formatv("this->getAttr(\"{0}\")", namedAttr.name));
-  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
-    auto &value = op.getOperand(i);
-    if (value.name.empty())
-      continue;
-
-    if (value.isVariadic())
-      verifyCtx.addSubst(value.name, formatv("this->getODSOperands({0})", i));
-    else
-      verifyCtx.addSubst(value.name,
-                         formatv("(*this->getODSOperands({0}).begin())", i));
-  }
-  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
-    auto &value = op.getResult(i);
-    if (value.name.empty())
-      continue;
-
-    if (value.isVariadic())
-      verifyCtx.addSubst(value.name, formatv("this->getODSResults({0})", i));
-    else
-      verifyCtx.addSubst(value.name,
-                         formatv("(*this->getODSResults({0}).begin())", i));
-  }
-
-  // Verify the attributes have the correct type.
-  for (const auto &namedAttr : op.getAttributes()) {
-    const auto &attr = namedAttr.attr;
-    if (attr.isDerivedAttr())
-      continue;
-
-    auto attrName = namedAttr.name;
-    // Prefix with `tblgen_` to avoid hiding the attribute accessor.
-    auto varName = tblgenNamePrefix + attrName;
-    body << formatv("  auto {0} = this->getAttr(\"{1}\");\n", varName,
-                    attrName);
-
-    bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
-    if (allowMissingAttr) {
-      // If the attribute has a default value, then only verify the predicate if
-      // set. This does effectively assume that the default value is valid.
-      // TODO: verify the debug value is valid (perhaps in debug mode only).
-      body << "  if (" << varName << ") {\n";
-    } else {
-      body << "  if (!" << varName
-           << ") return emitOpError(\"requires attribute '" << attrName
-           << "'\");\n  {\n";
-    }
-
-    auto attrPred = attr.getPredicate();
-    if (!attrPred.isNull()) {
-      body << tgfmt(
-          "    if (!($0)) return emitOpError(\"attribute '$1' "
-          "failed to satisfy constraint: $2\");\n",
-          /*ctx=*/nullptr,
-          tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
-          attrName, attr.getDescription());
-    }
-
-    body << "  }\n";
-  }
+  auto *valueInit = def.getValueInit("verifier");
+  CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
+  bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
+  populateSubstitutions(op, "this->getAttr", "this->getODSOperands",
+                        "this->getODSResults", verifyCtx);
 
+  genAttributeVerifier(op, "this->getAttr", "emitOpError(",
+                       /*emitVerificationRequiringOp=*/true, verifyCtx, body);
   genOperandResultVerifier(body, op.getOperands(), "operand");
   genOperandResultVerifier(body, op.getResults(), "result");
 
   for (auto &trait : op.getTraits()) {
     if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
-      body << tgfmt("  if (!($0)) {\n    "
-                    "return emitOpError(\"failed to verify that $1\");\n  }\n",
+      body << tgfmt("  if (!($0))\n    "
+                    "return emitOpError(\"failed to verify that $1\");\n",
                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
                     t->getDescription());
     }
@@ -1890,12 +1917,17 @@ class OpOperandAdaptorEmitter {
 private:
   explicit OpOperandAdaptorEmitter(const Operator &op);
 
+  // Add verification function. This generates a verify method for the adaptor
+  // which verifies all the op-independent attribute constraints.
+  void addVerification();
+
+  const Operator &op;
   Class adaptor;
 };
 } // end namespace
 
 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
-    : adaptor(op.getAdaptorName()) {
+    : op(op), adaptor(op.getAdaptorName()) {
   adaptor.newField("ValueRange", "odsOperands");
   adaptor.newField("DictionaryAttr", "odsAttrs");
   const auto *attrSizedOperands =
@@ -1957,6 +1989,50 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
     if (!attr.isDerivedAttr())
       emitAttr(name, attr);
   }
+
+  // Add verification function.
+  addVerification();
+}
+
+void OpOperandAdaptorEmitter::addVerification() {
+  auto &method = adaptor.newMethod("LogicalResult", "verify",
+                                   /*params=*/"Location loc");
+  auto &body = method.body();
+
+  const char *checkAttrSizedValueSegmentsCode = R"(
+  {
+    auto sizeAttr = odsAttrs.get("{0}").cast<DenseIntElementsAttr>();
+    auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
+    if (numElements != {1})
+      return emitError(loc, "'{0}' attribute for specifying {2} segments "
+                       "must have {1} elements");
+  }
+  )";
+
+  // Verify a few traits first so that we can use
+  // getODSOperands()/getODSResults() in the rest of the verifier.
+  for (auto &trait : op.getTraits()) {
+    if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
+      if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
+        body << formatv(checkAttrSizedValueSegmentsCode,
+                        "operand_segment_sizes", op.getNumOperands(),
+                        "operand");
+      } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
+        body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
+                        op.getNumResults(), "result");
+      }
+    }
+  }
+
+  FmtContext verifyCtx;
+  populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
+                        "<no results should be genarated>", verifyCtx);
+  genAttributeVerifier(op, "odsAttrs.get",
+                       Twine("emitError(loc, \"'") + op.getOperationName() +
+                           "' op \"",
+                       /*emitVerificationRequiringOp*/ false, verifyCtx, body);
+
+  body << "  return success();";
 }
 
 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {


        


More information about the Mlir-commits mailing list