[Mlir-commits] [mlir] 6b6c966 - [mlir][ODS] Add a new trait `TypesMatchWith`
River Riddle
llvmlistbot at llvm.org
Wed Feb 19 10:22:41 PST 2020
Author: River Riddle
Date: 2020-02-19T10:18:58-08:00
New Revision: 6b6c96695c0054ebad6816171ef89d5cb76a058b
URL: https://github.com/llvm/llvm-project/commit/6b6c96695c0054ebad6816171ef89d5cb76a058b
DIFF: https://github.com/llvm/llvm-project/commit/6b6c96695c0054ebad6816171ef89d5cb76a058b.diff
LOG: [mlir][ODS] Add a new trait `TypesMatchWith`
Summary:
This trait takes three arguments: lhs, rhs, transformer. It verifies that the type of 'rhs' matches the type of 'lhs' when the given 'transformer' is applied to 'lhs'. This allows for adding constraints like: "the type of 'a' must match the element type of 'b'". A followup revision will add support in the declarative parser for using these equality constraints to port more c++ parsers to the declarative form.
Differential Revision: https://reviews.llvm.org/D74647
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Support/STLExtras.h
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/IR/invalid-ops.mlir
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 73a88681da3c..b6186bd4ec76 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -310,7 +310,15 @@ def CallOp : Std_Op<"call", [CallOpInterface]> {
}];
}
-def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
+def CallIndirectOp : Std_Op<"call_indirect", [
+ CallOpInterface,
+ TypesMatchWith<"callee input types match argument types",
+ "callee", "operands",
+ "$_self.cast<FunctionType>().getInputs()">,
+ TypesMatchWith<"callee result types match result types",
+ "callee", "results",
+ "$_self.cast<FunctionType>().getResults()">
+ ]> {
let summary = "indirect call operation";
let description = [{
The "call_indirect" operation represents an indirect call to a value of
@@ -322,7 +330,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
}];
let arguments = (ins FunctionType:$callee, Variadic<AnyType>:$operands);
- let results = (outs Variadic<AnyType>);
+ let results = (outs Variadic<AnyType>:$results);
let builders = [OpBuilder<
"Builder *, OperationState &result, Value callee,"
@@ -347,6 +355,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
CallInterfaceCallable getCallableForCallee() { return getCallee(); }
}];
+ let verifier = ?;
let hasCanonicalizer = 1;
}
@@ -361,7 +370,10 @@ def CeilFOp : FloatUnaryOp<"ceilf"> {
}
def CmpFOp : Std_Op<"cmpf",
- [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
+ [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
+ TypesMatchWith<
+ "result type has i1 element type and same shape as operands",
+ "lhs", "result", "getI1SameShape($_self)">]> {
let summary = "floating-point comparison operation";
let description = [{
The "cmpf" operation compares its two operands according to the float
@@ -386,7 +398,7 @@ def CmpFOp : Std_Op<"cmpf",
}];
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
- let results = (outs BoolLike);
+ let results = (outs BoolLike:$result);
let builders = [OpBuilder<
"Builder *builder, OperationState &result, CmpFPredicate predicate,"
@@ -426,7 +438,10 @@ def CmpIPredicateAttr : I64EnumAttr<
}
def CmpIOp : Std_Op<"cmpi",
- [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
+ [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
+ TypesMatchWith<
+ "result type has i1 element type and same shape as operands",
+ "lhs", "result", "getI1SameShape($_self)">]> {
let summary = "integer comparison operation";
let description = [{
The "cmpi" operation compares its two operands according to the integer
@@ -454,7 +469,7 @@ def CmpIOp : Std_Op<"cmpi",
IntegerLike:$lhs,
IntegerLike:$rhs
);
- let results = (outs BoolLike);
+ let results = (outs BoolLike:$result);
let builders = [OpBuilder<
"Builder *builder, OperationState &result, CmpIPredicate predicate,"
@@ -708,7 +723,11 @@ def ExpOp : FloatUnaryOp<"exp"> {
let summary = "base-e exponential of the specified value";
}
-def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
+def ExtractElementOp : Std_Op<"extract_element",
+ [NoSideEffect,
+ TypesMatchWith<"result type matches element type of aggregate",
+ "aggregate", "result",
+ "$_self.cast<ShapedType>().getElementType()">]> {
let summary = "element extract operation";
let description = [{
The "extract_element" op reads a tensor or vector and returns one element
@@ -723,7 +742,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
Variadic<Index>:$indices);
- let results = (outs AnyType);
+ let results = (outs AnyType:$result);
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value aggregate,"
@@ -796,7 +815,10 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
-def LoadOp : Std_Op<"load"> {
+def LoadOp : Std_Op<"load",
+ [TypesMatchWith<"result type matches element type of 'memref'",
+ "memref", "result",
+ "$_self.cast<MemRefType>().getElementType()">]> {
let summary = "load operation";
let description = [{
The "load" op reads an element from a memref specified by an index list. The
@@ -809,7 +831,7 @@ def LoadOp : Std_Op<"load"> {
}];
let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
- let results = (outs AnyType);
+ let results = (outs AnyType:$result);
let builders = [OpBuilder<
"Builder *, OperationState &result, Value memref,"
@@ -1029,7 +1051,11 @@ def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
>];
}
-def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
+def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
+ AllTypesMatch<["true_value", "false_value", "result"]>,
+ TypesMatchWith<"condition type matches i1 equivalent of result type",
+ "result", "condition",
+ "getI1SameShape($_self)">]> {
let summary = "select operation";
let description = [{
The "select" operation chooses one value based on a binary condition
@@ -1044,9 +1070,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
%3 = select %2, %0, %1 : i32
}];
- let arguments = (ins BoolLike:$condition, AnyType:$true_value,
- AnyType:$false_value);
- let results = (outs AnyType);
+ let arguments = (ins BoolLike:$condition, IntegerOrFloatLike:$true_value,
+ IntegerOrFloatLike:$false_value);
+ let results = (outs IntegerOrFloatLike:$result);
+ let verifier = ?;
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value condition,"
@@ -1158,7 +1185,10 @@ def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
-def SplatOp : Std_Op<"splat", [NoSideEffect]> {
+def SplatOp : Std_Op<"splat", [NoSideEffect,
+ TypesMatchWith<"operand type matches element type of result",
+ "aggregate", "input",
+ "$_self.cast<ShapedType>().getElementType()">]> {
let summary = "splat or broadcast operation";
let description = [{
The "splat" op reads a value of integer or float type and broadcasts it into
@@ -1193,7 +1223,10 @@ def SplatOp : Std_Op<"splat", [NoSideEffect]> {
let hasFolder = 1;
}
-def StoreOp : Std_Op<"store"> {
+def StoreOp : Std_Op<"store",
+ [TypesMatchWith<"type of 'value' matches element type of 'memref'",
+ "memref", "value",
+ "$_self.cast<MemRefType>().getElementType()">]> {
let summary = "store operation";
let description = [{
The "store" op writes an element to a memref specified by an index list.
@@ -1455,7 +1488,10 @@ def TensorCastOp : CastOp<"tensor_cast"> {
}
def TensorLoadOp : Std_Op<"tensor_load",
- [SameOperandsAndResultShape, SameOperandsAndResultElementType]> {
+ [SameOperandsAndResultShape, SameOperandsAndResultElementType,
+ TypesMatchWith<"result type matches tensor equivalent of 'memref'",
+ "memref", "result",
+ "getTensorTypeFromMemRefType($_self)">]> {
let summary = "tensor load operation";
let description = [{
The "tensor_load" operation creates a tensor from a memref, making an
@@ -1466,8 +1502,8 @@ def TensorLoadOp : Std_Op<"tensor_load",
%12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
}];
- let arguments = (ins AnyMemRef);
- let results = (outs AnyTensor);
+ let arguments = (ins AnyMemRef:$memref);
+ let results = (outs AnyTensor:$result);
// TensorLoadOp is fully verified by traits.
let verifier = ?;
@@ -1488,7 +1524,10 @@ def TensorLoadOp : Std_Op<"tensor_load",
}
def TensorStoreOp : Std_Op<"tensor_store",
- [SameOperandsShape, SameOperandsElementType]> {
+ [SameOperandsShape, SameOperandsElementType,
+ TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'",
+ "memref", "tensor",
+ "getTensorTypeFromMemRefType($_self)">]> {
let summary = "tensor store operation";
let description = [{
The "tensor_store" operation stores the contents of a tensor into a memref.
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 5ead87681ad0..10837b94bfc2 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -339,10 +339,11 @@ def Vector_ShuffleOp :
def Vector_ExtractElementOp :
Vector_Op<"extractelement", [NoSideEffect,
- PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ TypesMatchWith<"result type matches element type of vector operand",
+ "vector", "result",
+ "$_self.cast<ShapedType>().getElementType()">]>,
Arguments<(ins AnyVector:$vector, AnyInteger:$position)>,
- Results<(outs AnyType)> {
+ Results<(outs AnyType:$result)> {
let summary = "extractelement operation";
let description = [{
Takes an 1-D vector and a dynamic index position and extracts the
@@ -482,12 +483,12 @@ def Vector_FMAOp :
def Vector_InsertElementOp :
Vector_Op<"insertelement", [NoSideEffect,
- PredOpTrait<"source operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>,
- PredOpTrait<"dest operand and result have same type",
- TCresIsSameAsOpBase<0, 1>>]>,
+ TypesMatchWith<"source operand type matches element type of result",
+ "result", "source",
+ "$_self.cast<ShapedType>().getElementType()">,
+ AllTypesMatch<["dest", "result"]>]>,
Arguments<(ins AnyType:$source, AnyVector:$dest, AnyInteger:$position)>,
- Results<(outs AnyVector)> {
+ Results<(outs AnyVector:$result)> {
let summary = "insertelement operation";
let description = [{
Takes a scalar source, an 1-D destination vector and a dynamic index
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 88240aed1949..2f02a885242b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -595,6 +595,11 @@ def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
"floating-point-like">;
+// Type constraint for integer-like or float-like types.
+def IntegerOrFloatLike : TypeConstraint<Or<[IntegerLike.predicate,
+ FloatLike.predicate]>,
+ "integer-like or floating-point-like">;
+
//===----------------------------------------------------------------------===//
// Attribute definitions
@@ -1716,6 +1721,17 @@ class AllShapesMatch<list<string> names> :
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
+// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+class TypesMatchWith<string description, string lhsArg, string rhsArg,
+ string transform> :
+ PredOpTrait<description, CPred<
+ !subst("$_self", "$" # lhsArg # ".getType()", transform)
+ # " == $" # rhsArg # ".getType()">> {
+ string lhs = lhsArg;
+ string rhs = rhsArg;
+ string transformer = transform;
+}
+
// Type Constraint operand `idx`'s Element type is `type`.
class TCopVTEtIs<int idx, Type type> : And<[
CPred<"$_op.getNumOperands() > " # idx>,
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index a9a6ff46242c..a7999607d28c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -610,6 +610,12 @@ class ValueTypeRange final
ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
};
+template <typename RangeT>
+inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
+ return lhs.size() == llvm::size(rhs) &&
+ std::equal(lhs.begin(), lhs.end(), rhs.begin());
+}
+
//===----------------------------------------------------------------------===//
// OperandRange
@@ -625,6 +631,7 @@ class OperandRange final
using type_iterator = ValueTypeIterator<iterator>;
using type_range = ValueTypeRange<OperandRange>;
type_range getTypes() const { return {begin(), end()}; }
+ auto getType() const { return getTypes(); }
private:
/// See `detail::indexed_accessor_range_base` for details.
@@ -656,6 +663,7 @@ class ResultRange final
using type_iterator = ArrayRef<Type>::iterator;
using type_range = ArrayRef<Type>;
type_range getTypes() const;
+ auto getType() const { return getTypes(); }
private:
/// See `indexed_accessor_range` for details.
@@ -725,6 +733,7 @@ class ValueRange final
using type_iterator = ValueTypeIterator<iterator>;
using type_range = ValueTypeRange<ValueRange>;
type_range getTypes() const { return {begin(), end()}; }
+ auto getType() const { return getTypes(); }
private:
using OwnerT = detail::ValueRangeOwner;
diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h
index 14336aad6a25..527234921d95 100644
--- a/mlir/include/mlir/Support/STLExtras.h
+++ b/mlir/include/mlir/Support/STLExtras.h
@@ -233,6 +233,12 @@ class indexed_accessor_range_base {
return DerivedT::dereference_iterator(base, index);
}
+ /// Compare this range with another.
+ template <typename OtherT> bool operator==(const OtherT &other) {
+ return size() == llvm::size(other) &&
+ std::equal(begin(), end(), other.begin());
+ }
+
/// Return the size of this range.
size_t size() const { return count; }
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 9a0bcb9d58a0..7318861938a4 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -528,30 +528,6 @@ static void print(OpAsmPrinter &p, CallIndirectOp op) {
p << " : " << op.getCallee().getType();
}
-static LogicalResult verify(CallIndirectOp op) {
- // The callee must be a function.
- auto fnType = op.getCallee().getType().dyn_cast<FunctionType>();
- if (!fnType)
- return op.emitOpError("callee must have function type");
-
- // Verify that the operand and result types match the callee.
- if (fnType.getNumInputs() != op.getNumOperands() - 1)
- return op.emitOpError("incorrect number of operands for callee");
-
- for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
- if (op.getOperand(i + 1).getType() != fnType.getInput(i))
- return op.emitOpError("operand type mismatch");
-
- if (fnType.getNumResults() != op.getNumResults())
- return op.emitOpError("incorrect number of results for callee");
-
- for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
- if (op.getResult(i).getType() != fnType.getResult(i))
- return op.emitOpError("result type mismatch");
-
- return success();
-}
-
void CallIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
@@ -562,8 +538,8 @@ void CallIndirectOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
// Return the type of the same shape (scalar, vector or tensor) containing i1.
-static Type getCheckedI1SameShape(Builder *build, Type type) {
- auto i1Type = build->getI1Type();
+static Type getCheckedI1SameShape(Type type) {
+ auto i1Type = IntegerType::get(1, type.getContext());
if (type.isIntOrIndexOrFloat())
return i1Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
@@ -575,8 +551,8 @@ static Type getCheckedI1SameShape(Builder *build, Type type) {
return Type();
}
-static Type getI1SameShape(Builder *build, Type type) {
- Type res = getCheckedI1SameShape(build, type);
+static Type getI1SameShape(Type type) {
+ Type res = getCheckedI1SameShape(type);
assert(res && "expected type with valid i1 shape");
return res;
}
@@ -588,7 +564,7 @@ static Type getI1SameShape(Builder *build, Type type) {
static void buildCmpIOp(Builder *build, OperationState &result,
CmpIPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
- result.types.push_back(getI1SameShape(build, lhs.getType()));
+ result.types.push_back(getI1SameShape(lhs.getType()));
result.addAttribute(
CmpIOp::getPredicateAttrName(),
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
@@ -618,7 +594,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
<< "unknown comparison predicate \"" << predicateName << "\"";
auto builder = parser.getBuilder();
- Type i1Type = getCheckedI1SameShape(&builder, type);
+ Type i1Type = getCheckedI1SameShape(type);
if (!i1Type)
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
@@ -741,7 +717,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
static void buildCmpFOp(Builder *build, OperationState &result,
CmpFPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
- result.types.push_back(getI1SameShape(build, lhs.getType()));
+ result.types.push_back(getI1SameShape(lhs.getType()));
result.addAttribute(
CmpFOp::getPredicateAttrName(),
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
@@ -772,7 +748,7 @@ static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
"\"");
auto builder = parser.getBuilder();
- Type i1Type = getCheckedI1SameShape(&builder, type);
+ Type i1Type = getCheckedI1SameShape(type);
if (!i1Type)
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
@@ -1534,13 +1510,8 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
}
static LogicalResult verify(ExtractElementOp op) {
- auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
-
- // This should be possible with tablegen type constraints
- if (op.getType() != aggregateType.getElementType())
- return op.emitOpError("result type must match element type of aggregate");
-
// Verify the # indices match if we have a ranked type.
+ auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
if (aggregateType.hasRank() &&
aggregateType.getRank() != op.getNumOperands() - 1)
return op.emitOpError("incorrect number of indices for extract_element");
@@ -1628,12 +1599,8 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
}
static LogicalResult verify(LoadOp op) {
- if (op.getType() != op.getMemRefType().getElementType())
- return op.emitOpError("result type must match element type of memref");
-
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
return op.emitOpError("incorrect number of indices for load");
-
return success();
}
@@ -1943,7 +1910,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
parser.parseColonType(type))
return failure();
- auto i1Type = getCheckedI1SameShape(&parser.getBuilder(), type);
+ auto i1Type = getCheckedI1SameShape(type);
if (!i1Type)
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
@@ -1959,17 +1926,6 @@ static void print(OpAsmPrinter &p, SelectOp op) {
p.printOptionalAttrDict(op.getAttrs());
}
-static LogicalResult verify(SelectOp op) {
- auto trueType = op.getTrueValue().getType();
- auto falseType = op.getFalseValue().getType();
-
- if (trueType != falseType)
- return op.emitOpError(
- "requires 'true' and 'false' arguments to be of the same type");
-
- return success();
-}
-
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
auto condition = getCondition();
@@ -2087,11 +2043,6 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
}
static LogicalResult verify(StoreOp op) {
- // First operand must have same type as memref element type.
- if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
- return op.emitOpError(
- "first operand must have same type memref element type");
-
if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
return op.emitOpError("store index operand count not equal to memref rank");
@@ -2198,10 +2149,10 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
// Helpers for Tensor[Load|Store]Op
//===----------------------------------------------------------------------===//
-static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
+static Type getTensorTypeFromMemRefType(Type type) {
if (auto memref = type.dyn_cast<MemRefType>())
return RankedTensorType::get(memref.getShape(), memref.getElementType());
- return b.getNoneType();
+ return NoneType::get(type.getContext());
}
//===----------------------------------------------------------------------===//
@@ -2218,13 +2169,12 @@ static ParseResult parseTensorLoadOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType op;
Type type;
- return failure(parser.parseOperand(op) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(op, type, result.operands) ||
- parser.addTypeToList(
- getTensorTypeFromMemRefType(parser.getBuilder(), type),
- result.types));
+ return failure(
+ parser.parseOperand(op) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(op, type, result.operands) ||
+ parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types));
}
//===----------------------------------------------------------------------===//
@@ -2246,9 +2196,8 @@ static ParseResult parseTensorStoreOp(OpAsmParser &parser,
parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
- parser.resolveOperands(
- ops, {getTensorTypeFromMemRefType(parser.getBuilder(), type), type},
- loc, result.operands));
+ parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type},
+ loc, result.operands));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 2a45820be7b0..3743a5cfcf3c 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -131,9 +131,9 @@ func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
// -----
func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
- %c = constant 3 : index
- // expected-error at +1 {{'vector.insertelement' op failed to verify that source operand and result have same element type}}
- %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, index) -> (vector<4xf32>)
+ %c = constant 3 : i32
+ // expected-error at +1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}}
+ %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>)
}
// -----
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 9f48d9d6bc70..cebc87ea94a1 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -275,7 +275,7 @@ func @func_with_ops(i32, i32, i32) {
func @func_with_ops(i1, i32, i64) {
^bb0(%cond : i1, %t : i32, %f : i64):
- // expected-error at +1 {{'true' and 'false' arguments to be of the same type}}
+ // expected-error at +1 {{all of {true_value, false_value, result} have same type}}
%r = "std.select"(%cond, %t, %f) : (i1, i32, i64) -> i32
}
@@ -460,7 +460,7 @@ func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
// -----
func @extract_element_element_result_type_mismatch(%v : vector<3xf32>, %i : index) {
- // expected-error at +1 {{result type must match element type of aggregate}}
+ // expected-error at +1 {{result type matches element type of aggregate}}
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, index) -> f64
return
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6af100ed95fa..a5962957be77 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1127,29 +1127,55 @@ void OpEmitter::genVerifier() {
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
auto &body = method.body();
+ const char *checkAttrSizedValueSegmentsCode = R"(
+ auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
+ auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
+ if (numElements != {1}) {{
+ return emitOpError("'{0}' attribute for specifying {2} segments "
+ "must have {1} elements");
+ }
+ )";
+
+ // Verify a few traits first so that we can use
+ // getODSOperands()/getODSResults() in the rest of the verifier.
+ for (auto &trait : op.getTraits()) {
+ if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
+ if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
+ body << formatv(checkAttrSizedValueSegmentsCode,
+ "operand_segment_sizes", op.getNumOperands(),
+ "operand");
+ } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
+ body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
+ op.getNumResults(), "result");
+ }
+ }
+ }
+
// Populate substitutions for attributes and named operands and results.
for (const auto &namedAttr : op.getAttributes())
verifyCtx.addSubst(namedAttr.name,
formatv("this->getAttr(\"{0}\")", namedAttr.name));
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
- // Skip from from first variadic operands for now. Else getOperand index
- // used below doesn't match.
+ if (value.name.empty())
+ continue;
+
if (value.isVariadic())
- break;
- if (!value.name.empty())
+ verifyCtx.addSubst(value.name, formatv("this->getODSOperands({0})", i));
+ else
verifyCtx.addSubst(value.name,
- formatv("this->getOperation()->getOperand({0})", i));
+ formatv("(*this->getODSOperands({0}).begin())", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
- // Skip from from first variadic results for now. Else getResult index used
- // below doesn't match.
+ if (value.name.empty())
+ continue;
+
if (value.isVariadic())
- break;
- if (!value.name.empty())
+ verifyCtx.addSubst(value.name, formatv("this->getODSResults({0})", i));
+ else
verifyCtx.addSubst(value.name,
- formatv("this->getOperation()->getResult({0})", i));
+ formatv("(*this->getODSResults({0}).begin())", i));
}
// Verify the attributes have the correct type.
@@ -1189,14 +1215,8 @@ void OpEmitter::genVerifier() {
body << " }\n";
}
- const char *code = R"(
- auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
- auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
- if (numElements != {1}) {{
- return emitOpError("'{0}' attribute for specifying {2} segments "
- "must have {1} elements");
- }
- )";
+ genOperandResultVerifier(body, op.getOperands(), "operand");
+ genOperandResultVerifier(body, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
@@ -1204,23 +1224,9 @@ void OpEmitter::genVerifier() {
"return emitOpError(\"failed to verify that $1\");\n }\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
t->getDescription());
- } else if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
- if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
- body << formatv(code, "operand_segment_sizes", op.getNumOperands(),
- "operand");
- } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
- body << formatv(code, "result_segment_sizes", op.getNumResults(),
- "result");
- }
}
}
- // These should happen after we verified the traits because
- // getODSOperands()/getODSResults() may depend on traits (e.g.,
- // AttrSizedOperandSegments/AttrSizedResultSegments).
- genOperandResultVerifier(body, op.getOperands(), "operand");
- genOperandResultVerifier(body, op.getResults(), "result");
-
genRegionVerifier(body);
if (hasCustomVerify) {
More information about the Mlir-commits
mailing list