[Mlir-commits] [mlir] bdc7ce9 - [mlir][NFC] Update Vector operations to use `hasVerifier` instead of `verifier`
River Riddle
llvmlistbot at llvm.org
Wed Feb 2 13:35:43 PST 2022
Author: River Riddle
Date: 2022-02-02T13:34:30-08:00
New Revision: bdc7ce975a8dae66a010320a11b4eb75b4c6c895
URL: https://github.com/llvm/llvm-project/commit/bdc7ce975a8dae66a010320a11b4eb75b4c6c895
DIFF: https://github.com/llvm/llvm-project/commit/bdc7ce975a8dae66a010320a11b4eb75b4c6c895.diff
LOG: [mlir][NFC] Update Vector operations to use `hasVerifier` instead of `verifier`
The verifier field is deprecated, and slated for removal.
Differential Revision: https://reviews.llvm.org/D118820
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ce3fe69c613b5..1a48f7d7d9bab 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -29,12 +29,10 @@ class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits> {
// For every vector op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
- // * LogicalResult verify(${C++ class of Op} op)
// * ParseResult parse${C++ class of Op}(OpAsmParser &parser,
// OperationState &result)
// functions.
let printer = [{ return ::print(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@@ -255,6 +253,7 @@ def Vector_ContractionOp :
}];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_ReductionOp :
@@ -290,6 +289,7 @@ def Vector_ReductionOp :
return vector().getType().cast<VectorType>();
}
}];
+ let hasVerifier = 1;
}
def Vector_MultiDimReductionOp :
@@ -373,6 +373,7 @@ def Vector_MultiDimReductionOp :
let assemblyFormat =
"$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_BroadcastOp :
@@ -420,6 +421,7 @@ def Vector_BroadcastOp :
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_ShuffleOp :
@@ -475,6 +477,7 @@ def Vector_ShuffleOp :
return vector().getType().cast<VectorType>();
}
}];
+ let hasVerifier = 1;
}
def Vector_ExtractElementOp :
@@ -521,6 +524,7 @@ def Vector_ExtractElementOp :
return vector().getType().cast<VectorType>();
}
}];
+ let hasVerifier = 1;
}
def Vector_ExtractOp :
@@ -555,6 +559,7 @@ def Vector_ExtractOp :
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_ExtractMapOp :
@@ -623,6 +628,7 @@ def Vector_ExtractMapOp :
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_FMAOp :
@@ -648,8 +654,6 @@ def Vector_FMAOp :
%3 = vector.fma %0, %1, %2: vector<8x16xf32>
```
}];
- // Fully specified by traits.
- let verifier = ?;
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)";
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc),
@@ -706,7 +710,7 @@ def Vector_InsertElementOp :
return dest().getType().cast<VectorType>();
}
}];
-
+ let hasVerifier = 1;
}
def Vector_InsertOp :
@@ -749,6 +753,7 @@ def Vector_InsertOp :
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_InsertMapOp :
@@ -816,6 +821,7 @@ def Vector_InsertMapOp :
$vector `,` $dest `[` $ids `]` attr-dict
`:` type($vector) `into` type($result)
}];
+ let hasVerifier = 1;
}
def Vector_InsertStridedSliceOp :
@@ -873,6 +879,7 @@ def Vector_InsertStridedSliceOp :
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_OuterProductOp :
@@ -960,6 +967,7 @@ def Vector_OuterProductOp :
return CombiningKind::ADD;
}
}];
+ let hasVerifier = 1;
}
// TODO: Add transformation which decomposes ReshapeOp into an optimized
@@ -1081,6 +1089,7 @@ def Vector_ReshapeOp :
$vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,`
$fixed_vector_sizes attr-dict `:` type($vector) `to` type($result)
}];
+ let hasVerifier = 1;
}
def Vector_ExtractStridedSliceOp :
@@ -1133,6 +1142,7 @@ def Vector_ExtractStridedSliceOp :
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
}
@@ -1340,6 +1350,7 @@ def Vector_TransferReadOp :
];
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_TransferWriteOp :
@@ -1477,6 +1488,7 @@ def Vector_TransferWriteOp :
];
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_LoadOp : Vector_Op<"load"> {
@@ -1552,6 +1564,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
}];
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat =
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
@@ -1628,6 +1641,7 @@ def Vector_StoreOp : Vector_Op<"store"> {
}];
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
"`:` type($base) `,` type($valueToStore)";
@@ -1687,6 +1701,7 @@ def Vector_MaskedLoadOp :
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_MaskedStoreOp :
@@ -1740,6 +1755,7 @@ def Vector_MaskedStoreOp :
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_GatherOp :
@@ -1805,6 +1821,7 @@ def Vector_GatherOp :
"type($index_vec) `,` type($mask) `,` type($pass_thru) "
"`into` type($result)";
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_ScatterOp :
@@ -1867,6 +1884,7 @@ def Vector_ScatterOp :
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_ExpandLoadOp :
@@ -1925,6 +1943,7 @@ def Vector_ExpandLoadOp :
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_CompressStoreOp :
@@ -1980,6 +1999,7 @@ def Vector_CompressStoreOp :
"$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
"type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_ShapeCastOp :
@@ -2031,6 +2051,7 @@ def Vector_ShapeCastOp :
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def Vector_BitCastOp :
@@ -2070,6 +2091,7 @@ def Vector_BitCastOp :
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_TypeCastOp :
@@ -2116,6 +2138,7 @@ def Vector_TypeCastOp :
let assemblyFormat = [{
$memref attr-dict `:` type($memref) `to` type($result)
}];
+ let hasVerifier = 1;
}
def Vector_ConstantMaskOp :
@@ -2157,6 +2180,7 @@ def Vector_ConstantMaskOp :
static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
}];
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
+ let hasVerifier = 1;
}
def Vector_CreateMaskOp :
@@ -2194,6 +2218,7 @@ def Vector_CreateMaskOp :
}];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
let assemblyFormat = "$operands attr-dict `:` type(results)";
}
@@ -2245,6 +2270,7 @@ def Vector_TransposeOp :
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
def Vector_PrintOp :
@@ -2272,7 +2298,6 @@ def Vector_PrintOp :
newline).
```
}];
- let verifier = ?;
let extraClassDeclaration = [{
Type getPrintType() {
return source().getType();
@@ -2348,7 +2373,6 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
lhs.getType().cast<VectorType>().getElementType()));
}]>,
];
- let verifier = ?;
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
}
@@ -2393,7 +2417,6 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
: (vector<16xf32>) -> vector<16xf32>
```
}];
- let verifier = ?;
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
}
@@ -2426,7 +2449,6 @@ def VectorScaleOp : Vector_Op<"vscale",
}];
let results = (outs Index:$res);
let assemblyFormat = "attr-dict";
- let verifier = ?;
}
//===----------------------------------------------------------------------===//
@@ -2485,6 +2507,7 @@ def Vector_ScanOp :
let assemblyFormat =
"$kind `,` $source `,` $initial_value attr-dict `:` "
"type($source) `,` type($initial_value) ";
+ let hasVerifier = 1;
}
#endif // VECTOR_OPS
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95821ff6f36d0..e034b98e5d0cc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -354,15 +354,15 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
CombiningKindAttr::get(kind, builder.getContext()));
}
-static LogicalResult verify(MultiDimReductionOp op) {
- auto reductionMask = op.getReductionMask();
+LogicalResult MultiDimReductionOp::verify() {
+ auto reductionMask = getReductionMask();
auto targetType = MultiDimReductionOp::inferDestType(
- op.getSourceVectorType().getShape(), reductionMask,
- op.getSourceVectorType().getElementType());
+ getSourceVectorType().getShape(), reductionMask,
+ getSourceVectorType().getElementType());
// TODO: update to support 0-d vectors when available.
- if (targetType != op.getDestType())
- return op.emitError("invalid output vector type: ")
- << op.getDestType() << " (expected: " << targetType << ")";
+ if (targetType != getDestType())
+ return emitError("invalid output vector type: ")
+ << getDestType() << " (expected: " << targetType << ")";
return success();
}
@@ -377,29 +377,29 @@ OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
// ReductionOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ReductionOp op) {
+LogicalResult ReductionOp::verify() {
// Verify for 1-D vector.
- int64_t rank = op.getVectorType().getRank();
+ int64_t rank = getVectorType().getRank();
if (rank != 1)
- return op.emitOpError("unsupported reduction rank: ") << rank;
+ return emitOpError("unsupported reduction rank: ") << rank;
// Verify supported reduction kind.
- StringRef strKind = op.kind();
+ StringRef strKind = kind();
auto maybeKind = symbolizeCombiningKind(strKind);
if (!maybeKind)
- return op.emitOpError("unknown reduction kind: ") << strKind;
+ return emitOpError("unknown reduction kind: ") << strKind;
- Type eltType = op.dest().getType();
+ Type eltType = dest().getType();
if (!isSupportedCombiningKind(*maybeKind, eltType))
- return op.emitOpError("unsupported reduction type '")
- << eltType << "' for kind '" << op.kind() << "'";
+ return emitOpError("unsupported reduction type '")
+ << eltType << "' for kind '" << strKind << "'";
// Verify optional accumulator.
- if (!op.acc().empty()) {
+ if (!acc().empty()) {
if (strKind != "add" && strKind != "mul")
- return op.emitOpError("no accumulator for reduction kind: ") << strKind;
+ return emitOpError("no accumulator for reduction kind: ") << strKind;
if (!eltType.isa<FloatType>())
- return op.emitOpError("no accumulator for type: ") << eltType;
+ return emitOpError("no accumulator for type: ") << eltType;
}
return success();
@@ -676,78 +676,78 @@ static LogicalResult verifyOutputShape(
return success();
}
-static LogicalResult verify(ContractionOp op) {
- auto lhsType = op.getLhsType();
- auto rhsType = op.getRhsType();
- auto accType = op.getAccType();
- auto resType = op.getResultType();
+LogicalResult ContractionOp::verify() {
+ auto lhsType = getLhsType();
+ auto rhsType = getRhsType();
+ auto accType = getAccType();
+ auto resType = getResultType();
// Verify that an indexing map was specified for each vector operand.
- if (op.indexing_maps().size() != 3)
- return op.emitOpError("expected an indexing map for each vector operand");
+ if (indexing_maps().size() != 3)
+ return emitOpError("expected an indexing map for each vector operand");
// Verify that each index map has 'numIterators' inputs, no symbols, and
// that the number of map outputs equals the rank of its associated
// vector operand.
- unsigned numIterators = op.iterator_types().getValue().size();
- for (const auto &it : llvm::enumerate(op.indexing_maps())) {
+ unsigned numIterators = iterator_types().getValue().size();
+ for (const auto &it : llvm::enumerate(indexing_maps())) {
auto index = it.index();
auto map = it.value().cast<AffineMapAttr>().getValue();
if (map.getNumSymbols() != 0)
- return op.emitOpError("expected indexing map ")
+ return emitOpError("expected indexing map ")
<< index << " to have no symbols";
- auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
+ auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
unsigned rank = vectorType ? vectorType.getShape().size() : 0;
// Verify that the map has the right number of inputs, outputs, and indices.
// This also correctly accounts for (..) -> () for rank-0 results.
if (map.getNumDims() != numIterators)
- return op.emitOpError("expected indexing map ")
+ return emitOpError("expected indexing map ")
<< index << " to have " << numIterators << " number of inputs";
if (map.getNumResults() != rank)
- return op.emitOpError("expected indexing map ")
+ return emitOpError("expected indexing map ")
<< index << " to have " << rank << " number of outputs";
if (!map.isProjectedPermutation())
- return op.emitOpError("expected indexing map ")
+ return emitOpError("expected indexing map ")
<< index << " to be a projected permutation of its inputs";
}
- auto contractingDimMap = op.getContractingDimMap();
- auto batchDimMap = op.getBatchDimMap();
+ auto contractingDimMap = getContractingDimMap();
+ auto batchDimMap = getBatchDimMap();
// Verify at least one contracting dimension pair was specified.
if (contractingDimMap.empty())
- return op.emitOpError("expected at least one contracting dimension pair");
+ return emitOpError("expected at least one contracting dimension pair");
// Verify contracting dimension map was properly constructed.
if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
- return op.emitOpError("invalid contracting dimension map");
+ return emitOpError("invalid contracting dimension map");
// Verify batch dimension map was properly constructed.
if (!verifyDimMap(lhsType, rhsType, batchDimMap))
- return op.emitOpError("invalid batch dimension map");
+ return emitOpError("invalid batch dimension map");
// Verify 'accType' and 'resType' shape.
- if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
+ if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
contractingDimMap, batchDimMap)))
return failure();
// Verify that either two vector masks are set or none are set.
- auto lhsMaskType = op.getLHSVectorMaskType();
- auto rhsMaskType = op.getRHSVectorMaskType();
+ auto lhsMaskType = getLHSVectorMaskType();
+ auto rhsMaskType = getRHSVectorMaskType();
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
- return op.emitOpError("invalid number of vector masks specified");
+ return emitOpError("invalid number of vector masks specified");
if (lhsMaskType && rhsMaskType) {
// Verify mask rank == argument rank.
if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
rhsMaskType.getShape().size() != rhsType.getShape().size())
- return op.emitOpError("invalid vector mask rank");
+ return emitOpError("invalid vector mask rank");
}
// Verify supported combining kind.
auto vectorType = resType.dyn_cast<VectorType>();
auto elementType = vectorType ? vectorType.getElementType() : resType;
- if (!isSupportedCombiningKind(op.kind(), elementType))
- return op.emitOpError("unsupported contraction type");
+ if (!isSupportedCombiningKind(kind(), elementType))
+ return emitOpError("unsupported contraction type");
return success();
}
@@ -923,17 +923,17 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(source.getType().cast<VectorType>().getElementType());
}
-static LogicalResult verify(vector::ExtractElementOp op) {
- VectorType vectorType = op.getVectorType();
+LogicalResult vector::ExtractElementOp::verify() {
+ VectorType vectorType = getVectorType();
if (vectorType.getRank() == 0) {
- if (op.position())
- return op.emitOpError("expected position to be empty with 0-D vector");
+ if (position())
+ return emitOpError("expected position to be empty with 0-D vector");
return success();
}
if (vectorType.getRank() != 1)
- return op.emitOpError("unexpected >1 vector rank");
- if (!op.position())
- return op.emitOpError("expected position for 1-D vector");
+ return emitOpError("unexpected >1 vector rank");
+ if (!position())
+ return emitOpError("expected position for 1-D vector");
return success();
}
@@ -1003,16 +1003,16 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resType, result.types));
}
-static LogicalResult verify(vector::ExtractOp op) {
- auto positionAttr = op.position().getValue();
- if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
- return op.emitOpError(
+LogicalResult vector::ExtractOp::verify() {
+ auto positionAttr = position().getValue();
+ if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
+ return emitOpError(
"expected position attribute of rank smaller than vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 ||
- attr.getInt() >= op.getVectorType().getDimSize(en.index()))
- return op.emitOpError("expected position attribute #")
+ attr.getInt() >= getVectorType().getDimSize(en.index()))
+ return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
"vector dimension";
@@ -1565,24 +1565,21 @@ void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
ExtractMapOp::build(builder, result, resultType, vector, ids);
}
-static LogicalResult verify(ExtractMapOp op) {
- if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
- return op.emitOpError(
- "expected source and destination vectors of same rank");
+LogicalResult ExtractMapOp::verify() {
+ if (getSourceVectorType().getRank() != getResultType().getRank())
+ return emitOpError("expected source and destination vectors of same rank");
unsigned numId = 0;
- for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
- if (op.getSourceVectorType().getDimSize(i) %
- op.getResultType().getDimSize(i) !=
+ for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) {
+ if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) !=
0)
- return op.emitOpError("source vector dimensions must be a multiple of "
- "destination vector dimensions");
- if (op.getSourceVectorType().getDimSize(i) !=
- op.getResultType().getDimSize(i))
+ return emitOpError("source vector dimensions must be a multiple of "
+ "destination vector dimensions");
+ if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
numId++;
}
- if (numId != op.ids().size())
- return op.emitOpError("expected number of ids must match the number of "
- "dimensions distributed");
+ if (numId != ids().size())
+ return emitOpError("expected number of ids must match the number of "
+ "dimensions distributed");
return success();
}
@@ -1666,19 +1663,19 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
return BroadcastableToResult::Success;
}
-static LogicalResult verify(BroadcastOp op) {
+LogicalResult BroadcastOp::verify() {
std::pair<int, int> mismatchingDims;
- BroadcastableToResult res = isBroadcastableTo(
- op.getSourceType(), op.getVectorType(), &mismatchingDims);
+ BroadcastableToResult res =
+ isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
- return op.emitOpError("source rank higher than destination rank");
+ return emitOpError("source rank higher than destination rank");
if (res == BroadcastableToResult::DimensionMismatch)
- return op.emitOpError("dimension mismatch (")
+ return emitOpError("dimension mismatch (")
<< mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
if (res == BroadcastableToResult::SourceTypeNotAVector)
- return op.emitOpError("source type is not a vector");
+ return emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
}
@@ -1741,36 +1738,35 @@ static void print(OpAsmPrinter &p, ShuffleOp op) {
p << " : " << op.v1().getType() << ", " << op.v2().getType();
}
-static LogicalResult verify(ShuffleOp op) {
- VectorType resultType = op.getVectorType();
- VectorType v1Type = op.getV1VectorType();
- VectorType v2Type = op.getV2VectorType();
+LogicalResult ShuffleOp::verify() {
+ VectorType resultType = getVectorType();
+ VectorType v1Type = getV1VectorType();
+ VectorType v2Type = getV2VectorType();
// Verify ranks.
int64_t resRank = resultType.getRank();
int64_t v1Rank = v1Type.getRank();
int64_t v2Rank = v2Type.getRank();
if (resRank != v1Rank || v1Rank != v2Rank)
- return op.emitOpError("rank mismatch");
+ return emitOpError("rank mismatch");
// Verify all but leading dimension sizes.
for (int64_t r = 1; r < v1Rank; ++r) {
int64_t resDim = resultType.getDimSize(r);
int64_t v1Dim = v1Type.getDimSize(r);
int64_t v2Dim = v2Type.getDimSize(r);
if (resDim != v1Dim || v1Dim != v2Dim)
- return op.emitOpError("dimension mismatch");
+ return emitOpError("dimension mismatch");
}
// Verify mask length.
- auto maskAttr = op.mask().getValue();
+ auto maskAttr = mask().getValue();
int64_t maskLength = maskAttr.size();
if (maskLength != resultType.getDimSize(0))
- return op.emitOpError("mask length mismatch");
+ return emitOpError("mask length mismatch");
// Verify all indices.
int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
for (const auto &en : llvm::enumerate(maskAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
- return op.emitOpError("mask index #")
- << (en.index() + 1) << " out of range";
+ return emitOpError("mask index #") << (en.index() + 1) << " out of range";
}
return success();
}
@@ -1824,17 +1820,17 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(dest.getType());
}
-static LogicalResult verify(InsertElementOp op) {
- auto dstVectorType = op.getDestVectorType();
+LogicalResult InsertElementOp::verify() {
+ auto dstVectorType = getDestVectorType();
if (dstVectorType.getRank() == 0) {
- if (op.position())
- return op.emitOpError("expected position to be empty with 0-D vector");
+ if (position())
+ return emitOpError("expected position to be empty with 0-D vector");
return success();
}
if (dstVectorType.getRank() != 1)
- return op.emitOpError("unexpected >1 vector rank");
- if (!op.position())
- return op.emitOpError("expected position for 1-D vector");
+ return emitOpError("unexpected >1 vector rank");
+ if (!position())
+ return emitOpError("expected position for 1-D vector");
return success();
}
@@ -1860,27 +1856,27 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
build(builder, result, source, dest, positionConstants);
}
-static LogicalResult verify(InsertOp op) {
- auto positionAttr = op.position().getValue();
- auto destVectorType = op.getDestVectorType();
+LogicalResult InsertOp::verify() {
+ auto positionAttr = position().getValue();
+ auto destVectorType = getDestVectorType();
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
- return op.emitOpError(
+ return emitOpError(
"expected position attribute of rank smaller than dest vector rank");
- auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
+ auto srcVectorType = getSourceType().dyn_cast<VectorType>();
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
- return op.emitOpError("expected position attribute rank + source rank to "
+ return emitOpError("expected position attribute rank + source rank to "
"match dest vector rank");
if (!srcVectorType &&
(positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
- return op.emitOpError(
+ return emitOpError(
"expected position attribute rank to match the dest vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 ||
attr.getInt() >= destVectorType.getDimSize(en.index()))
- return op.emitOpError("expected position attribute #")
+ return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
"dest vector dimension";
@@ -1933,24 +1929,21 @@ void InsertMapOp::build(OpBuilder &builder, OperationState &result,
InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
}
-static LogicalResult verify(InsertMapOp op) {
- if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
- return op.emitOpError(
- "expected source and destination vectors of same rank");
+LogicalResult InsertMapOp::verify() {
+ if (getSourceVectorType().getRank() != getResultType().getRank())
+ return emitOpError("expected source and destination vectors of same rank");
unsigned numId = 0;
- for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
- if (op.getResultType().getDimSize(i) %
- op.getSourceVectorType().getDimSize(i) !=
+ for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) {
+ if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) !=
0)
- return op.emitOpError(
+ return emitOpError(
"destination vector size must be a multiple of source vector size");
- if (op.getResultType().getDimSize(i) !=
- op.getSourceVectorType().getDimSize(i))
+ if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i))
numId++;
}
- if (numId != op.ids().size())
- return op.emitOpError("expected number of ids must match the number of "
- "dimensions distributed");
+ if (numId != ids().size())
+ return emitOpError("expected number of ids must match the number of "
+ "dimensions distributed");
return success();
}
@@ -2062,19 +2055,18 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
}
-static LogicalResult verify(InsertStridedSliceOp op) {
- auto sourceVectorType = op.getSourceVectorType();
- auto destVectorType = op.getDestVectorType();
- auto offsets = op.offsets();
- auto strides = op.strides();
+LogicalResult InsertStridedSliceOp::verify() {
+ auto sourceVectorType = getSourceVectorType();
+ auto destVectorType = getDestVectorType();
+ auto offsets = offsetsAttr();
+ auto strides = stridesAttr();
if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
- return op.emitOpError(
+ return emitOpError(
"expected offsets of same size as destination vector rank");
if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
- return op.emitOpError(
- "expected strides of same size as source vector rank");
+ return emitOpError("expected strides of same size as source vector rank");
if (sourceVectorType.getRank() > destVectorType.getRank())
- return op.emitOpError(
+ return emitOpError(
"expected source rank to be smaller than destination rank");
auto sourceShape = sourceVectorType.getShape();
@@ -2084,13 +2076,14 @@ static LogicalResult verify(InsertStridedSliceOp op) {
sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
auto offName = InsertStridedSliceOp::getOffsetsAttrName();
auto stridesName = InsertStridedSliceOp::getStridesAttrName();
- if (failed(
- isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
- failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
+ if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
+ offName)) ||
+ failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
+ stridesName,
/*halfOpen=*/false)) ||
failed(isSumOfIntegerArrayAttrConfinedToShape(
- op, offsets,
- makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
+ *this, offsets,
+ makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
offName, "source vector shape",
/*halfOpen=*/false, /*min=*/1)))
return failure();
@@ -2161,39 +2154,39 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
parser.addTypeToList(resType, result.types));
}
-static LogicalResult verify(OuterProductOp op) {
- Type tRHS = op.getOperandTypeRHS();
- VectorType vLHS = op.getOperandVectorTypeLHS(),
+LogicalResult OuterProductOp::verify() {
+ Type tRHS = getOperandTypeRHS();
+ VectorType vLHS = getOperandVectorTypeLHS(),
vRHS = tRHS.dyn_cast<VectorType>(),
- vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
+ vACC = getOperandVectorTypeACC(), vRES = getVectorType();
if (vLHS.getRank() != 1)
- return op.emitOpError("expected 1-d vector for operand #1");
+ return emitOpError("expected 1-d vector for operand #1");
if (vRHS) {
// Proper OUTER operation.
if (vRHS.getRank() != 1)
- return op.emitOpError("expected 1-d vector for operand #2");
+ return emitOpError("expected 1-d vector for operand #2");
if (vRES.getRank() != 2)
- return op.emitOpError("expected 2-d vector result");
+ return emitOpError("expected 2-d vector result");
if (vLHS.getDimSize(0) != vRES.getDimSize(0))
- return op.emitOpError("expected #1 operand dim to match result dim #1");
+ return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
- return op.emitOpError("expected #2 operand dim to match result dim #2");
+ return emitOpError("expected #2 operand dim to match result dim #2");
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
- return op.emitOpError("expected 1-d vector result");
+ return emitOpError("expected 1-d vector result");
if (vLHS.getDimSize(0) != vRES.getDimSize(0))
- return op.emitOpError("expected #1 operand dim to match result dim #1");
+ return emitOpError("expected #1 operand dim to match result dim #1");
}
if (vACC && vACC != vRES)
- return op.emitOpError("expected operand #3 of same type as result type");
+ return emitOpError("expected operand #3 of same type as result type");
// Verify supported combining kind.
- if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
- return op.emitOpError("unsupported outerproduct type");
+ if (!isSupportedCombiningKind(kind(), vRES.getElementType()))
+ return emitOpError("unsupported outerproduct type");
return success();
}
@@ -2202,22 +2195,22 @@ static LogicalResult verify(OuterProductOp op) {
// ReshapeOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ReshapeOp op) {
+LogicalResult ReshapeOp::verify() {
// Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
- auto inputVectorType = op.getInputVectorType();
- auto outputVectorType = op.getOutputVectorType();
- int64_t inputShapeRank = op.getNumInputShapeSizes();
- int64_t outputShapeRank = op.getNumOutputShapeSizes();
+ auto inputVectorType = getInputVectorType();
+ auto outputVectorType = getOutputVectorType();
+ int64_t inputShapeRank = getNumInputShapeSizes();
+ int64_t outputShapeRank = getNumOutputShapeSizes();
SmallVector<int64_t, 4> fixedVectorSizes;
- op.getFixedVectorSizes(fixedVectorSizes);
+ getFixedVectorSizes(fixedVectorSizes);
int64_t numFixedVectorSizes = fixedVectorSizes.size();
if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
- return op.emitError("invalid input shape for vector type ")
+ return emitError("invalid input shape for vector type ")
<< inputVectorType;
if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
- return op.emitError("invalid output shape for vector type ")
+ return emitError("invalid output shape for vector type ")
<< outputVectorType;
// Verify that the 'fixedVectorSizes' match an input/output vector shape
@@ -2226,7 +2219,7 @@ static LogicalResult verify(ReshapeOp op) {
for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
unsigned index = inputVectorRank - numFixedVectorSizes - i;
if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
- return op.emitError("fixed vector size must match input vector for dim ")
+ return emitError("fixed vector size must match input vector for dim ")
<< i;
}
@@ -2234,7 +2227,7 @@ static LogicalResult verify(ReshapeOp op) {
for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
unsigned index = outputVectorRank - numFixedVectorSizes - i;
if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
- return op.emitError("fixed vector size must match output vector for dim ")
+ return emitError("fixed vector size must match output vector for dim ")
<< i;
}
@@ -2243,18 +2236,18 @@ static LogicalResult verify(ReshapeOp op) {
auto isDefByConstant = [](Value operand) {
return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
};
- if (llvm::all_of(op.input_shape(), isDefByConstant) &&
- llvm::all_of(op.output_shape(), isDefByConstant)) {
+ if (llvm::all_of(input_shape(), isDefByConstant) &&
+ llvm::all_of(output_shape(), isDefByConstant)) {
int64_t numInputElements = 1;
- for (auto operand : op.input_shape())
+ for (auto operand : input_shape())
numInputElements *=
cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
int64_t numOutputElements = 1;
- for (auto operand : op.output_shape())
+ for (auto operand : output_shape())
numOutputElements *=
cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
if (numInputElements != numOutputElements)
- return op.emitError("product of input and output shape sizes must match");
+ return emitError("product of input and output shape sizes must match");
}
return success();
}
@@ -2301,42 +2294,37 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
result.addAttribute(getStridesAttrName(), stridesAttr);
}
-static LogicalResult verify(ExtractStridedSliceOp op) {
- auto type = op.getVectorType();
- auto offsets = op.offsets();
- auto sizes = op.sizes();
- auto strides = op.strides();
- if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
- op.emitOpError(
- "expected offsets, sizes and strides attributes of same size");
- return failure();
- }
+LogicalResult ExtractStridedSliceOp::verify() {
+ auto type = getVectorType();
+ auto offsets = offsetsAttr();
+ auto sizes = sizesAttr();
+ auto strides = stridesAttr();
+ if (offsets.size() != sizes.size() || offsets.size() != strides.size())
+ return emitOpError("expected offsets, sizes and strides attributes of same size");
auto shape = type.getShape();
- auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
- auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
- auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
- if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
- failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
- failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
+ auto offName = getOffsetsAttrName();
+ auto sizesName = getSizesAttrName();
+ auto stridesName = getStridesAttrName();
+ if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
+ failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
+ failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
stridesName)) ||
- failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
- failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
+ failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
+ failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
/*halfOpen=*/false,
/*min=*/1)) ||
- failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
+ failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName,
/*halfOpen=*/false)) ||
- failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
+ failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape,
offName, sizesName,
/*halfOpen=*/false)))
return failure();
- auto resultType = inferStridedSliceOpResultType(
- op.getVectorType(), op.offsets(), op.sizes(), op.strides());
- if (op.getResult().getType() != resultType) {
- op.emitOpError("expected result type to be ") << resultType;
- return failure();
- }
+ auto resultType =
+ inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
+ if (getResult().getType() != resultType)
+ return emitOpError("expected result type to be ") << resultType;
return success();
}
@@ -2828,44 +2816,43 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
return parser.addTypeToList(vectorType, result.types);
}
-static LogicalResult verify(TransferReadOp op) {
+LogicalResult TransferReadOp::verify() {
// Consistency of elemental types in source and vector.
- ShapedType shapedType = op.getShapedType();
- VectorType vectorType = op.getVectorType();
- VectorType maskType = op.getMaskType();
- auto paddingType = op.padding().getType();
- auto permutationMap = op.permutation_map();
+ ShapedType shapedType = getShapedType();
+ VectorType vectorType = getVectorType();
+ VectorType maskType = getMaskType();
+ auto paddingType = padding().getType();
+ auto permutationMap = permutation_map();
auto sourceElementType = shapedType.getElementType();
- if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
- return op.emitOpError("requires ") << shapedType.getRank() << " indices";
+ if (static_cast<int64_t>(indices().size()) != shapedType.getRank())
+ return emitOpError("requires ") << shapedType.getRank() << " indices";
- if (failed(
- verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
- shapedType, vectorType, maskType, permutationMap,
- op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
+ if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
+ shapedType, vectorType, maskType, permutationMap,
+ in_bounds() ? *in_bounds() : ArrayAttr())))
return failure();
if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
// Source has vector element type.
// Check that 'sourceVectorElementType' and 'paddingType' types match.
if (sourceVectorElementType != paddingType)
- return op.emitOpError(
+ return emitOpError(
"requires source element type and padding type to match.");
} else {
// Check that 'paddingType' is valid to store in a vector type.
if (!VectorType::isValidElementType(paddingType))
- return op.emitOpError("requires valid padding vector elemental type");
+ return emitOpError("requires valid padding vector elemental type");
// Check that padding type and vector element types match.
if (paddingType != sourceElementType)
- return op.emitOpError(
+ return emitOpError(
"requires formal padding and source of the same elemental type");
}
return verifyPermutationMap(permutationMap,
- [&op](Twine t) { return op.emitOpError(t); });
+ [&](Twine t) { return emitOpError(t); });
}
/// This is a common class used for patterns of the form
@@ -3208,29 +3195,28 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << " : " << op.getVectorType() << ", " << op.getShapedType();
}
-static LogicalResult verify(TransferWriteOp op) {
+LogicalResult TransferWriteOp::verify() {
// Consistency of elemental types in shape and vector.
- ShapedType shapedType = op.getShapedType();
- VectorType vectorType = op.getVectorType();
- VectorType maskType = op.getMaskType();
- auto permutationMap = op.permutation_map();
+ ShapedType shapedType = getShapedType();
+ VectorType vectorType = getVectorType();
+ VectorType maskType = getMaskType();
+ auto permutationMap = permutation_map();
- if (llvm::size(op.indices()) != shapedType.getRank())
- return op.emitOpError("requires ") << shapedType.getRank() << " indices";
+ if (llvm::size(indices()) != shapedType.getRank())
+ return emitOpError("requires ") << shapedType.getRank() << " indices";
// We do not allow broadcast dimensions on TransferWriteOps for the moment,
// as the semantics is unclear. This can be revisited later if necessary.
- if (op.hasBroadcastDim())
- return op.emitOpError("should not have broadcast dimensions");
+ if (hasBroadcastDim())
+ return emitOpError("should not have broadcast dimensions");
- if (failed(
- verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
- shapedType, vectorType, maskType, permutationMap,
- op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
+ if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
+ shapedType, vectorType, maskType, permutationMap,
+ in_bounds() ? *in_bounds() : ArrayAttr())))
return failure();
return verifyPermutationMap(permutationMap,
- [&op](Twine t) { return op.emitOpError(t); });
+ [&](Twine t) { return emitOpError(t); });
}
/// Fold:
@@ -3514,25 +3500,25 @@ static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
return success();
}
-static LogicalResult verify(vector::LoadOp op) {
- VectorType resVecTy = op.getVectorType();
- MemRefType memRefTy = op.getMemRefType();
+LogicalResult vector::LoadOp::verify() {
+ VectorType resVecTy = getVectorType();
+ MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
return failure();
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
if (memVecTy != resVecTy)
- return op.emitOpError("base memref and result vector types should match");
+ return emitOpError("base memref and result vector types should match");
memElemTy = memVecTy.getElementType();
}
if (resVecTy.getElementType() != memElemTy)
- return op.emitOpError("base and result element types should match");
- if (llvm::size(op.indices()) != memRefTy.getRank())
- return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
+ return emitOpError("base and result element types should match");
+ if (llvm::size(indices()) != memRefTy.getRank())
+ return emitOpError("requires ") << memRefTy.getRank() << " indices";
return success();
}
@@ -3546,26 +3532,26 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
// StoreOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(vector::StoreOp op) {
- VectorType valueVecTy = op.getVectorType();
- MemRefType memRefTy = op.getMemRefType();
+LogicalResult vector::StoreOp::verify() {
+ VectorType valueVecTy = getVectorType();
+ MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
return failure();
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
if (memVecTy != valueVecTy)
- return op.emitOpError(
+ return emitOpError(
"base memref and valueToStore vector types should match");
memElemTy = memVecTy.getElementType();
}
if (valueVecTy.getElementType() != memElemTy)
- return op.emitOpError("base and valueToStore element type should match");
- if (llvm::size(op.indices()) != memRefTy.getRank())
- return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
+ return emitOpError("base and valueToStore element type should match");
+ if (llvm::size(indices()) != memRefTy.getRank())
+ return emitOpError("requires ") << memRefTy.getRank() << " indices";
return success();
}
@@ -3578,20 +3564,20 @@ LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
// MaskedLoadOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(MaskedLoadOp op) {
- VectorType maskVType = op.getMaskVectorType();
- VectorType passVType = op.getPassThruVectorType();
- VectorType resVType = op.getVectorType();
- MemRefType memType = op.getMemRefType();
+LogicalResult MaskedLoadOp::verify() {
+ VectorType maskVType = getMaskVectorType();
+ VectorType passVType = getPassThruVectorType();
+ VectorType resVType = getVectorType();
+ MemRefType memType = getMemRefType();
if (resVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and result element type should match");
- if (llvm::size(op.indices()) != memType.getRank())
- return op.emitOpError("requires ") << memType.getRank() << " indices";
+ return emitOpError("base and result element type should match");
+ if (llvm::size(indices()) != memType.getRank())
+ return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected result dim to match mask dim");
+ return emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
- return op.emitOpError("expected pass_thru of same type as result type");
+ return emitOpError("expected pass_thru of same type as result type");
return success();
}
@@ -3632,17 +3618,17 @@ OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
// MaskedStoreOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(MaskedStoreOp op) {
- VectorType maskVType = op.getMaskVectorType();
- VectorType valueVType = op.getVectorType();
- MemRefType memType = op.getMemRefType();
+LogicalResult MaskedStoreOp::verify() {
+ VectorType maskVType = getMaskVectorType();
+ VectorType valueVType = getVectorType();
+ MemRefType memType = getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and valueToStore element type should match");
- if (llvm::size(op.indices()) != memType.getRank())
- return op.emitOpError("requires ") << memType.getRank() << " indices";
+ return emitOpError("base and valueToStore element type should match");
+ if (llvm::size(indices()) != memType.getRank())
+ return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected valueToStore dim to match mask dim");
+ return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
@@ -3682,22 +3668,22 @@ LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
// GatherOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(GatherOp op) {
- VectorType indVType = op.getIndexVectorType();
- VectorType maskVType = op.getMaskVectorType();
- VectorType resVType = op.getVectorType();
- MemRefType memType = op.getMemRefType();
+LogicalResult GatherOp::verify() {
+ VectorType indVType = getIndexVectorType();
+ VectorType maskVType = getMaskVectorType();
+ VectorType resVType = getVectorType();
+ MemRefType memType = getMemRefType();
if (resVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and result element type should match");
- if (llvm::size(op.indices()) != memType.getRank())
- return op.emitOpError("requires ") << memType.getRank() << " indices";
+ return emitOpError("base and result element type should match");
+ if (llvm::size(indices()) != memType.getRank())
+ return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != indVType.getDimSize(0))
- return op.emitOpError("expected result dim to match indices dim");
+ return emitOpError("expected result dim to match indices dim");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected result dim to match mask dim");
- if (resVType != op.getPassThruVectorType())
- return op.emitOpError("expected pass_thru of same type as result type");
+ return emitOpError("expected result dim to match mask dim");
+ if (resVType != getPassThruVectorType())
+ return emitOpError("expected pass_thru of same type as result type");
return success();
}
@@ -3730,20 +3716,20 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ScatterOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ScatterOp op) {
- VectorType indVType = op.getIndexVectorType();
- VectorType maskVType = op.getMaskVectorType();
- VectorType valueVType = op.getVectorType();
- MemRefType memType = op.getMemRefType();
+LogicalResult ScatterOp::verify() {
+ VectorType indVType = getIndexVectorType();
+ VectorType maskVType = getMaskVectorType();
+ VectorType valueVType = getVectorType();
+ MemRefType memType = getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and valueToStore element type should match");
- if (llvm::size(op.indices()) != memType.getRank())
- return op.emitOpError("requires ") << memType.getRank() << " indices";
+ return emitOpError("base and valueToStore element type should match");
+ if (llvm::size(indices()) != memType.getRank())
+ return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != indVType.getDimSize(0))
- return op.emitOpError("expected valueToStore dim to match indices dim");
+ return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected valueToStore dim to match mask dim");
+ return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
@@ -3776,20 +3762,20 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExpandLoadOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ExpandLoadOp op) {
- VectorType maskVType = op.getMaskVectorType();
- VectorType passVType = op.getPassThruVectorType();
- VectorType resVType = op.getVectorType();
- MemRefType memType = op.getMemRefType();
+LogicalResult ExpandLoadOp::verify() {
+ VectorType maskVType = getMaskVectorType();
+ VectorType passVType = getPassThruVectorType();
+ VectorType resVType = getVectorType();
+ MemRefType memType = getMemRefType();
if (resVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and result element type should match");
- if (llvm::size(op.indices()) != memType.getRank())
- return op.emitOpError("requires ") << memType.getRank() << " indices";
+ return emitOpError("base and result element type should match");
+ if (llvm::size(indices()) != memType.getRank())
+ return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected result dim to match mask dim");
+ return emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
- return op.emitOpError("expected pass_thru of same type as result type");
+ return emitOpError("expected pass_thru of same type as result type");
return success();
}
@@ -3824,17 +3810,17 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
// CompressStoreOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(CompressStoreOp op) {
- VectorType maskVType = op.getMaskVectorType();
- VectorType valueVType = op.getVectorType();
- MemRefType memType = op.getMemRefType();
+LogicalResult CompressStoreOp::verify() {
+ VectorType maskVType = getMaskVectorType();
+ VectorType valueVType = getVectorType();
+ MemRefType memType = getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and valueToStore element type should match");
- if (llvm::size(op.indices()) != memType.getRank())
- return op.emitOpError("requires ") << memType.getRank() << " indices";
+ return emitOpError("base and valueToStore element type should match");
+ if (llvm::size(indices()) != memType.getRank())
+ return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected valueToStore dim to match mask dim");
+ return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
@@ -3930,13 +3916,13 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
return success();
}
-static LogicalResult verify(ShapeCastOp op) {
- auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
- auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
+LogicalResult ShapeCastOp::verify() {
+ auto sourceVectorType = source().getType().dyn_cast_or_null<VectorType>();
+ auto resultVectorType = result().getType().dyn_cast_or_null<VectorType>();
// Check if source/result are of vector type.
if (sourceVectorType && resultVectorType)
- return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
+ return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
return success();
}
@@ -4005,16 +3991,16 @@ void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
// VectorBitCastOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(BitCastOp op) {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+LogicalResult BitCastOp::verify() {
+ auto sourceVectorType = getSourceVectorType();
+ auto resultVectorType = getResultVectorType();
for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
- return op.emitOpError("dimension size mismatch at: ") << i;
+ return emitOpError("dimension size mismatch at: ") << i;
}
- DataLayout dataLayout = DataLayout::closest(op);
+ DataLayout dataLayout = DataLayout::closest(*this);
auto sourceElementBits =
dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
auto resultElementBits =
@@ -4022,11 +4008,11 @@ static LogicalResult verify(BitCastOp op) {
if (sourceVectorType.getRank() == 0) {
if (sourceElementBits != resultElementBits)
- return op.emitOpError("source/result bitwidth of the 0-D vector element "
+ return emitOpError("source/result bitwidth of the 0-D vector element "
"types must be equal");
} else if (sourceElementBits * sourceVectorType.getShape().back() !=
resultElementBits * resultVectorType.getShape().back()) {
- return op.emitOpError(
+ return emitOpError(
"source/result bitwidth of the minor 1-D vectors must be equal");
}
@@ -4096,26 +4082,25 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
memRefType.getMemorySpace()));
}
-static LogicalResult verify(TypeCastOp op) {
- MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
+LogicalResult TypeCastOp::verify() {
+ MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
if (!canonicalType.getLayout().isIdentity())
- return op.emitOpError(
- "expects operand to be a memref with identity layout");
- if (!op.getResultMemRefType().getLayout().isIdentity())
- return op.emitOpError("expects result to be a memref with identity layout");
- if (op.getResultMemRefType().getMemorySpace() !=
- op.getMemRefType().getMemorySpace())
- return op.emitOpError("expects result in same memory space");
-
- auto sourceType = op.getMemRefType();
- auto resultType = op.getResultMemRefType();
+ return emitOpError("expects operand to be a memref with identity layout");
+ if (!getResultMemRefType().getLayout().isIdentity())
+ return emitOpError("expects result to be a memref with identity layout");
+ if (getResultMemRefType().getMemorySpace() !=
+ getMemRefType().getMemorySpace())
+ return emitOpError("expects result in same memory space");
+
+ auto sourceType = getMemRefType();
+ auto resultType = getResultMemRefType();
if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
- return op.emitOpError(
+ return emitOpError(
"expects result and operand with same underlying scalar type: ")
<< resultType;
if (extractShape(sourceType) != extractShape(resultType))
- return op.emitOpError(
+ return emitOpError(
"expects concatenated result and operand shapes to be equal: ")
<< resultType;
return success();
@@ -4154,27 +4139,27 @@ OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
return vector();
}
-static LogicalResult verify(vector::TransposeOp op) {
- VectorType vectorType = op.getVectorType();
- VectorType resultType = op.getResultType();
+LogicalResult vector::TransposeOp::verify() {
+ VectorType vectorType = getVectorType();
+ VectorType resultType = getResultType();
int64_t rank = resultType.getRank();
if (vectorType.getRank() != rank)
- return op.emitOpError("vector result rank mismatch: ") << rank;
+ return emitOpError("vector result rank mismatch: ") << rank;
// Verify transposition array.
- auto transpAttr = op.transp().getValue();
+ auto transpAttr = transp().getValue();
int64_t size = transpAttr.size();
if (rank != size)
- return op.emitOpError("transposition length mismatch: ") << size;
+ return emitOpError("transposition length mismatch: ") << size;
SmallVector<bool, 8> seen(rank, false);
for (const auto &ta : llvm::enumerate(transpAttr)) {
int64_t i = ta.value().cast<IntegerAttr>().getInt();
if (i < 0 || i >= rank)
- return op.emitOpError("transposition index out of range: ") << i;
+ return emitOpError("transposition index out of range: ") << i;
if (seen[i])
- return op.emitOpError("duplicate position index: ") << i;
+ return emitOpError("duplicate position index: ") << i;
seen[i] = true;
if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
- return op.emitOpError("dimension size mismatch at: ") << i;
+ return emitOpError("dimension size mismatch at: ") << i;
}
return success();
}
@@ -4236,31 +4221,30 @@ void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
// ConstantMaskOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ConstantMaskOp &op) {
- auto resultType = op.getResult().getType().cast<VectorType>();
+LogicalResult ConstantMaskOp::verify() {
+ auto resultType = getResult().getType().cast<VectorType>();
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
- if (op.mask_dim_sizes().size() != 1)
- return op->emitError("array attr must have length 1 for 0-D vectors");
- auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
+ if (mask_dim_sizes().size() != 1)
+ return emitError("array attr must have length 1 for 0-D vectors");
+ auto dim = mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
if (dim != 0 && dim != 1)
- return op->emitError(
- "mask dim size must be either 0 or 1 for 0-D vectors");
+ return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
}
// Verify that array attr size matches the rank of the vector result.
- if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
- return op.emitOpError(
+ if (static_cast<int64_t>(mask_dim_sizes().size()) != resultType.getRank())
+ return emitOpError(
"must specify array attr of size equal vector result rank");
// Verify that each array attr element is in bounds of corresponding vector
// result dimension size.
auto resultShape = resultType.getShape();
SmallVector<int64_t, 4> maskDimSizes;
- for (const auto &it : llvm::enumerate(op.mask_dim_sizes())) {
+ for (const auto &it : llvm::enumerate(mask_dim_sizes())) {
int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
if (attrValue < 0 || attrValue > resultShape[it.index()])
- return op.emitOpError(
+ return emitOpError(
"array attr of size out of bounds of vector result dimension size");
maskDimSizes.push_back(attrValue);
}
@@ -4269,8 +4253,8 @@ static LogicalResult verify(ConstantMaskOp &op) {
bool anyZeros = llvm::is_contained(maskDimSizes, 0);
bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
if (anyZeros && !allZeros)
- return op.emitOpError("expected all mask dim sizes to be zeros, "
- "as a result of conjunction with zero mask dim");
+ return emitOpError("expected all mask dim sizes to be zeros, "
+ "as a result of conjunction with zero mask dim");
return success();
}
@@ -4278,16 +4262,16 @@ static LogicalResult verify(ConstantMaskOp &op) {
// CreateMaskOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(CreateMaskOp op) {
- auto vectorType = op.getResult().getType().cast<VectorType>();
+LogicalResult CreateMaskOp::verify() {
+ auto vectorType = getResult().getType().cast<VectorType>();
// Verify that an operand was specified for each result vector each dimension.
if (vectorType.getRank() == 0) {
- if (op->getNumOperands() != 1)
- return op.emitOpError(
+ if (getNumOperands() != 1)
+ return emitOpError(
"must specify exactly one operand for 0-D create_mask");
- } else if (op.getNumOperands() !=
- op.getResult().getType().cast<VectorType>().getRank()) {
- return op.emitOpError(
+ } else if (getNumOperands() !=
+ getResult().getType().cast<VectorType>().getRank()) {
+ return emitOpError(
"must specify an operand for each result vector dimension");
}
return success();
@@ -4342,20 +4326,20 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ScanOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ScanOp op) {
- VectorType srcType = op.getSourceType();
- VectorType initialType = op.getInitialValueType();
+LogicalResult ScanOp::verify() {
+ VectorType srcType = getSourceType();
+ VectorType initialType = getInitialValueType();
// Check reduction dimension < rank.
int64_t srcRank = srcType.getRank();
- int64_t reductionDim = op.reduction_dim();
+ int64_t reductionDim = reduction_dim();
if (reductionDim >= srcRank)
- return op.emitOpError("reduction dimension ")
+ return emitOpError("reduction dimension ")
<< reductionDim << " has to be less than " << srcRank;
// Check that rank(initial_value) = rank(src) - 1.
int64_t initialValueRank = initialType.getRank();
if (initialValueRank != srcRank - 1)
- return op.emitOpError("initial value rank ")
+ return emitOpError("initial value rank ")
<< initialValueRank << " has to be equal to " << srcRank - 1;
// Check shapes of initial value and src.
@@ -4370,7 +4354,7 @@ static LogicalResult verify(ScanOp op) {
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != std::get<1>(s);
})) {
- return op.emitOpError("incompatible input/initial value shapes");
+ return emitOpError("incompatible input/initial value shapes");
}
return success();
More information about the Mlir-commits
mailing list