[Mlir-commits] [mlir] 26222db - [mlir][DeclarativeParser] Add support for the TypesMatchWith trait.
River Riddle
llvmlistbot at llvm.org
Fri Feb 21 15:17:34 PST 2020
Author: River Riddle
Date: 2020-02-21T15:15:31-08:00
New Revision: 26222db01b079023d0fe3bb60f2c1b38f4f19d5a
URL: https://github.com/llvm/llvm-project/commit/26222db01b079023d0fe3bb60f2c1b38f4f19d5a
DIFF: https://github.com/llvm/llvm-project/commit/26222db01b079023d0fe3bb60f2c1b38f4f19d5a.diff
LOG: [mlir][DeclarativeParser] Add support for the TypesMatchWith trait.
This allows for injecting type constraints that are not direct 1-1 mappings, for example when one type is equal to the element type of another. This allows for moving over several more parsers to the declarative form.
Differential Revision: https://reviews.llvm.org/D74648
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/IR/invalid-ops.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index a0b739ea6ec7..fe28f8d7143f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -357,6 +357,8 @@ def CallIndirectOp : Std_Op<"call_indirect", [
let verifier = ?;
let hasCanonicalizer = 1;
+
+ let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
}
def CeilFOp : FloatUnaryOp<"ceilf"> {
@@ -490,6 +492,8 @@ def CmpIOp : Std_Op<"cmpi",
let verifier = [{ return success(); }];
let hasFolder = 1;
+
+ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
@@ -761,6 +765,10 @@ def ExtractElementOp : Std_Op<"extract_element",
}];
let hasFolder = 1;
+
+ let assemblyFormat = [{
+ $aggregate `[` $indices `]` attr-dict `:` type($aggregate)
+ }];
}
def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
@@ -853,6 +861,8 @@ def LoadOp : Std_Op<"load",
}];
let hasFolder = 1;
+
+ let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
}
def LogOp : FloatUnaryOp<"log"> {
@@ -1090,6 +1100,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
}];
let hasFolder = 1;
+
+ let assemblyFormat = [{
+ $condition `,` $true_value `,` $false_value attr-dict `:` type($result)
+ }];
}
def SignExtendIOp : Std_Op<"sexti",
@@ -1222,6 +1236,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
[{ build(builder, result, aggregateType, element); }]>];
let hasFolder = 1;
+
+ let assemblyFormat = "$input attr-dict `:` type($aggregate)";
}
def StoreOp : Std_Op<"store",
@@ -1264,6 +1280,10 @@ def StoreOp : Std_Op<"store",
}];
let hasFolder = 1;
+
+ let assemblyFormat = [{
+ $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
+ }];
}
def SubFOp : FloatArithmeticOp<"subf"> {
@@ -1517,11 +1537,12 @@ def TensorLoadOp : Std_Op<"tensor_load",
result.addTypes(resultType);
}]>];
-
let extraClassDeclaration = [{
/// The result of a tensor_load is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
}];
+
+ let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
def TensorStoreOp : Std_Op<"tensor_store",
@@ -1545,6 +1566,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;
+
+ let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
}
def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index ce6029a5d497..70917ff2b882 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -363,6 +363,10 @@ def Vector_ExtractElementOp :
return vector().getType().cast<VectorType>();
}
}];
+
+ let assemblyFormat = [{
+ $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
+ }];
}
def Vector_ExtractOp :
@@ -512,6 +516,11 @@ def Vector_InsertElementOp :
return dest().getType().cast<VectorType>();
}
}];
+
+ let assemblyFormat = [{
+ $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
+ type($result)
+ }];
}
def Vector_InsertOp :
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0d0736b4512f..08fac5ea19ef 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -496,6 +496,12 @@ class OpAsmParser {
return failure();
return success();
}
+ template <typename Operands>
+ ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
+ SmallVectorImpl<Value> &result) {
+ return resolveOperands(std::forward<Operands>(operands),
+ ArrayRef<Type>(type), loc, result);
+ }
template <typename Operands, typename Types>
ParseResult resolveOperands(Operands &&operands, Types &&types,
llvm::SMLoc loc, SmallVectorImpl<Value> &result) {
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index b0c97e2dc1a4..89fffa0df60d 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -294,6 +294,11 @@ struct OperationState {
void addTypes(ArrayRef<Type> newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
+ template <typename RangeT>
+ std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value>
+ addTypes(RangeT &&newTypes) {
+ types.append(newTypes.begin(), newTypes.end());
+ }
/// Add an attribute with the specified name.
void addAttribute(StringRef name, Attribute attr) {
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 42daf193b9cb..aa0a42812342 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -505,29 +505,6 @@ struct SimplifyIndirectCallWithKnownCallee
};
} // end anonymous namespace.
-static ParseResult parseCallIndirectOp(OpAsmParser &parser,
- OperationState &result) {
- FunctionType calleeType;
- OpAsmParser::OperandType callee;
- llvm::SMLoc operandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- return failure(
- parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.resolveOperand(callee, calleeType, result.operands) ||
- parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
- result.operands) ||
- parser.addTypesToList(calleeType.getResults(), result.types));
-}
-
-static void print(OpAsmPrinter &p, CallIndirectOp op) {
- p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : " << op.getCallee().getType();
-}
-
void CallIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
@@ -570,55 +547,6 @@ static void buildCmpIOp(Builder *build, OperationState &result,
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
-static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 2> ops;
- SmallVector<NamedAttribute, 4> attrs;
- Attribute predicateNameAttr;
- Type type;
- if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
- attrs) ||
- parser.parseComma() || parser.parseOperandList(ops, 2) ||
- parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
- parser.resolveOperands(ops, type, result.operands))
- return failure();
-
- if (!predicateNameAttr.isa<StringAttr>())
- return parser.emitError(parser.getNameLoc(),
- "expected string comparison predicate attribute");
-
- // Rewrite string attribute to an enum value.
- StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
- Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
- if (!predicate.hasValue())
- return parser.emitError(parser.getNameLoc())
- << "unknown comparison predicate \"" << predicateName << "\"";
-
- auto builder = parser.getBuilder();
- Type i1Type = getCheckedI1SameShape(type);
- if (!i1Type)
- return parser.emitError(parser.getNameLoc(),
- "expected type with valid i1 shape");
-
- attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
- result.attributes = attrs;
-
- result.addTypes({i1Type});
- return success();
-}
-
-static void print(OpAsmPrinter &p, CmpIOp op) {
- p << "cmpi ";
-
- Builder b(op.getContext());
- auto predicateValue =
- op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
- p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
- << '"' << ", " << op.lhs() << ", " << op.rhs();
- p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
- p << " : " << op.lhs().getType();
-}
-
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
// comparison predicates.
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
@@ -1486,30 +1414,6 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
// ExtractElementOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, ExtractElementOp op) {
- p << "extract_element " << op.getAggregate() << '[' << op.getIndices();
- p << ']';
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getAggregate().getType();
-}
-
-static ParseResult parseExtractElementOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType aggregateInfo;
- SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- ShapedType type;
-
- auto indexTy = parser.getBuilder().getIndexType();
- return failure(
- parser.parseOperand(aggregateInfo) ||
- parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(aggregateInfo, type, result.operands) ||
- parser.resolveOperands(indexInfo, indexTy, result.operands) ||
- parser.addTypeToList(type.getElementType(), result.types));
-}
-
static LogicalResult verify(ExtractElementOp op) {
// Verify the # indices match if we have a ranked type.
auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
@@ -1577,28 +1481,6 @@ OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
// LoadOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, LoadOp op) {
- p << "load " << op.getMemRef() << '[' << op.getIndices() << ']';
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getMemRefType();
-}
-
-static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType memrefInfo;
- SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType type;
-
- auto indexTy = parser.getBuilder().getIndexType();
- return failure(
- parser.parseOperand(memrefInfo) ||
- parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(memrefInfo, type, result.operands) ||
- parser.resolveOperands(indexInfo, indexTy, result.operands) ||
- parser.addTypeToList(type.getElementType(), result.types));
-}
-
static LogicalResult verify(LoadOp op) {
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
return op.emitOpError("incorrect number of indices for load");
@@ -1902,31 +1784,6 @@ bool SIToFPOp::areCastCompatible(Type a, Type b) {
// SelectOp
//===----------------------------------------------------------------------===//
-static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 3> ops;
- SmallVector<NamedAttribute, 4> attrs;
- Type type;
- if (parser.parseOperandList(ops, 3) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type))
- return failure();
-
- auto i1Type = getCheckedI1SameShape(type);
- if (!i1Type)
- return parser.emitError(parser.getNameLoc(),
- "expected type with valid i1 shape");
-
- std::array<Type, 3> types = {i1Type, type, type};
- return failure(parser.resolveOperands(ops, types, parser.getNameLoc(),
- result.operands) ||
- parser.addTypeToList(type, result.types));
-}
-
-static void print(OpAsmPrinter &p, SelectOp op) {
- p << "select " << op.getOperands() << " : " << op.getTrueValue().getType();
- p.printOptionalAttrDict(op.getAttrs());
-}
-
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
auto condition = getCondition();
@@ -1968,25 +1825,6 @@ static LogicalResult verify(SignExtendIOp op) {
// SplatOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, SplatOp op) {
- p << "splat " << op.getOperand();
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getType();
-}
-
-static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType splatValueInfo;
- ShapedType shapedType;
-
- return failure(parser.parseOperand(splatValueInfo) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(shapedType) ||
- parser.resolveOperand(splatValueInfo,
- shapedType.getElementType(),
- result.operands) ||
- parser.addTypeToList(shapedType, result.types));
-}
-
static LogicalResult verify(SplatOp op) {
// TODO: we could replace this by a trait.
if (op.getOperand().getType() !=
@@ -2017,32 +1855,6 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
// StoreOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, StoreOp op) {
- p << "store " << op.getValueToStore();
- p << ", " << op.getMemRef() << '[' << op.getIndices() << ']';
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getMemRefType();
-}
-
-static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType storeValueInfo;
- OpAsmParser::OperandType memrefInfo;
- SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType memrefType;
-
- auto indexTy = parser.getBuilder().getIndexType();
- return failure(
- parser.parseOperand(storeValueInfo) || parser.parseComma() ||
- parser.parseOperand(memrefInfo) ||
- parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(memrefType) ||
- parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
- result.operands) ||
- parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
- parser.resolveOperands(indexInfo, indexTy, result.operands));
-}
-
static LogicalResult verify(StoreOp op) {
if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
return op.emitOpError("store index operand count not equal to memref rank");
@@ -2156,51 +1968,6 @@ static Type getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
-//===----------------------------------------------------------------------===//
-// TensorLoadOp
-//===----------------------------------------------------------------------===//
-
-static void print(OpAsmPrinter &p, TensorLoadOp op) {
- p << "tensor_load " << op.getOperand();
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getOperand().getType();
-}
-
-static ParseResult parseTensorLoadOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType op;
- Type type;
- return failure(
- parser.parseOperand(op) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(op, type, result.operands) ||
- parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types));
-}
-
-//===----------------------------------------------------------------------===//
-// TensorStoreOp
-//===----------------------------------------------------------------------===//
-
-static void print(OpAsmPrinter &p, TensorStoreOp op) {
- p << "tensor_store " << op.tensor() << ", " << op.memref();
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.memref().getType();
-}
-
-static ParseResult parseTensorStoreOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 2> ops;
- Type type;
- llvm::SMLoc loc = parser.getCurrentLocation();
- return failure(
- parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type},
- loc, result.operands));
-}
-
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 35dbc83d4595..53c5cbd57319 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -412,31 +412,6 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
// ExtractElementOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
- p << op.getOperationName() << " " << op.vector() << "[" << op.position()
- << " : " << op.position().getType() << "]";
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.vector().getType();
-}
-
-static ParseResult parseExtractElementOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType vector, position;
- Type positionType;
- VectorType vectorType;
- if (parser.parseOperand(vector) || parser.parseLSquare() ||
- parser.parseOperand(position) || parser.parseColonType(positionType) ||
- parser.parseRSquare() ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(vectorType))
- return failure();
- Type resultType = vectorType.getElementType();
- return failure(
- parser.resolveOperand(vector, vectorType, result.operands) ||
- parser.resolveOperand(position, positionType, result.operands) ||
- parser.addTypeToList(resultType, result.types));
-}
-
static LogicalResult verify(vector::ExtractElementOp op) {
VectorType vectorType = op.getVectorType();
if (vectorType.getRank() != 1)
@@ -715,33 +690,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
// InsertElementOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, InsertElementOp op) {
- p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
- << op.position() << " : " << op.position().getType() << "]";
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.dest().getType();
-}
-
-static ParseResult parseInsertElementOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType source, dest, position;
- Type positionType;
- VectorType destType;
- if (parser.parseOperand(source) || parser.parseComma() ||
- parser.parseOperand(dest) || parser.parseLSquare() ||
- parser.parseOperand(position) || parser.parseColonType(positionType) ||
- parser.parseRSquare() ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(destType))
- return failure();
- Type sourceType = destType.getElementType();
- return failure(
- parser.resolveOperand(source, sourceType, result.operands) ||
- parser.resolveOperand(dest, destType, result.operands) ||
- parser.resolveOperand(position, positionType, result.operands) ||
- parser.addTypeToList(destType, result.types));
-}
-
static LogicalResult verify(InsertElementOp op) {
auto dstVectorType = op.getDestVectorType();
if (dstVectorType.getRank() != 1)
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index cebc87ea94a1..20f90c76e3d1 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -226,7 +226,7 @@ func @func_with_ops(i32, i32) {
// Integer comparisons are not recognized for float types.
func @func_with_ops(f32, f32) {
^bb0(%a : f32, %b : f32):
- %r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}}
+ %r = cmpi "eq", %a, %b : f32 // expected-error {{'lhs' must be integer-like, but got 'f32'}}
}
// -----
@@ -298,13 +298,13 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
// -----
func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
- // expected-error at +1 {{expected type with valid i1 shape}}
+ // expected-error at +1 {{'result' must be integer-like or floating-point-like, but got '() -> ()'}}
%sel = select %cond, %idx, %idx : () -> ()
// -----
func @invalid_cmp_shape(%idx : () -> ()) {
- // expected-error at +1 {{expected type with valid i1 shape}}
+ // expected-error at +1 {{'lhs' must be integer-like, but got '() -> ()'}}
%cmp = cmpi "eq", %idx, %idx : () -> ()
// -----
@@ -340,7 +340,7 @@ func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
// -----
func @invalid_cmp_attr(%idx : i32) {
- // expected-error at +1 {{expected string comparison predicate attribute}}
+ // expected-error at +1 {{invalid kind of attribute specified}}
%cmp = cmpi i1, %idx, %idx : i32
// -----
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index b8aeb904e187..62653bc2da03 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -219,16 +219,26 @@ struct OperationFormat {
void setBuilderIdx(int idx) { builderIdx = idx; }
/// Get the variable this type is resolved to, or None.
- Optional<StringRef> getVariable() const { return variableName; }
- void setVariable(StringRef variable) { variableName = variable; }
+ const NamedTypeConstraint *getVariable() const { return variable; }
+ Optional<StringRef> getVarTransformer() const {
+ return variableTransformer;
+ }
+ void setVariable(const NamedTypeConstraint *var,
+ Optional<StringRef> transformer) {
+ variable = var;
+ variableTransformer = transformer;
+ }
private:
/// If the type is resolved with a buildable type, this is the index into
/// 'buildableTypes' in the parent format.
Optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
- /// the name of the variable that this type is resolved to.
- Optional<StringRef> variableName;
+ /// the variable that this type is resolved to.
+ const NamedTypeConstraint *variable;
+ /// If the type is resolved based upon another operand or result, this is
+ /// a transformer to apply to the variable when resolving.
+ Optional<StringRef> variableTransformer;
};
OperationFormat(const Operator &op)
@@ -487,6 +497,34 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
void OperationFormat::genParserTypeResolution(Operator &op,
OpMethodBody &body) {
+ // If any of type resolutions use transformed variables, make sure that the
+ // types of those variables are resolved.
+ SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
+ FmtContext verifierFCtx;
+ for (TypeResolution &resolver :
+ llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
+ Optional<StringRef> transformer = resolver.getVarTransformer();
+ if (!transformer)
+ continue;
+ // Ensure that we don't verify the same variables twice.
+ const NamedTypeConstraint *variable = resolver.getVariable();
+ if (!verifiedVariables.insert(variable).second)
+ continue;
+
+ auto constraint = variable->constraint;
+ body << " for (Type type : " << variable->name << "Types) {\n"
+ << " (void)type;\n"
+ << " if (!("
+ << tgfmt(constraint.getConditionTemplate(),
+ &verifierFCtx.withSelf("type"))
+ << ")) {\n"
+ << formatv(" return parser.emitError(parser.getNameLoc()) << "
+ "\"'{0}' must be {1}, but got \" << type;\n",
+ variable->name, constraint.getDescription())
+ << " }\n"
+ << " }\n";
+ }
+
// Initialize the set of buildable types.
if (!buildableTypes.empty()) {
body << " Builder &builder = parser.getBuilder();\n";
@@ -498,18 +536,27 @@ void OperationFormat::genParserTypeResolution(Operator &op,
<< tgfmt(it.first, &typeBuilderCtx) << ";\n";
}
+ // Emit the code necessary for a type resolver.
+ auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
+ if (Optional<int> val = resolver.getBuilderIdx()) {
+ body << "odsBuildableType" << *val;
+ } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
+ if (Optional<StringRef> tform = resolver.getVarTransformer())
+ body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
+ else
+ body << var->name << "Types";
+ } else {
+ body << curVar << "Types";
+ }
+ };
+
// Resolve each of the result types.
if (allResultTypes) {
body << " result.addTypes(allResultTypes);\n";
} else {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
body << " result.addTypes(";
- if (Optional<int> val = resultTypes[i].getBuilderIdx())
- body << "odsBuildableType" << *val;
- else if (Optional<StringRef> var = resultTypes[i].getVariable())
- body << *var << "Types";
- else
- body << op.getResultName(i) << "Types";
+ emitTypeResolver(resultTypes[i], op.getResultName(i));
body << ");\n";
}
}
@@ -552,25 +599,19 @@ void OperationFormat::genParserTypeResolution(Operator &op,
if (hasAllOperands) {
body << " if (parser.resolveOperands(allOperands, ";
- auto emitOperandType = [&](int idx) {
- if (Optional<int> val = operandTypes[idx].getBuilderIdx())
- body << "ArrayRef<Type>(odsBuildableType" << *val << ")";
- else if (Optional<StringRef> var = operandTypes[idx].getVariable())
- body << *var << "Types";
- else
- body << op.getOperand(idx).name << "Types";
- };
-
// Group all of the operand types together to perform the resolution all at
// once. Use llvm::concat to perform the merge. llvm::concat does not allow
// the case of a single range, so guard it here.
if (op.getNumOperands() > 1) {
body << "llvm::concat<const Type>(";
- interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body,
- emitOperandType);
+ interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
+ body << "ArrayRef<Type>(";
+ emitTypeResolver(operandTypes[i], op.getOperand(i).name);
+ body << ")";
+ });
body << ")";
} else {
- emitOperandType(/*idx=*/0);
+ emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
}
body << ", allOperandLoc, result.operands))\n"
@@ -583,13 +624,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
NamedTypeConstraint &operand = op.getOperand(i);
body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
- if (Optional<int> val = operandTypes[i].getBuilderIdx())
- body << "odsBuildableType" << *val << ", ";
- else if (Optional<StringRef> var = operandTypes[i].getVariable())
- body << *var << "Types, " << operand.name << "OperandsLoc, ";
- else
- body << operand.name << "Types, " << operand.name << "OperandsLoc, ";
- body << "result.operands))\n return failure();\n";
+ emitTypeResolver(operandTypes[i], operand.name);
+
+ // If this isn't a buildable type, verify the sizes match by adding the loc.
+ if (!operandTypes[i].getBuilderIdx())
+ body << ", " << operand.name << "OperandsLoc";
+ body << ", result.operands))\n return failure();\n";
}
}
@@ -954,18 +994,30 @@ class FormatParser {
LogicalResult parse();
private:
+ /// This struct represents a type resolution instance. It includes a specific
+ /// type as well as an optional transformer to apply to that type in order to
+ /// properly resolve the type of a variable.
+ struct TypeResolutionInstance {
+ const NamedTypeConstraint *type;
+ Optional<StringRef> transformer;
+ };
+
/// Given the values of an `AllTypesMatch` trait, check for inferrable type
/// resolution.
void handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
- llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver);
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
/// Check for inferrable type resolution given all operands, and or results,
/// have the same type. If 'includeResults' is true, the results also have the
/// same type as all of the operands.
void handleSameTypesConstraint(
- llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults);
+ /// Returns an argument with the given name that has been seen within the
+ /// format.
+ const NamedTypeConstraint *findSeenArg(StringRef name);
+
/// Parse a specific element.
LogicalResult parseElement(std::unique_ptr<Element> &element,
bool isTopLevel);
@@ -1044,16 +1096,21 @@ LogicalResult FormatParser::parse() {
return emitError(loc, "format missing 'attr-dict' directive");
// Check for any type traits that we can use for inferring types.
- llvm::StringMap<const NamedTypeConstraint *> variableTyResolver;
+ llvm::StringMap<TypeResolutionInstance> variableTyResolver;
for (const OpTrait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef();
- if (def.isSubClassOf("AllTypesMatch"))
+ if (def.isSubClassOf("AllTypesMatch")) {
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
variableTyResolver);
- else if (def.getName() == "SameTypeOperands")
+ } else if (def.getName() == "SameTypeOperands") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
- else if (def.getName() == "SameOperandsAndResultType")
+ } else if (def.getName() == "SameOperandsAndResultType") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
+ } else if (def.isSubClassOf("TypesMatchWith")) {
+ if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
+ variableTyResolver[def.getValueAsString("rhs")] = {
+ lhsArg, def.getValueAsString("transformer")};
+ }
}
// Check that all of the result types can be inferred.
@@ -1066,7 +1123,8 @@ LogicalResult FormatParser::parse() {
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
- fmt.resultTypes[i].setVariable(varResolverIt->second->name);
+ fmt.resultTypes[i].setVariable(varResolverIt->second.type,
+ varResolverIt->second.transformer);
continue;
}
@@ -1102,7 +1160,8 @@ LogicalResult FormatParser::parse() {
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
if (varResolverIt != variableTyResolver.end()) {
- fmt.operandTypes[i].setVariable(varResolverIt->second->name);
+ fmt.operandTypes[i].setVariable(varResolverIt->second.type,
+ varResolverIt->second.transformer);
continue;
}
@@ -1121,30 +1180,23 @@ LogicalResult FormatParser::parse() {
void FormatParser::handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
- llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver) {
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type.
- const NamedTypeConstraint *arg = nullptr;
- if ((arg = findArg(op.getOperands(), values[i]))) {
- if (!seenOperandTypes.test(arg - op.operand_begin()))
- continue;
- } else if ((arg = findArg(op.getResults(), values[i]))) {
- if (!seenResultTypes.test(arg - op.result_begin()))
- continue;
- } else {
+ const NamedTypeConstraint *arg = findSeenArg(values[i]);
+ if (!arg)
continue;
- }
// Mark this value as the type resolver for the other variables.
for (unsigned j = 0; j != i; ++j)
- variableTyResolver[values[j]] = arg;
+ variableTyResolver[values[j]] = {arg, llvm::None};
for (unsigned j = i + 1; j != e; ++j)
- variableTyResolver[values[j]] = arg;
+ variableTyResolver[values[j]] = {arg, llvm::None};
}
}
void FormatParser::handleSameTypesConstraint(
- llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults) {
const NamedTypeConstraint *resolver = nullptr;
int resolvedIt = -1;
@@ -1160,14 +1212,22 @@ void FormatParser::handleSameTypesConstraint(
// Set the resolvers for each operand and result.
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
- variableTyResolver[op.getOperand(i).name] = resolver;
+ variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
if (includeResults) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
- variableTyResolver[op.getResultName(i)] = resolver;
+ variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
}
}
+const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
+ if (auto *arg = findArg(op.getOperands(), name))
+ return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
+ if (auto *arg = findArg(op.getResults(), name))
+ return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
+ return nullptr;
+}
+
LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
bool isTopLevel) {
// Directives.
@@ -1191,7 +1251,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
StringRef name = varTok.getSpelling().drop_front();
llvm::SMLoc loc = varTok.getLoc();
- // Check that the parsed argument is something actually registered on the op.
+ // Check that the parsed argument is something actually registered on the
+ // op.
/// Attributes
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
if (isTopLevel && !seenAttrs.insert(attr).second)
More information about the Mlir-commits
mailing list