[Mlir-commits] [mlir] 26222db - [mlir][DeclarativeParser] Add support for the TypesMatchWith trait.

River Riddle llvmlistbot at llvm.org
Fri Feb 21 15:17:34 PST 2020


Author: River Riddle
Date: 2020-02-21T15:15:31-08:00
New Revision: 26222db01b079023d0fe3bb60f2c1b38f4f19d5a

URL: https://github.com/llvm/llvm-project/commit/26222db01b079023d0fe3bb60f2c1b38f4f19d5a
DIFF: https://github.com/llvm/llvm-project/commit/26222db01b079023d0fe3bb60f2c1b38f4f19d5a.diff

LOG: [mlir][DeclarativeParser] Add support for the TypesMatchWith trait.

This allows for injecting type constraints that are not direct 1-1 mappings, for example when one type is equal to the element type of another. This allows for moving over several more parsers to the declarative form.

Differential Revision: https://reviews.llvm.org/D74648

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/VectorOps/VectorOps.cpp
    mlir/test/IR/invalid-ops.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index a0b739ea6ec7..fe28f8d7143f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -357,6 +357,8 @@ def CallIndirectOp : Std_Op<"call_indirect", [
 
   let verifier = ?;
   let hasCanonicalizer = 1;
+
+  let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
 }
 
 def CeilFOp : FloatUnaryOp<"ceilf"> {
@@ -490,6 +492,8 @@ def CmpIOp : Std_Op<"cmpi",
   let verifier = [{ return success(); }];
 
   let hasFolder = 1;
+
+  let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
 }
 
 def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
@@ -761,6 +765,10 @@ def ExtractElementOp : Std_Op<"extract_element",
   }];
 
   let hasFolder = 1;
+
+  let assemblyFormat = [{
+    $aggregate `[` $indices `]` attr-dict `:` type($aggregate)
+  }];
 }
 
 def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
@@ -853,6 +861,8 @@ def LoadOp : Std_Op<"load",
   }];
 
   let hasFolder = 1;
+
+  let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
 }
 
 def LogOp : FloatUnaryOp<"log"> {
@@ -1090,6 +1100,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
   }];
 
   let hasFolder = 1;
+
+  let assemblyFormat = [{
+    $condition `,` $true_value `,` $false_value attr-dict `:` type($result)
+  }];
 }
 
 def SignExtendIOp : Std_Op<"sexti",
@@ -1222,6 +1236,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
                   [{ build(builder, result, aggregateType, element); }]>];
 
   let hasFolder = 1;
+
+  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
 }
 
 def StoreOp : Std_Op<"store",
@@ -1264,6 +1280,10 @@ def StoreOp : Std_Op<"store",
   }];
 
   let hasFolder = 1;
+
+  let assemblyFormat = [{
+    $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
+  }];
 }
 
 def SubFOp : FloatArithmeticOp<"subf"> {
@@ -1517,11 +1537,12 @@ def TensorLoadOp : Std_Op<"tensor_load",
       result.addTypes(resultType);
   }]>];
 
-
   let extraClassDeclaration = [{
     /// The result of a tensor_load is always a tensor.
     TensorType getType() { return getResult().getType().cast<TensorType>(); }
   }];
+
+  let assemblyFormat = "$memref attr-dict `:` type($memref)";
 }
 
 def TensorStoreOp : Std_Op<"tensor_store",
@@ -1545,6 +1566,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
   let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
   // TensorStoreOp is fully verified by traits.
   let verifier = ?;
+
+  let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
 }
 
 def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {

diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index ce6029a5d497..70917ff2b882 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -363,6 +363,10 @@ def Vector_ExtractElementOp :
       return vector().getType().cast<VectorType>();
     }
   }];
+
+  let assemblyFormat = [{
+    $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
+  }];
 }
 
 def Vector_ExtractOp :
@@ -512,6 +516,11 @@ def Vector_InsertElementOp :
       return dest().getType().cast<VectorType>();
     }
   }];
+
+  let assemblyFormat = [{
+    $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
+    type($result)
+  }];
 }
 
 def Vector_InsertOp :

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0d0736b4512f..08fac5ea19ef 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -496,6 +496,12 @@ class OpAsmParser {
         return failure();
     return success();
   }
+  template <typename Operands>
+  ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
+                              SmallVectorImpl<Value> &result) {
+    return resolveOperands(std::forward<Operands>(operands),
+                           ArrayRef<Type>(type), loc, result);
+  }
   template <typename Operands, typename Types>
   ParseResult resolveOperands(Operands &&operands, Types &&types,
                               llvm::SMLoc loc, SmallVectorImpl<Value> &result) {

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index b0c97e2dc1a4..89fffa0df60d 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -294,6 +294,11 @@ struct OperationState {
   void addTypes(ArrayRef<Type> newTypes) {
     types.append(newTypes.begin(), newTypes.end());
   }
+  template <typename RangeT>
+  std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value>
+  addTypes(RangeT &&newTypes) {
+    types.append(newTypes.begin(), newTypes.end());
+  }
 
   /// Add an attribute with the specified name.
   void addAttribute(StringRef name, Attribute attr) {

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 42daf193b9cb..aa0a42812342 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -505,29 +505,6 @@ struct SimplifyIndirectCallWithKnownCallee
 };
 } // end anonymous namespace.
 
-static ParseResult parseCallIndirectOp(OpAsmParser &parser,
-                                       OperationState &result) {
-  FunctionType calleeType;
-  OpAsmParser::OperandType callee;
-  llvm::SMLoc operandsLoc;
-  SmallVector<OpAsmParser::OperandType, 4> operands;
-  return failure(
-      parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
-      parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(calleeType) ||
-      parser.resolveOperand(callee, calleeType, result.operands) ||
-      parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
-                             result.operands) ||
-      parser.addTypesToList(calleeType.getResults(), result.types));
-}
-
-static void print(OpAsmPrinter &p, CallIndirectOp op) {
-  p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
-  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
-  p << " : " << op.getCallee().getType();
-}
-
 void CallIndirectOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<SimplifyIndirectCallWithKnownCallee>(context);
@@ -570,55 +547,6 @@ static void buildCmpIOp(Builder *build, OperationState &result,
       build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
 }
 
-static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> ops;
-  SmallVector<NamedAttribute, 4> attrs;
-  Attribute predicateNameAttr;
-  Type type;
-  if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
-                            attrs) ||
-      parser.parseComma() || parser.parseOperandList(ops, 2) ||
-      parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
-      parser.resolveOperands(ops, type, result.operands))
-    return failure();
-
-  if (!predicateNameAttr.isa<StringAttr>())
-    return parser.emitError(parser.getNameLoc(),
-                            "expected string comparison predicate attribute");
-
-  // Rewrite string attribute to an enum value.
-  StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
-  Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
-  if (!predicate.hasValue())
-    return parser.emitError(parser.getNameLoc())
-           << "unknown comparison predicate \"" << predicateName << "\"";
-
-  auto builder = parser.getBuilder();
-  Type i1Type = getCheckedI1SameShape(type);
-  if (!i1Type)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected type with valid i1 shape");
-
-  attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
-  result.attributes = attrs;
-
-  result.addTypes({i1Type});
-  return success();
-}
-
-static void print(OpAsmPrinter &p, CmpIOp op) {
-  p << "cmpi ";
-
-  Builder b(op.getContext());
-  auto predicateValue =
-      op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
-  p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
-    << '"' << ", " << op.lhs() << ", " << op.rhs();
-  p.printOptionalAttrDict(op.getAttrs(),
-                          /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
-  p << " : " << op.lhs().getType();
-}
-
 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
 // comparison predicates.
 static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
@@ -1486,30 +1414,6 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, ExtractElementOp op) {
-  p << "extract_element " << op.getAggregate() << '[' << op.getIndices();
-  p << ']';
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getAggregate().getType();
-}
-
-static ParseResult parseExtractElementOp(OpAsmParser &parser,
-                                         OperationState &result) {
-  OpAsmParser::OperandType aggregateInfo;
-  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  ShapedType type;
-
-  auto indexTy = parser.getBuilder().getIndexType();
-  return failure(
-      parser.parseOperand(aggregateInfo) ||
-      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(aggregateInfo, type, result.operands) ||
-      parser.resolveOperands(indexInfo, indexTy, result.operands) ||
-      parser.addTypeToList(type.getElementType(), result.types));
-}
-
 static LogicalResult verify(ExtractElementOp op) {
   // Verify the # indices match if we have a ranked type.
   auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
@@ -1577,28 +1481,6 @@ OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
 // LoadOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, LoadOp op) {
-  p << "load " << op.getMemRef() << '[' << op.getIndices() << ']';
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getMemRefType();
-}
-
-static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType memrefInfo;
-  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  MemRefType type;
-
-  auto indexTy = parser.getBuilder().getIndexType();
-  return failure(
-      parser.parseOperand(memrefInfo) ||
-      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(memrefInfo, type, result.operands) ||
-      parser.resolveOperands(indexInfo, indexTy, result.operands) ||
-      parser.addTypeToList(type.getElementType(), result.types));
-}
-
 static LogicalResult verify(LoadOp op) {
   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
     return op.emitOpError("incorrect number of indices for load");
@@ -1902,31 +1784,6 @@ bool SIToFPOp::areCastCompatible(Type a, Type b) {
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 3> ops;
-  SmallVector<NamedAttribute, 4> attrs;
-  Type type;
-  if (parser.parseOperandList(ops, 3) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type))
-    return failure();
-
-  auto i1Type = getCheckedI1SameShape(type);
-  if (!i1Type)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected type with valid i1 shape");
-
-  std::array<Type, 3> types = {i1Type, type, type};
-  return failure(parser.resolveOperands(ops, types, parser.getNameLoc(),
-                                        result.operands) ||
-                 parser.addTypeToList(type, result.types));
-}
-
-static void print(OpAsmPrinter &p, SelectOp op) {
-  p << "select " << op.getOperands() << " : " << op.getTrueValue().getType();
-  p.printOptionalAttrDict(op.getAttrs());
-}
-
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   auto condition = getCondition();
 
@@ -1968,25 +1825,6 @@ static LogicalResult verify(SignExtendIOp op) {
 // SplatOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, SplatOp op) {
-  p << "splat " << op.getOperand();
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getType();
-}
-
-static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType splatValueInfo;
-  ShapedType shapedType;
-
-  return failure(parser.parseOperand(splatValueInfo) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(shapedType) ||
-                 parser.resolveOperand(splatValueInfo,
-                                       shapedType.getElementType(),
-                                       result.operands) ||
-                 parser.addTypeToList(shapedType, result.types));
-}
-
 static LogicalResult verify(SplatOp op) {
   // TODO: we could replace this by a trait.
   if (op.getOperand().getType() !=
@@ -2017,32 +1855,6 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
 // StoreOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, StoreOp op) {
-  p << "store " << op.getValueToStore();
-  p << ", " << op.getMemRef() << '[' << op.getIndices() << ']';
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getMemRefType();
-}
-
-static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType storeValueInfo;
-  OpAsmParser::OperandType memrefInfo;
-  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  MemRefType memrefType;
-
-  auto indexTy = parser.getBuilder().getIndexType();
-  return failure(
-      parser.parseOperand(storeValueInfo) || parser.parseComma() ||
-      parser.parseOperand(memrefInfo) ||
-      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(memrefType) ||
-      parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
-                            result.operands) ||
-      parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
-      parser.resolveOperands(indexInfo, indexTy, result.operands));
-}
-
 static LogicalResult verify(StoreOp op) {
   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
     return op.emitOpError("store index operand count not equal to memref rank");
@@ -2156,51 +1968,6 @@ static Type getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
-//===----------------------------------------------------------------------===//
-// TensorLoadOp
-//===----------------------------------------------------------------------===//
-
-static void print(OpAsmPrinter &p, TensorLoadOp op) {
-  p << "tensor_load " << op.getOperand();
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.getOperand().getType();
-}
-
-static ParseResult parseTensorLoadOp(OpAsmParser &parser,
-                                     OperationState &result) {
-  OpAsmParser::OperandType op;
-  Type type;
-  return failure(
-      parser.parseOperand(op) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(op, type, result.operands) ||
-      parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types));
-}
-
-//===----------------------------------------------------------------------===//
-// TensorStoreOp
-//===----------------------------------------------------------------------===//
-
-static void print(OpAsmPrinter &p, TensorStoreOp op) {
-  p << "tensor_store " << op.tensor() << ", " << op.memref();
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.memref().getType();
-}
-
-static ParseResult parseTensorStoreOp(OpAsmParser &parser,
-                                      OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> ops;
-  Type type;
-  llvm::SMLoc loc = parser.getCurrentLocation();
-  return failure(
-      parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type},
-                             loc, result.operands));
-}
-
 //===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 35dbc83d4595..53c5cbd57319 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -412,31 +412,6 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
-  p << op.getOperationName() << " " << op.vector() << "[" << op.position()
-    << " : " << op.position().getType() << "]";
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.vector().getType();
-}
-
-static ParseResult parseExtractElementOp(OpAsmParser &parser,
-                                         OperationState &result) {
-  OpAsmParser::OperandType vector, position;
-  Type positionType;
-  VectorType vectorType;
-  if (parser.parseOperand(vector) || parser.parseLSquare() ||
-      parser.parseOperand(position) || parser.parseColonType(positionType) ||
-      parser.parseRSquare() ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(vectorType))
-    return failure();
-  Type resultType = vectorType.getElementType();
-  return failure(
-      parser.resolveOperand(vector, vectorType, result.operands) ||
-      parser.resolveOperand(position, positionType, result.operands) ||
-      parser.addTypeToList(resultType, result.types));
-}
-
 static LogicalResult verify(vector::ExtractElementOp op) {
   VectorType vectorType = op.getVectorType();
   if (vectorType.getRank() != 1)
@@ -715,33 +690,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
 // InsertElementOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, InsertElementOp op) {
-  p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
-    << op.position() << " : " << op.position().getType() << "]";
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : " << op.dest().getType();
-}
-
-static ParseResult parseInsertElementOp(OpAsmParser &parser,
-                                        OperationState &result) {
-  OpAsmParser::OperandType source, dest, position;
-  Type positionType;
-  VectorType destType;
-  if (parser.parseOperand(source) || parser.parseComma() ||
-      parser.parseOperand(dest) || parser.parseLSquare() ||
-      parser.parseOperand(position) || parser.parseColonType(positionType) ||
-      parser.parseRSquare() ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(destType))
-    return failure();
-  Type sourceType = destType.getElementType();
-  return failure(
-      parser.resolveOperand(source, sourceType, result.operands) ||
-      parser.resolveOperand(dest, destType, result.operands) ||
-      parser.resolveOperand(position, positionType, result.operands) ||
-      parser.addTypeToList(destType, result.types));
-}
-
 static LogicalResult verify(InsertElementOp op) {
   auto dstVectorType = op.getDestVectorType();
   if (dstVectorType.getRank() != 1)

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index cebc87ea94a1..20f90c76e3d1 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -226,7 +226,7 @@ func @func_with_ops(i32, i32) {
 // Integer comparisons are not recognized for float types.
 func @func_with_ops(f32, f32) {
 ^bb0(%a : f32, %b : f32):
-  %r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}}
+  %r = cmpi "eq", %a, %b : f32 // expected-error {{'lhs' must be integer-like, but got 'f32'}}
 }
 
 // -----
@@ -298,13 +298,13 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
 // -----
 
 func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
-  // expected-error at +1 {{expected type with valid i1 shape}}
+  // expected-error at +1 {{'result' must be integer-like or floating-point-like, but got '() -> ()'}}
   %sel = select %cond, %idx, %idx : () -> ()
 
 // -----
 
 func @invalid_cmp_shape(%idx : () -> ()) {
-  // expected-error at +1 {{expected type with valid i1 shape}}
+  // expected-error at +1 {{'lhs' must be integer-like, but got '() -> ()'}}
   %cmp = cmpi "eq", %idx, %idx : () -> ()
 
 // -----
@@ -340,7 +340,7 @@ func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
 // -----
 
 func @invalid_cmp_attr(%idx : i32) {
-  // expected-error at +1 {{expected string comparison predicate attribute}}
+  // expected-error at +1 {{invalid kind of attribute specified}}
   %cmp = cmpi i1, %idx, %idx : i32
 
 // -----

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index b8aeb904e187..62653bc2da03 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -219,16 +219,26 @@ struct OperationFormat {
     void setBuilderIdx(int idx) { builderIdx = idx; }
 
     /// Get the variable this type is resolved to, or None.
-    Optional<StringRef> getVariable() const { return variableName; }
-    void setVariable(StringRef variable) { variableName = variable; }
+    const NamedTypeConstraint *getVariable() const { return variable; }
+    Optional<StringRef> getVarTransformer() const {
+      return variableTransformer;
+    }
+    void setVariable(const NamedTypeConstraint *var,
+                     Optional<StringRef> transformer) {
+      variable = var;
+      variableTransformer = transformer;
+    }
 
   private:
     /// If the type is resolved with a buildable type, this is the index into
     /// 'buildableTypes' in the parent format.
     Optional<int> builderIdx;
     /// If the type is resolved based upon another operand or result, this is
-    /// the name of the variable that this type is resolved to.
-    Optional<StringRef> variableName;
+    /// the variable that this type is resolved to.
+    const NamedTypeConstraint *variable;
+    /// If the type is resolved based upon another operand or result, this is
+    /// a transformer to apply to the variable when resolving.
+    Optional<StringRef> variableTransformer;
   };
 
   OperationFormat(const Operator &op)
@@ -487,6 +497,34 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
 
 void OperationFormat::genParserTypeResolution(Operator &op,
                                               OpMethodBody &body) {
+  // If any of type resolutions use transformed variables, make sure that the
+  // types of those variables are resolved.
+  SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
+  FmtContext verifierFCtx;
+  for (TypeResolution &resolver :
+       llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
+    Optional<StringRef> transformer = resolver.getVarTransformer();
+    if (!transformer)
+      continue;
+    // Ensure that we don't verify the same variables twice.
+    const NamedTypeConstraint *variable = resolver.getVariable();
+    if (!verifiedVariables.insert(variable).second)
+      continue;
+
+    auto constraint = variable->constraint;
+    body << "  for (Type type : " << variable->name << "Types) {\n"
+         << "    (void)type;\n"
+         << "    if (!("
+         << tgfmt(constraint.getConditionTemplate(),
+                  &verifierFCtx.withSelf("type"))
+         << ")) {\n"
+         << formatv("      return parser.emitError(parser.getNameLoc()) << "
+                    "\"'{0}' must be {1}, but got \" << type;\n",
+                    variable->name, constraint.getDescription())
+         << "    }\n"
+         << "  }\n";
+  }
+
   // Initialize the set of buildable types.
   if (!buildableTypes.empty()) {
     body << "  Builder &builder = parser.getBuilder();\n";
@@ -498,18 +536,27 @@ void OperationFormat::genParserTypeResolution(Operator &op,
            << tgfmt(it.first, &typeBuilderCtx) << ";\n";
   }
 
+  // Emit the code necessary for a type resolver.
+  auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
+    if (Optional<int> val = resolver.getBuilderIdx()) {
+      body << "odsBuildableType" << *val;
+    } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
+      if (Optional<StringRef> tform = resolver.getVarTransformer())
+        body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
+      else
+        body << var->name << "Types";
+    } else {
+      body << curVar << "Types";
+    }
+  };
+
   // Resolve each of the result types.
   if (allResultTypes) {
     body << "  result.addTypes(allResultTypes);\n";
   } else {
     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
       body << "  result.addTypes(";
-      if (Optional<int> val = resultTypes[i].getBuilderIdx())
-        body << "odsBuildableType" << *val;
-      else if (Optional<StringRef> var = resultTypes[i].getVariable())
-        body << *var << "Types";
-      else
-        body << op.getResultName(i) << "Types";
+      emitTypeResolver(resultTypes[i], op.getResultName(i));
       body << ");\n";
     }
   }
@@ -552,25 +599,19 @@ void OperationFormat::genParserTypeResolution(Operator &op,
   if (hasAllOperands) {
     body << "  if (parser.resolveOperands(allOperands, ";
 
-    auto emitOperandType = [&](int idx) {
-      if (Optional<int> val = operandTypes[idx].getBuilderIdx())
-        body << "ArrayRef<Type>(odsBuildableType" << *val << ")";
-      else if (Optional<StringRef> var = operandTypes[idx].getVariable())
-        body << *var << "Types";
-      else
-        body << op.getOperand(idx).name << "Types";
-    };
-
     // Group all of the operand types together to perform the resolution all at
     // once. Use llvm::concat to perform the merge. llvm::concat does not allow
     // the case of a single range, so guard it here.
     if (op.getNumOperands() > 1) {
       body << "llvm::concat<const Type>(";
-      interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body,
-                      emitOperandType);
+      interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
+        body << "ArrayRef<Type>(";
+        emitTypeResolver(operandTypes[i], op.getOperand(i).name);
+        body << ")";
+      });
       body << ")";
     } else {
-      emitOperandType(/*idx=*/0);
+      emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
     }
 
     body << ", allOperandLoc, result.operands))\n"
@@ -583,13 +624,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
     NamedTypeConstraint &operand = op.getOperand(i);
     body << "  if (parser.resolveOperands(" << operand.name << "Operands, ";
-    if (Optional<int> val = operandTypes[i].getBuilderIdx())
-      body << "odsBuildableType" << *val << ", ";
-    else if (Optional<StringRef> var = operandTypes[i].getVariable())
-      body << *var << "Types, " << operand.name << "OperandsLoc, ";
-    else
-      body << operand.name << "Types, " << operand.name << "OperandsLoc, ";
-    body << "result.operands))\n    return failure();\n";
+    emitTypeResolver(operandTypes[i], operand.name);
+
+    // If this isn't a buildable type, verify the sizes match by adding the loc.
+    if (!operandTypes[i].getBuilderIdx())
+      body << ", " << operand.name << "OperandsLoc";
+    body << ", result.operands))\n    return failure();\n";
   }
 }
 
@@ -954,18 +994,30 @@ class FormatParser {
   LogicalResult parse();
 
 private:
+  /// This struct represents a type resolution instance. It includes a specific
+  /// type as well as an optional transformer to apply to that type in order to
+  /// properly resolve the type of a variable.
+  struct TypeResolutionInstance {
+    const NamedTypeConstraint *type;
+    Optional<StringRef> transformer;
+  };
+
   /// Given the values of an `AllTypesMatch` trait, check for inferrable type
   /// resolution.
   void handleAllTypesMatchConstraint(
       ArrayRef<StringRef> values,
-      llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver);
+      llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
   /// Check for inferrable type resolution given all operands, and or results,
   /// have the same type. If 'includeResults' is true, the results also have the
   /// same type as all of the operands.
   void handleSameTypesConstraint(
-      llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
+      llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
       bool includeResults);
 
+  /// Returns an argument with the given name that has been seen within the
+  /// format.
+  const NamedTypeConstraint *findSeenArg(StringRef name);
+
   /// Parse a specific element.
   LogicalResult parseElement(std::unique_ptr<Element> &element,
                              bool isTopLevel);
@@ -1044,16 +1096,21 @@ LogicalResult FormatParser::parse() {
     return emitError(loc, "format missing 'attr-dict' directive");
 
   // Check for any type traits that we can use for inferring types.
-  llvm::StringMap<const NamedTypeConstraint *> variableTyResolver;
+  llvm::StringMap<TypeResolutionInstance> variableTyResolver;
   for (const OpTrait &trait : op.getTraits()) {
     const llvm::Record &def = trait.getDef();
-    if (def.isSubClassOf("AllTypesMatch"))
+    if (def.isSubClassOf("AllTypesMatch")) {
       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
                                     variableTyResolver);
-    else if (def.getName() == "SameTypeOperands")
+    } else if (def.getName() == "SameTypeOperands") {
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
-    else if (def.getName() == "SameOperandsAndResultType")
+    } else if (def.getName() == "SameOperandsAndResultType") {
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
+    } else if (def.isSubClassOf("TypesMatchWith")) {
+      if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
+        variableTyResolver[def.getValueAsString("rhs")] = {
+            lhsArg, def.getValueAsString("transformer")};
+    }
   }
 
   // Check that all of the result types can be inferred.
@@ -1066,7 +1123,8 @@ LogicalResult FormatParser::parse() {
       // Check to see if we can infer this type from another variable.
       auto varResolverIt = variableTyResolver.find(op.getResultName(i));
       if (varResolverIt != variableTyResolver.end()) {
-        fmt.resultTypes[i].setVariable(varResolverIt->second->name);
+        fmt.resultTypes[i].setVariable(varResolverIt->second.type,
+                                       varResolverIt->second.transformer);
         continue;
       }
 
@@ -1102,7 +1160,8 @@ LogicalResult FormatParser::parse() {
     // Check to see if we can infer this type from another variable.
     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
     if (varResolverIt != variableTyResolver.end()) {
-      fmt.operandTypes[i].setVariable(varResolverIt->second->name);
+      fmt.operandTypes[i].setVariable(varResolverIt->second.type,
+                                      varResolverIt->second.transformer);
       continue;
     }
 
@@ -1121,30 +1180,23 @@ LogicalResult FormatParser::parse() {
 
 void FormatParser::handleAllTypesMatchConstraint(
     ArrayRef<StringRef> values,
-    llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver) {
+    llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
   for (unsigned i = 0, e = values.size(); i != e; ++i) {
     // Check to see if this value matches a resolved operand or result type.
-    const NamedTypeConstraint *arg = nullptr;
-    if ((arg = findArg(op.getOperands(), values[i]))) {
-      if (!seenOperandTypes.test(arg - op.operand_begin()))
-        continue;
-    } else if ((arg = findArg(op.getResults(), values[i]))) {
-      if (!seenResultTypes.test(arg - op.result_begin()))
-        continue;
-    } else {
+    const NamedTypeConstraint *arg = findSeenArg(values[i]);
+    if (!arg)
       continue;
-    }
 
     // Mark this value as the type resolver for the other variables.
     for (unsigned j = 0; j != i; ++j)
-      variableTyResolver[values[j]] = arg;
+      variableTyResolver[values[j]] = {arg, llvm::None};
     for (unsigned j = i + 1; j != e; ++j)
-      variableTyResolver[values[j]] = arg;
+      variableTyResolver[values[j]] = {arg, llvm::None};
   }
 }
 
 void FormatParser::handleSameTypesConstraint(
-    llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
+    llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
     bool includeResults) {
   const NamedTypeConstraint *resolver = nullptr;
   int resolvedIt = -1;
@@ -1160,14 +1212,22 @@ void FormatParser::handleSameTypesConstraint(
   // Set the resolvers for each operand and result.
   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
     if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
-      variableTyResolver[op.getOperand(i).name] = resolver;
+      variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
   if (includeResults) {
     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
       if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
-        variableTyResolver[op.getResultName(i)] = resolver;
+        variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
   }
 }
 
+const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
+  if (auto *arg = findArg(op.getOperands(), name))
+    return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
+  if (auto *arg = findArg(op.getResults(), name))
+    return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
+  return nullptr;
+}
+
 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
                                          bool isTopLevel) {
   // Directives.
@@ -1191,7 +1251,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
   StringRef name = varTok.getSpelling().drop_front();
   llvm::SMLoc loc = varTok.getLoc();
 
-  // Check that the parsed argument is something actually registered on the op.
+  // Check that the parsed argument is something actually registered on the
+  // op.
   /// Attributes
   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
     if (isTopLevel && !seenAttrs.insert(attr).second)


        


More information about the Mlir-commits mailing list