[Mlir-commits] [mlir] 6b6c966 - [mlir][ODS] Add a new trait `TypesMatchWith`

River Riddle llvmlistbot at llvm.org
Wed Feb 19 10:22:41 PST 2020


Author: River Riddle
Date: 2020-02-19T10:18:58-08:00
New Revision: 6b6c96695c0054ebad6816171ef89d5cb76a058b

URL: https://github.com/llvm/llvm-project/commit/6b6c96695c0054ebad6816171ef89d5cb76a058b
DIFF: https://github.com/llvm/llvm-project/commit/6b6c96695c0054ebad6816171ef89d5cb76a058b.diff

LOG: [mlir][ODS] Add a new trait `TypesMatchWith`

Summary:
This trait takes three arguments: lhs, rhs, transformer. It verifies that the type of 'rhs' matches the type of 'lhs' when the given 'transformer' is applied to 'lhs'. This allows for adding constraints like: "the type of 'a' must match the element type of 'b'". A followup revision will add support in the declarative parser for using these equality constraints to port more c++ parsers to the declarative form.

Differential Revision: https://reviews.llvm.org/D74647

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Ops.td
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/Support/STLExtras.h
    mlir/lib/Dialect/StandardOps/Ops.cpp
    mlir/test/Dialect/VectorOps/invalid.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 73a88681da3c..b6186bd4ec76 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -310,7 +310,15 @@ def CallOp : Std_Op<"call", [CallOpInterface]> {
   }];
 }
 
-def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
+def CallIndirectOp : Std_Op<"call_indirect", [
+      CallOpInterface,
+      TypesMatchWith<"callee input types match argument types",
+                     "callee", "operands",
+                     "$_self.cast<FunctionType>().getInputs()">,
+      TypesMatchWith<"callee result types match result types",
+                     "callee", "results",
+                     "$_self.cast<FunctionType>().getResults()">
+    ]> {
   let summary = "indirect call operation";
   let description = [{
     The "call_indirect" operation represents an indirect call to a value of
@@ -322,7 +330,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
   }];
 
   let arguments = (ins FunctionType:$callee, Variadic<AnyType>:$operands);
-  let results = (outs Variadic<AnyType>);
+  let results = (outs Variadic<AnyType>:$results);
 
   let builders = [OpBuilder<
     "Builder *, OperationState &result, Value callee,"
@@ -347,6 +355,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
     CallInterfaceCallable getCallableForCallee() { return getCallee(); }
   }];
 
+  let verifier = ?;
   let hasCanonicalizer = 1;
 }
 
@@ -361,7 +370,10 @@ def CeilFOp : FloatUnaryOp<"ceilf"> {
 }
 
 def CmpFOp : Std_Op<"cmpf",
-    [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
+    [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
+     TypesMatchWith<
+       "result type has i1 element type and same shape as operands",
+       "lhs", "result", "getI1SameShape($_self)">]> {
   let summary = "floating-point comparison operation";
   let description = [{
     The "cmpf" operation compares its two operands according to the float
@@ -386,7 +398,7 @@ def CmpFOp : Std_Op<"cmpf",
   }];
 
   let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
-  let results = (outs BoolLike);
+  let results = (outs BoolLike:$result);
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, CmpFPredicate predicate,"
@@ -426,7 +438,10 @@ def CmpIPredicateAttr : I64EnumAttr<
 }
 
 def CmpIOp : Std_Op<"cmpi",
-    [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
+    [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
+     TypesMatchWith<
+       "result type has i1 element type and same shape as operands",
+       "lhs", "result", "getI1SameShape($_self)">]> {
   let summary = "integer comparison operation";
   let description = [{
     The "cmpi" operation compares its two operands according to the integer
@@ -454,7 +469,7 @@ def CmpIOp : Std_Op<"cmpi",
       IntegerLike:$lhs,
       IntegerLike:$rhs
   );
-  let results = (outs BoolLike);
+  let results = (outs BoolLike:$result);
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, CmpIPredicate predicate,"
@@ -708,7 +723,11 @@ def ExpOp : FloatUnaryOp<"exp"> {
   let summary = "base-e exponential of the specified value";
 }
 
-def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
+def ExtractElementOp : Std_Op<"extract_element",
+    [NoSideEffect,
+     TypesMatchWith<"result type matches element type of aggregate",
+                    "aggregate", "result",
+                    "$_self.cast<ShapedType>().getElementType()">]> {
   let summary = "element extract operation";
   let description = [{
     The "extract_element" op reads a tensor or vector and returns one element
@@ -723,7 +742,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
 
   let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
                        Variadic<Index>:$indices);
-  let results = (outs AnyType);
+  let results = (outs AnyType:$result);
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, Value aggregate,"
@@ -796,7 +815,10 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
   let hasFolder = 0;
 }
 
-def LoadOp : Std_Op<"load"> {
+def LoadOp : Std_Op<"load",
+     [TypesMatchWith<"result type matches element type of 'memref'",
+                     "memref", "result",
+                     "$_self.cast<MemRefType>().getElementType()">]> {
   let summary = "load operation";
   let description = [{
     The "load" op reads an element from a memref specified by an index list. The
@@ -809,7 +831,7 @@ def LoadOp : Std_Op<"load"> {
   }];
 
   let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
-  let results = (outs AnyType);
+  let results = (outs AnyType:$result);
 
   let builders = [OpBuilder<
     "Builder *, OperationState &result, Value memref,"
@@ -1029,7 +1051,11 @@ def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
   >];
 }
 
-def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
+def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
+     AllTypesMatch<["true_value", "false_value", "result"]>,
+     TypesMatchWith<"condition type matches i1 equivalent of result type",
+                     "result", "condition",
+                     "getI1SameShape($_self)">]> {
   let summary = "select operation";
   let description = [{
     The "select" operation chooses one value based on a binary condition
@@ -1044,9 +1070,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
       %3 = select %2, %0, %1 : i32
   }];
 
-  let arguments = (ins BoolLike:$condition, AnyType:$true_value,
-                       AnyType:$false_value);
-  let results = (outs AnyType);
+  let arguments = (ins BoolLike:$condition, IntegerOrFloatLike:$true_value,
+                       IntegerOrFloatLike:$false_value);
+  let results = (outs IntegerOrFloatLike:$result);
+  let verifier = ?;
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, Value condition,"
@@ -1158,7 +1185,10 @@ def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
   let hasFolder = 0;
 }
 
-def SplatOp : Std_Op<"splat", [NoSideEffect]> {
+def SplatOp : Std_Op<"splat", [NoSideEffect,
+     TypesMatchWith<"operand type matches element type of result",
+                    "aggregate", "input",
+                    "$_self.cast<ShapedType>().getElementType()">]> {
   let summary = "splat or broadcast operation";
   let description = [{
     The "splat" op reads a value of integer or float type and broadcasts it into
@@ -1193,7 +1223,10 @@ def SplatOp : Std_Op<"splat", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def StoreOp : Std_Op<"store"> {
+def StoreOp : Std_Op<"store",
+     [TypesMatchWith<"type of 'value' matches element type of 'memref'",
+                     "memref", "value",
+                     "$_self.cast<MemRefType>().getElementType()">]> {
   let summary = "store operation";
   let description = [{
     The "store" op writes an element to a memref specified by an index list.
@@ -1455,7 +1488,10 @@ def TensorCastOp : CastOp<"tensor_cast"> {
 }
 
 def TensorLoadOp : Std_Op<"tensor_load",
-    [SameOperandsAndResultShape, SameOperandsAndResultElementType]> {
+    [SameOperandsAndResultShape, SameOperandsAndResultElementType,
+     TypesMatchWith<"result type matches tensor equivalent of 'memref'",
+                    "memref", "result",
+                    "getTensorTypeFromMemRefType($_self)">]> {
   let summary = "tensor load operation";
   let description = [{
     The "tensor_load" operation creates a tensor from a memref, making an
@@ -1466,8 +1502,8 @@ def TensorLoadOp : Std_Op<"tensor_load",
        %12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
   }];
 
-  let arguments = (ins AnyMemRef);
-  let results = (outs AnyTensor);
+  let arguments = (ins AnyMemRef:$memref);
+  let results = (outs AnyTensor:$result);
   // TensorLoadOp is fully verified by traits.
   let verifier = ?;
 
@@ -1488,7 +1524,10 @@ def TensorLoadOp : Std_Op<"tensor_load",
 }
 
 def TensorStoreOp : Std_Op<"tensor_store",
-    [SameOperandsShape, SameOperandsElementType]> {
+    [SameOperandsShape, SameOperandsElementType,
+     TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'",
+                    "memref", "tensor",
+                    "getTensorTypeFromMemRefType($_self)">]> {
   let summary = "tensor store operation";
   let description = [{
     The "tensor_store" operation stores the contents of a tensor into a memref.

diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 5ead87681ad0..10837b94bfc2 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -339,10 +339,11 @@ def Vector_ShuffleOp :
 
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [NoSideEffect,
-     PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+     TypesMatchWith<"result type matches element type of vector operand",
+                    "vector", "result",
+                    "$_self.cast<ShapedType>().getElementType()">]>,
     Arguments<(ins AnyVector:$vector, AnyInteger:$position)>,
-    Results<(outs AnyType)> {
+    Results<(outs AnyType:$result)> {
   let summary = "extractelement operation";
   let description = [{
     Takes an 1-D vector and a dynamic index position and extracts the
@@ -482,12 +483,12 @@ def Vector_FMAOp :
 
 def Vector_InsertElementOp :
   Vector_Op<"insertelement", [NoSideEffect,
-     PredOpTrait<"source operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"dest operand and result have same type",
-                 TCresIsSameAsOpBase<0, 1>>]>,
+     TypesMatchWith<"source operand type matches element type of result",
+                    "result", "source",
+                    "$_self.cast<ShapedType>().getElementType()">,
+     AllTypesMatch<["dest", "result"]>]>,
      Arguments<(ins AnyType:$source, AnyVector:$dest, AnyInteger:$position)>,
-     Results<(outs AnyVector)> {
+     Results<(outs AnyVector:$result)> {
   let summary = "insertelement operation";
   let description = [{
     Takes a scalar source, an 1-D destination vector and a dynamic index

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 88240aed1949..2f02a885242b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -595,6 +595,11 @@ def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
         VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
     "floating-point-like">;
 
+// Type constraint for integer-like or float-like types.
+def IntegerOrFloatLike : TypeConstraint<Or<[IntegerLike.predicate,
+                                            FloatLike.predicate]>,
+    "integer-like or floating-point-like">;
+
 
 //===----------------------------------------------------------------------===//
 // Attribute definitions
@@ -1716,6 +1721,17 @@ class AllShapesMatch<list<string> names> :
 class AllTypesMatch<list<string> names> :
     AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
 
+// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+class TypesMatchWith<string description, string lhsArg, string rhsArg,
+                     string transform> :
+    PredOpTrait<description, CPred<
+      !subst("$_self", "$" # lhsArg # ".getType()", transform)
+      # " == $" # rhsArg # ".getType()">> {
+  string lhs = lhsArg;
+  string rhs = rhsArg;
+  string transformer = transform;
+}
+
 // Type Constraint operand `idx`'s Element type is `type`.
 class TCopVTEtIs<int idx, Type type> : And<[
    CPred<"$_op.getNumOperands() > " # idx>,

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index a9a6ff46242c..a7999607d28c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -610,6 +610,12 @@ class ValueTypeRange final
   ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
 };
 
+template <typename RangeT>
+inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
+  return lhs.size() == llvm::size(rhs) &&
+         std::equal(lhs.begin(), lhs.end(), rhs.begin());
+}
+
 //===----------------------------------------------------------------------===//
 // OperandRange
 
@@ -625,6 +631,7 @@ class OperandRange final
   using type_iterator = ValueTypeIterator<iterator>;
   using type_range = ValueTypeRange<OperandRange>;
   type_range getTypes() const { return {begin(), end()}; }
+  auto getType() const { return getTypes(); }
 
 private:
   /// See `detail::indexed_accessor_range_base` for details.
@@ -656,6 +663,7 @@ class ResultRange final
   using type_iterator = ArrayRef<Type>::iterator;
   using type_range = ArrayRef<Type>;
   type_range getTypes() const;
+  auto getType() const { return getTypes(); }
 
 private:
   /// See `indexed_accessor_range` for details.
@@ -725,6 +733,7 @@ class ValueRange final
   using type_iterator = ValueTypeIterator<iterator>;
   using type_range = ValueTypeRange<ValueRange>;
   type_range getTypes() const { return {begin(), end()}; }
+  auto getType() const { return getTypes(); }
 
 private:
   using OwnerT = detail::ValueRangeOwner;

diff  --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h
index 14336aad6a25..527234921d95 100644
--- a/mlir/include/mlir/Support/STLExtras.h
+++ b/mlir/include/mlir/Support/STLExtras.h
@@ -233,6 +233,12 @@ class indexed_accessor_range_base {
     return DerivedT::dereference_iterator(base, index);
   }
 
+  /// Compare this range with another.
+  template <typename OtherT> bool operator==(const OtherT &other) {
+    return size() == llvm::size(other) &&
+           std::equal(begin(), end(), other.begin());
+  }
+
   /// Return the size of this range.
   size_t size() const { return count; }
 

diff  --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 9a0bcb9d58a0..7318861938a4 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -528,30 +528,6 @@ static void print(OpAsmPrinter &p, CallIndirectOp op) {
   p << " : " << op.getCallee().getType();
 }
 
-static LogicalResult verify(CallIndirectOp op) {
-  // The callee must be a function.
-  auto fnType = op.getCallee().getType().dyn_cast<FunctionType>();
-  if (!fnType)
-    return op.emitOpError("callee must have function type");
-
-  // Verify that the operand and result types match the callee.
-  if (fnType.getNumInputs() != op.getNumOperands() - 1)
-    return op.emitOpError("incorrect number of operands for callee");
-
-  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
-    if (op.getOperand(i + 1).getType() != fnType.getInput(i))
-      return op.emitOpError("operand type mismatch");
-
-  if (fnType.getNumResults() != op.getNumResults())
-    return op.emitOpError("incorrect number of results for callee");
-
-  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
-    if (op.getResult(i).getType() != fnType.getResult(i))
-      return op.emitOpError("result type mismatch");
-
-  return success();
-}
-
 void CallIndirectOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<SimplifyIndirectCallWithKnownCallee>(context);
@@ -562,8 +538,8 @@ void CallIndirectOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 // Return the type of the same shape (scalar, vector or tensor) containing i1.
-static Type getCheckedI1SameShape(Builder *build, Type type) {
-  auto i1Type = build->getI1Type();
+static Type getCheckedI1SameShape(Type type) {
+  auto i1Type = IntegerType::get(1, type.getContext());
   if (type.isIntOrIndexOrFloat())
     return i1Type;
   if (auto tensorType = type.dyn_cast<RankedTensorType>())
@@ -575,8 +551,8 @@ static Type getCheckedI1SameShape(Builder *build, Type type) {
   return Type();
 }
 
-static Type getI1SameShape(Builder *build, Type type) {
-  Type res = getCheckedI1SameShape(build, type);
+static Type getI1SameShape(Type type) {
+  Type res = getCheckedI1SameShape(type);
   assert(res && "expected type with valid i1 shape");
   return res;
 }
@@ -588,7 +564,7 @@ static Type getI1SameShape(Builder *build, Type type) {
 static void buildCmpIOp(Builder *build, OperationState &result,
                         CmpIPredicate predicate, Value lhs, Value rhs) {
   result.addOperands({lhs, rhs});
-  result.types.push_back(getI1SameShape(build, lhs.getType()));
+  result.types.push_back(getI1SameShape(lhs.getType()));
   result.addAttribute(
       CmpIOp::getPredicateAttrName(),
       build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
@@ -618,7 +594,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
            << "unknown comparison predicate \"" << predicateName << "\"";
 
   auto builder = parser.getBuilder();
-  Type i1Type = getCheckedI1SameShape(&builder, type);
+  Type i1Type = getCheckedI1SameShape(type);
   if (!i1Type)
     return parser.emitError(parser.getNameLoc(),
                             "expected type with valid i1 shape");
@@ -741,7 +717,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
 static void buildCmpFOp(Builder *build, OperationState &result,
                         CmpFPredicate predicate, Value lhs, Value rhs) {
   result.addOperands({lhs, rhs});
-  result.types.push_back(getI1SameShape(build, lhs.getType()));
+  result.types.push_back(getI1SameShape(lhs.getType()));
   result.addAttribute(
       CmpFOp::getPredicateAttrName(),
       build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
@@ -772,7 +748,7 @@ static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
                                 "\"");
 
   auto builder = parser.getBuilder();
-  Type i1Type = getCheckedI1SameShape(&builder, type);
+  Type i1Type = getCheckedI1SameShape(type);
   if (!i1Type)
     return parser.emitError(parser.getNameLoc(),
                             "expected type with valid i1 shape");
@@ -1534,13 +1510,8 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
 }
 
 static LogicalResult verify(ExtractElementOp op) {
-  auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
-
-  // This should be possible with tablegen type constraints
-  if (op.getType() != aggregateType.getElementType())
-    return op.emitOpError("result type must match element type of aggregate");
-
   // Verify the # indices match if we have a ranked type.
+  auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
   if (aggregateType.hasRank() &&
       aggregateType.getRank() != op.getNumOperands() - 1)
     return op.emitOpError("incorrect number of indices for extract_element");
@@ -1628,12 +1599,8 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static LogicalResult verify(LoadOp op) {
-  if (op.getType() != op.getMemRefType().getElementType())
-    return op.emitOpError("result type must match element type of memref");
-
   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
     return op.emitOpError("incorrect number of indices for load");
-
   return success();
 }
 
@@ -1943,7 +1910,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
       parser.parseColonType(type))
     return failure();
 
-  auto i1Type = getCheckedI1SameShape(&parser.getBuilder(), type);
+  auto i1Type = getCheckedI1SameShape(type);
   if (!i1Type)
     return parser.emitError(parser.getNameLoc(),
                             "expected type with valid i1 shape");
@@ -1959,17 +1926,6 @@ static void print(OpAsmPrinter &p, SelectOp op) {
   p.printOptionalAttrDict(op.getAttrs());
 }
 
-static LogicalResult verify(SelectOp op) {
-  auto trueType = op.getTrueValue().getType();
-  auto falseType = op.getFalseValue().getType();
-
-  if (trueType != falseType)
-    return op.emitOpError(
-        "requires 'true' and 'false' arguments to be of the same type");
-
-  return success();
-}
-
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   auto condition = getCondition();
 
@@ -2087,11 +2043,6 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static LogicalResult verify(StoreOp op) {
-  // First operand must have same type as memref element type.
-  if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
-    return op.emitOpError(
-        "first operand must have same type memref element type");
-
   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
     return op.emitOpError("store index operand count not equal to memref rank");
 
@@ -2198,10 +2149,10 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
 // Helpers for Tensor[Load|Store]Op
 //===----------------------------------------------------------------------===//
 
-static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
+static Type getTensorTypeFromMemRefType(Type type) {
   if (auto memref = type.dyn_cast<MemRefType>())
     return RankedTensorType::get(memref.getShape(), memref.getElementType());
-  return b.getNoneType();
+  return NoneType::get(type.getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2218,13 +2169,12 @@ static ParseResult parseTensorLoadOp(OpAsmParser &parser,
                                      OperationState &result) {
   OpAsmParser::OperandType op;
   Type type;
-  return failure(parser.parseOperand(op) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(type) ||
-                 parser.resolveOperand(op, type, result.operands) ||
-                 parser.addTypeToList(
-                     getTensorTypeFromMemRefType(parser.getBuilder(), type),
-                     result.types));
+  return failure(
+      parser.parseOperand(op) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(type) ||
+      parser.resolveOperand(op, type, result.operands) ||
+      parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types));
 }
 
 //===----------------------------------------------------------------------===//
@@ -2246,9 +2196,8 @@ static ParseResult parseTensorStoreOp(OpAsmParser &parser,
       parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(type) ||
-      parser.resolveOperands(
-          ops, {getTensorTypeFromMemRefType(parser.getBuilder(), type), type},
-          loc, result.operands));
+      parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type},
+                             loc, result.operands));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 2a45820be7b0..3743a5cfcf3c 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -131,9 +131,9 @@ func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
 // -----
 
 func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
-  %c = constant 3 : index
-  // expected-error at +1 {{'vector.insertelement' op failed to verify that source operand and result have same element type}}
-  %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, index) -> (vector<4xf32>)
+  %c = constant 3 : i32
+  // expected-error at +1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}}
+  %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>)
 }
 
 // -----

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 9f48d9d6bc70..cebc87ea94a1 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -275,7 +275,7 @@ func @func_with_ops(i32, i32, i32) {
 
 func @func_with_ops(i1, i32, i64) {
 ^bb0(%cond : i1, %t : i32, %f : i64):
-  // expected-error at +1 {{'true' and 'false' arguments to be of the same type}}
+  // expected-error at +1 {{all of {true_value, false_value, result} have same type}}
   %r = "std.select"(%cond, %t, %f) : (i1, i32, i64) -> i32
 }
 
@@ -460,7 +460,7 @@ func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
 // -----
 
 func @extract_element_element_result_type_mismatch(%v : vector<3xf32>, %i : index) {
-  // expected-error at +1 {{result type must match element type of aggregate}}
+  // expected-error at +1 {{result type matches element type of aggregate}}
   %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, index) -> f64
   return
 }

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6af100ed95fa..a5962957be77 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1127,29 +1127,55 @@ void OpEmitter::genVerifier() {
   auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
   auto &body = method.body();
 
+  const char *checkAttrSizedValueSegmentsCode = R"(
+  auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
+  auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
+  if (numElements != {1}) {{
+    return emitOpError("'{0}' attribute for specifying {2} segments "
+                       "must have {1} elements");
+  }
+  )";
+
+  // Verify a few traits first so that we can use
+  // getODSOperands()/getODSResults() in the rest of the verifier.
+  for (auto &trait : op.getTraits()) {
+    if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
+      if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
+        body << formatv(checkAttrSizedValueSegmentsCode,
+                        "operand_segment_sizes", op.getNumOperands(),
+                        "operand");
+      } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
+        body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
+                        op.getNumResults(), "result");
+      }
+    }
+  }
+
   // Populate substitutions for attributes and named operands and results.
   for (const auto &namedAttr : op.getAttributes())
     verifyCtx.addSubst(namedAttr.name,
                        formatv("this->getAttr(\"{0}\")", namedAttr.name));
   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
     auto &value = op.getOperand(i);
-    // Skip from from first variadic operands for now. Else getOperand index
-    // used below doesn't match.
+    if (value.name.empty())
+      continue;
+
     if (value.isVariadic())
-      break;
-    if (!value.name.empty())
+      verifyCtx.addSubst(value.name, formatv("this->getODSOperands({0})", i));
+    else
       verifyCtx.addSubst(value.name,
-                         formatv("this->getOperation()->getOperand({0})", i));
+                         formatv("(*this->getODSOperands({0}).begin())", i));
   }
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
     auto &value = op.getResult(i);
-    // Skip from from first variadic results for now. Else getResult index used
-    // below doesn't match.
+    if (value.name.empty())
+      continue;
+
     if (value.isVariadic())
-      break;
-    if (!value.name.empty())
+      verifyCtx.addSubst(value.name, formatv("this->getODSResults({0})", i));
+    else
       verifyCtx.addSubst(value.name,
-                         formatv("this->getOperation()->getResult({0})", i));
+                         formatv("(*this->getODSResults({0}).begin())", i));
   }
 
   // Verify the attributes have the correct type.
@@ -1189,14 +1215,8 @@ void OpEmitter::genVerifier() {
     body << "  }\n";
   }
 
-  const char *code = R"(
-  auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
-  auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
-  if (numElements != {1}) {{
-    return emitOpError("'{0}' attribute for specifying {2} segments "
-                       "must have {1} elements");
-  }
-  )";
+  genOperandResultVerifier(body, op.getOperands(), "operand");
+  genOperandResultVerifier(body, op.getResults(), "result");
 
   for (auto &trait : op.getTraits()) {
     if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
@@ -1204,23 +1224,9 @@ void OpEmitter::genVerifier() {
                     "return emitOpError(\"failed to verify that $1\");\n  }\n",
                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
                     t->getDescription());
-    } else if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
-      if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
-        body << formatv(code, "operand_segment_sizes", op.getNumOperands(),
-                        "operand");
-      } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
-        body << formatv(code, "result_segment_sizes", op.getNumResults(),
-                        "result");
-      }
     }
   }
 
-  // These should happen after we verified the traits because
-  // getODSOperands()/getODSResults() may depend on traits (e.g.,
-  // AttrSizedOperandSegments/AttrSizedResultSegments).
-  genOperandResultVerifier(body, op.getOperands(), "operand");
-  genOperandResultVerifier(body, op.getResults(), "result");
-
   genRegionVerifier(body);
 
   if (hasCustomVerify) {


        


More information about the Mlir-commits mailing list