[Mlir-commits] [mlir] b98dc03 - [mlir][NFC] Update MemRef/Tensor operations to use `hasVerifier` instead of `verifier`
River Riddle
llvmlistbot at llvm.org
Wed Feb 2 13:35:46 PST 2022
Author: River Riddle
Date: 2022-02-02T13:34:30-08:00
New Revision: b98dc0351aefaa069f6a8f0cdc3800dc69918741
URL: https://github.com/llvm/llvm-project/commit/b98dc0351aefaa069f6a8f0cdc3800dc69918741
DIFF: https://github.com/llvm/llvm-project/commit/b98dc0351aefaa069f6a8f0cdc3800dc69918741.diff
LOG: [mlir][NFC] Update MemRef/Tensor operations to use `hasVerifier` instead of `verifier`
The verifier field is deprecated, and slated for removal.
Differential Revision: https://reviews.llvm.org/D118821
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 346a2eecedbfb..2af71109a786a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -28,7 +28,6 @@ def MemRefTypeAttr
class MemRef_Op<string mnemonic, list<Trait> traits = []>
: Op<MemRef_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@@ -93,6 +92,7 @@ class AllocLikeOp<string mnemonic,
}];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -115,6 +115,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
let results = (outs);
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -162,6 +163,7 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
```
}];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -205,6 +207,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
an alignment on any convenient boundary compatible with the type will be
chosen.
}];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -253,6 +256,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$bodyRegion);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -279,11 +283,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
let arguments = (ins Variadic<AnyType>:$results);
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
- let assemblyFormat =
- [{ attr-dict ($results^ `:` type($results))? }];
-
- // No custom verification needed.
- let verifier = ?;
+ let assemblyFormat = "attr-dict ($results^ `:` type($results))?";
}
//===----------------------------------------------------------------------===//
@@ -355,7 +355,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
- let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
@@ -370,6 +369,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -408,7 +408,6 @@ def CopyOp : MemRef_Op<"copy",
let hasCanonicalizer = 1;
let hasFolder = 1;
- let verifier = ?;
}
//===----------------------------------------------------------------------===//
@@ -434,7 +433,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", [MemFree]>:$memref);
let hasFolder = 1;
- let verifier = ?;
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
@@ -488,6 +486,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -646,6 +645,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
}
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -697,6 +697,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
Value getNumElements() { return numElements(); }
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -757,6 +758,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
return memref().getType().cast<MemRefType>();
}
}];
+ let hasVerifier = 1;
}
def AtomicYieldOp : MemRef_Op<"atomic_yield", [
@@ -772,6 +774,7 @@ def AtomicYieldOp : MemRef_Op<"atomic_yield", [
let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -797,9 +800,6 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs AnyStaticShapeMemRef:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
-
- // `GetGlobalOp` is fully verified by its traits.
- let verifier = ?;
}
//===----------------------------------------------------------------------===//
@@ -866,6 +866,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
}
}];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -939,6 +940,7 @@ def LoadOp : MemRef_Op<"load",
}];
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
}
@@ -982,6 +984,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1034,6 +1037,7 @@ def MemRef_ReinterpretCastOp:
let parser = ?;
let printer = ?;
+ let hasVerifier = 1;
let builders = [
// Build a ReinterpretCastOp with mixed static and dynamic entries.
@@ -1096,7 +1100,6 @@ def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
let arguments = (ins AnyRankedOrUnrankedMemRef:$memref);
let results = (outs Index);
- let verifier = ?;
let hasFolder = 1;
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
@@ -1161,6 +1164,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
let assemblyFormat = [{
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1226,6 +1230,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
@@ -1265,6 +1270,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
+ let hasVerifier = 1;
}
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
@@ -1302,6 +1308,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1369,6 +1376,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
}];
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
@@ -1617,6 +1625,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1645,8 +1654,6 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
"the reference to store to", [MemWrite]>:$memref);
- // TensorStoreOp is fully verified by traits.
- let verifier = ?;
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
}
@@ -1681,6 +1688,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1749,6 +1757,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
}];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1796,6 +1805,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
}
}];
let hasFolder = 1;
+ let hasVerifier = 1;
}
#endif // MEMREF_OPS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c87ee778665b5..278fedbbd3cb4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -21,7 +21,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
: Op<SparseTensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@@ -50,6 +49,7 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
```
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
+ let hasVerifier = 1;
}
def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
@@ -72,6 +72,7 @@ def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
```
}];
let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)";
+ let hasVerifier = 1;
}
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@@ -113,6 +114,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
@@ -137,6 +139,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
}];
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
" `to` type($result)";
+ let hasVerifier = 1;
}
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
@@ -161,6 +164,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
}];
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
" `to` type($result)";
+ let hasVerifier = 1;
}
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
@@ -183,6 +187,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -217,6 +222,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
}];
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
" type($tensor) `,` type($indices) `,` type($value)";
+ let hasVerifier = 1;
}
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
@@ -258,6 +264,7 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
" `,` type($filled) `,` type($added) `,` type($count)";
+ let hasVerifier = 1;
}
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
@@ -292,6 +299,7 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
" $added `,` $count attr-dict `:` type($tensor) `,`"
" type($indices) `,` type($values) `,` type($filled) `,`"
" type($added) `,` type($count)";
+ let hasVerifier = 1;
}
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
@@ -324,6 +332,7 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
```
}];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
+ let hasVerifier = 1;
}
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
@@ -349,6 +358,7 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
+ let hasVerifier = 1;
}
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
@@ -369,6 +379,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
```
}];
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
+ let hasVerifier = 1;
}
#endif // SPARSETENSOR_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 54af06b425052..a2f15b380b2e8 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -20,7 +20,6 @@ include "mlir/Interfaces/ViewLikeInterface.td"
class Tensor_Op<string mnemonic, list<Trait> traits = []>
: Op<Tensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@@ -59,7 +58,6 @@ def Tensor_CastOp : Tensor_Op<"cast", [
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasCanonicalizer = 1;
- let verifier = ?;
}
//===----------------------------------------------------------------------===//
@@ -111,6 +109,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -151,6 +150,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
}]>];
let hasFolder = 1;
+ let hasVerifier = 1;
}
@@ -303,6 +303,7 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -339,9 +340,6 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
let assemblyFormat = "$elements attr-dict `:` type($result)";
- // This op is fully verified by its traits.
- let verifier = ?;
-
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>,
@@ -394,6 +392,7 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -445,6 +444,7 @@ def Tensor_InsertOp : Tensor_Op<"insert",
}]>];
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -564,6 +564,7 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -586,7 +587,6 @@ def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
let arguments = (ins AnyTensor:$tensor);
let results = (outs Index);
- let verifier = ?;
let hasFolder = 1;
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}
@@ -650,6 +650,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
let assemblyFormat = [{
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -718,6 +719,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
@@ -748,6 +750,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
+ let hasVerifier = 1;
}
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
@@ -776,6 +779,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -961,6 +965,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
}
@@ -984,7 +989,6 @@ def Tensor_YieldOp : Tensor_Op<"yield",
// Dummy builder to appease code in templated ensureTerminator that
// GenerateOp's auto-generated parser calls.
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
- let verifier = ?;
}
#endif // TENSOR_OPS
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d84d7089f173b..a282bd2b8ae18 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -67,6 +67,10 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
+LogicalResult memref::CastOp::verify() {
+ return impl::verifyCastOp(*this, areCastCompatible);
+}
+
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@@ -95,15 +99,15 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
return success();
}
-static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
+LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
-static LogicalResult verify(AllocaOp op) {
+LogicalResult AllocaOp::verify() {
// An alloca op needs to have an ancestor with an allocation scope trait.
- if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
- return op.emitOpError(
+ if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
+ return emitOpError(
"requires an ancestor op with AutomaticAllocationScope trait");
- return verifyAllocLikeOp(op);
+ return verifyAllocLikeOp(*this);
}
namespace {
@@ -246,11 +250,8 @@ static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(AllocaScopeOp op) {
- if (failed(RegionBranchOpInterface::verifyTypes(op)))
- return failure();
-
- return success();
+LogicalResult AllocaScopeOp::verify() {
+ return RegionBranchOpInterface::verifyTypes(*this);
}
void AllocaScopeOp::getSuccessorRegions(
@@ -268,10 +269,9 @@ void AllocaScopeOp::getSuccessorRegions(
// AssumeAlignmentOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(AssumeAlignmentOp op) {
- unsigned alignment = op.alignment();
- if (!llvm::isPowerOf2_32(alignment))
- return op.emitOpError("alignment must be power of 2");
+LogicalResult AssumeAlignmentOp::verify() {
+ if (!llvm::isPowerOf2_32(alignment()))
+ return emitOpError("alignment must be power of 2");
return success();
}
@@ -556,17 +556,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}
-static LogicalResult verify(DimOp op) {
+LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
- Optional<int64_t> index = op.getConstantIndex();
+ Optional<int64_t> index = getConstantIndex();
if (!index.hasValue())
return success();
// Check that constant index is not knowingly out of range.
- auto type = op.source().getType();
+ auto type = source().getType();
if (auto memrefType = type.dyn_cast<MemRefType>()) {
if (index.getValue() >= memrefType.getRank())
- return op.emitOpError("index is out of range");
+ return emitOpError("index is out of range");
} else if (type.isa<UnrankedMemRefType>()) {
// Assume index to be in range.
} else {
@@ -866,67 +866,66 @@ static ParseResult parseDmaStartOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(DmaStartOp op) {
- unsigned numOperands = op.getNumOperands();
+LogicalResult DmaStartOp::verify() {
+ unsigned numOperands = getNumOperands();
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements.
if (numOperands < 4)
- return op.emitOpError("expected at least 4 operands");
+ return emitOpError("expected at least 4 operands");
// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
// 1. Source memref.
- if (!op.getSrcMemRef().getType().isa<MemRefType>())
- return op.emitOpError("expected source to be of memref type");
- if (numOperands < op.getSrcMemRefRank() + 4)
- return op.emitOpError()
- << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
- if (!op.getSrcIndices().empty() &&
- !llvm::all_of(op.getSrcIndices().getTypes(),
+ if (!getSrcMemRef().getType().isa<MemRefType>())
+ return emitOpError("expected source to be of memref type");
+ if (numOperands < getSrcMemRefRank() + 4)
+ return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
+ << " operands";
+ if (!getSrcIndices().empty() &&
+ !llvm::all_of(getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
- return op.emitOpError("expected source indices to be of index type");
+ return emitOpError("expected source indices to be of index type");
// 2. Destination memref.
- if (!op.getDstMemRef().getType().isa<MemRefType>())
- return op.emitOpError("expected destination to be of memref type");
- unsigned numExpectedOperands =
- op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
+ if (!getDstMemRef().getType().isa<MemRefType>())
+ return emitOpError("expected destination to be of memref type");
+ unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands)
- return op.emitOpError()
- << "expected at least " << numExpectedOperands << " operands";
- if (!op.getDstIndices().empty() &&
- !llvm::all_of(op.getDstIndices().getTypes(),
+ return emitOpError() << "expected at least " << numExpectedOperands
+ << " operands";
+ if (!getDstIndices().empty() &&
+ !llvm::all_of(getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
- return op.emitOpError("expected destination indices to be of index type");
+ return emitOpError("expected destination indices to be of index type");
// 3. Number of elements.
- if (!op.getNumElements().getType().isIndex())
- return op.emitOpError("expected num elements to be of index type");
+ if (!getNumElements().getType().isIndex())
+ return emitOpError("expected num elements to be of index type");
// 4. Tag memref.
- if (!op.getTagMemRef().getType().isa<MemRefType>())
- return op.emitOpError("expected tag to be of memref type");
- numExpectedOperands += op.getTagMemRefRank();
+ if (!getTagMemRef().getType().isa<MemRefType>())
+ return emitOpError("expected tag to be of memref type");
+ numExpectedOperands += getTagMemRefRank();
if (numOperands < numExpectedOperands)
- return op.emitOpError()
- << "expected at least " << numExpectedOperands << " operands";
- if (!op.getTagIndices().empty() &&
- !llvm::all_of(op.getTagIndices().getTypes(),
+ return emitOpError() << "expected at least " << numExpectedOperands
+ << " operands";
+ if (!getTagIndices().empty() &&
+ !llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
- return op.emitOpError("expected tag indices to be of index type");
+ return emitOpError("expected tag indices to be of index type");
// Optional stride-related operands must be either both present or both
// absent.
if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2)
- return op.emitOpError("incorrect number of operands");
+ return emitOpError("incorrect number of operands");
// 5. Strides.
- if (op.isStrided()) {
- if (!op.getStride().getType().isIndex() ||
- !op.getNumElementsPerStride().getType().isIndex())
- return op.emitOpError(
+ if (isStrided()) {
+ if (!getStride().getType().isIndex() ||
+ !getNumElementsPerStride().getType().isIndex())
+ return emitOpError(
"expected stride and num elements per stride to be of type index");
}
@@ -949,14 +948,14 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
return foldMemRefCast(*this);
}
-static LogicalResult verify(DmaWaitOp op) {
+LogicalResult DmaWaitOp::verify() {
// Check that the number of tag indices matches the tagMemRef rank.
- unsigned numTagIndices = op.tagIndices().size();
- unsigned tagMemRefRank = op.getTagMemRefRank();
+ unsigned numTagIndices = tagIndices().size();
+ unsigned tagMemRefRank = getTagMemRefRank();
if (numTagIndices != tagMemRefRank)
- return op.emitOpError() << "expected tagIndices to have the same number of "
- "elements as the tagMemRef rank, expected "
- << tagMemRefRank << ", but got " << numTagIndices;
+ return emitOpError() << "expected tagIndices to have the same number of "
+ "elements as the tagMemRef rank, expected "
+ << tagMemRefRank << ", but got " << numTagIndices;
return success();
}
@@ -979,14 +978,13 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
}
}
-static LogicalResult verify(GenericAtomicRMWOp op) {
- auto &body = op.getRegion();
+LogicalResult GenericAtomicRMWOp::verify() {
+ auto &body = getRegion();
if (body.getNumArguments() != 1)
- return op.emitOpError("expected single number of entry block arguments");
+ return emitOpError("expected single number of entry block arguments");
- if (op.getResult().getType() != body.getArgument(0).getType())
- return op.emitOpError(
- "expected block argument of the same type result type");
+ if (getResult().getType() != body.getArgument(0).getType())
+ return emitOpError("expected block argument of the same type result type");
bool hasSideEffects =
body.walk([&](Operation *nestedOp) {
@@ -1034,12 +1032,12 @@ static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
// AtomicYieldOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(AtomicYieldOp op) {
- Type parentType = op->getParentOp()->getResultTypes().front();
- Type resultType = op.result().getType();
+LogicalResult AtomicYieldOp::verify() {
+ Type parentType = (*this)->getParentOp()->getResultTypes().front();
+ Type resultType = result().getType();
if (parentType != resultType)
- return op.emitOpError() << "types mismatch between yield op: " << resultType
- << " and its parent: " << parentType;
+ return emitOpError() << "types mismatch between yield op: " << resultType
+ << " and its parent: " << parentType;
return success();
}
@@ -1090,19 +1088,19 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
return success();
}
-static LogicalResult verify(GlobalOp op) {
- auto memrefType = op.type().dyn_cast<MemRefType>();
+LogicalResult GlobalOp::verify() {
+ auto memrefType = type().dyn_cast<MemRefType>();
if (!memrefType || !memrefType.hasStaticShape())
- return op.emitOpError("type should be static shaped memref, but got ")
- << op.type();
+ return emitOpError("type should be static shaped memref, but got ")
+ << type();
// Verify that the initial value, if present, is either a unit attribute or
// an elements attribute.
- if (op.initial_value().hasValue()) {
- Attribute initValue = op.initial_value().getValue();
+ if (initial_value().hasValue()) {
+ Attribute initValue = initial_value().getValue();
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
- return op.emitOpError("initial value should be a unit or elements "
- "attribute, but got ")
+ return emitOpError("initial value should be a unit or elements "
+ "attribute, but got ")
<< initValue;
// Check that the type of the initial value is compatible with the type of
@@ -1111,17 +1109,17 @@ static LogicalResult verify(GlobalOp op) {
Type initType = initValue.getType();
Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (initType != tensorType)
- return op.emitOpError("initial value expected to be of type ")
+ return emitOpError("initial value expected to be of type ")
<< tensorType << ", but was of type " << initType;
}
}
- if (Optional<uint64_t> alignAttr = op.alignment()) {
+ if (Optional<uint64_t> alignAttr = alignment()) {
uint64_t alignment = alignAttr.getValue();
if (!llvm::isPowerOf2_64(alignment))
- return op->emitError() << "alignment attribute value " << alignment
- << " is not a power of 2";
+ return emitError() << "alignment attribute value " << alignment
+ << " is not a power of 2";
}
// TODO: verify visibility for declarations.
@@ -1154,9 +1152,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// LoadOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(LoadOp op) {
- if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
- return op.emitOpError("incorrect number of indices for load");
+LogicalResult LoadOp::verify() {
+ if (getNumOperands() != 1 + getMemRefType().getRank())
+ return emitOpError("incorrect number of indices for load");
return success();
}
@@ -1224,9 +1222,9 @@ static ParseResult parsePrefetchOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(PrefetchOp op) {
- if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
- return op.emitOpError("too few indices");
+LogicalResult PrefetchOp::verify() {
+ if (getNumOperands() != 1 + getMemRefType().getRank())
+ return emitOpError("too few indices");
return success();
}
@@ -1306,26 +1304,25 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
// TODO: ponder whether we want to allow missing trailing sizes/strides that are
// completed automatically, like we have for subview and extract_slice.
-static LogicalResult verify(ReinterpretCastOp op) {
+LogicalResult ReinterpretCastOp::verify() {
// The source and result memrefs should be in the same memory space.
- auto srcType = op.source().getType().cast<BaseMemRefType>();
- auto resultType = op.getType().cast<MemRefType>();
+ auto srcType = source().getType().cast<BaseMemRefType>();
+ auto resultType = getType().cast<MemRefType>();
if (srcType.getMemorySpace() != resultType.getMemorySpace())
- return op.emitError("
diff erent memory spaces specified for source type ")
+ return emitError("
diff erent memory spaces specified for source type ")
<< srcType << " and result memref type " << resultType;
if (srcType.getElementType() != resultType.getElementType())
- return op.emitError("
diff erent element types specified for source type ")
+ return emitError("
diff erent element types specified for source type ")
<< srcType << " and result memref type " << resultType;
// Match sizes in result memref type and in static_sizes attribute.
- for (auto &en :
- llvm::enumerate(llvm::zip(resultType.getShape(),
- extractFromI64ArrayAttr(op.static_sizes())))) {
+ for (auto &en : llvm::enumerate(llvm::zip(
+ resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
- return op.emitError("expected result type with size = ")
+ return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << en.index();
}
@@ -1336,27 +1333,26 @@ static LogicalResult verify(ReinterpretCastOp op) {
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
- return op.emitError(
- "expected result type to have strided layout but found ")
+ return emitError("expected result type to have strided layout but found ")
<< resultType;
// Match offset in result memref type and in static_offsets attribute.
- int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
+ int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
!ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
resultOffset != expectedOffset)
- return op.emitError("expected result type with offset = ")
+ return emitError("expected result type with offset = ")
<< resultOffset << " instead of " << expectedOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto &en : llvm::enumerate(llvm::zip(
- resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+ resultStrides, extractFromI64ArrayAttr(static_strides())))) {
int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
!ShapedType::isDynamicStrideOrOffset(expectedStride) &&
resultStride != expectedStride)
- return op.emitError("expected result type with stride = ")
+ return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << en.index();
}
@@ -1532,8 +1528,8 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
return success();
}
-static LogicalResult verify(ExpandShapeOp op) {
- return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
+LogicalResult ExpandShapeOp::verify() {
+ return verifyReshapeOp(*this, getResultType(), getSrcType());
}
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -1542,8 +1538,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
}
-static LogicalResult verify(CollapseShapeOp op) {
- return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
+LogicalResult CollapseShapeOp::verify() {
+ return verifyReshapeOp(*this, getSrcType(), getResultType());
}
struct CollapseShapeOpMemRefCastFolder
@@ -1593,32 +1589,30 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
// ReshapeOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ReshapeOp op) {
- Type operandType = op.source().getType();
- Type resultType = op.result().getType();
+LogicalResult ReshapeOp::verify() {
+ Type operandType = source().getType();
+ Type resultType = result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType)
- return op.emitOpError("element types of source and destination memref "
- "types should be the same");
+ return emitOpError("element types of source and destination memref "
+ "types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getLayout().isIdentity())
- return op.emitOpError(
- "source memref type should have identity affine map");
+ return emitOpError("source memref type should have identity affine map");
- int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
+ int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
if (!resultMemRefType.getLayout().isIdentity())
- return op.emitOpError(
- "result memref type should have identity affine map");
+ return emitOpError("result memref type should have identity affine map");
if (shapeSize == ShapedType::kDynamicSize)
- return op.emitOpError("cannot use shape operand with dynamic length to "
- "reshape to statically-ranked memref type");
+ return emitOpError("cannot use shape operand with dynamic length to "
+ "reshape to statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank())
- return op.emitOpError(
+ return emitOpError(
"length of shape operand
diff ers from the result's memref rank");
}
return success();
@@ -1628,9 +1622,9 @@ static LogicalResult verify(ReshapeOp op) {
// StoreOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(StoreOp op) {
- if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
- return op.emitOpError("store index operand count not equal to memref rank");
+LogicalResult StoreOp::verify() {
+ if (getNumOperands() != 2 + getMemRefType().getRank())
+ return emitOpError("store index operand count not equal to memref rank");
return success();
}
@@ -1951,29 +1945,29 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
}
/// Verifier for SubViewOp.
-static LogicalResult verify(SubViewOp op) {
- MemRefType baseType = op.getSourceType();
- MemRefType subViewType = op.getType();
+LogicalResult SubViewOp::verify() {
+ MemRefType baseType = getSourceType();
+ MemRefType subViewType = getType();
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
- return op.emitError("
diff erent memory spaces specified for base memref "
- "type ")
+ return emitError("
diff erent memory spaces specified for base memref "
+ "type ")
<< baseType << " and subview memref type " << subViewType;
// Verify that the base memref type has a strided layout map.
if (!isStrided(baseType))
- return op.emitError("base type ") << baseType << " is not strided";
+ return emitError("base type ") << baseType << " is not strided";
// Verify result type against inferred type.
auto expectedType = SubViewOp::inferResultType(
- baseType, extractFromI64ArrayAttr(op.static_offsets()),
- extractFromI64ArrayAttr(op.static_sizes()),
- extractFromI64ArrayAttr(op.static_strides()));
+ baseType, extractFromI64ArrayAttr(static_offsets()),
+ extractFromI64ArrayAttr(static_sizes()),
+ extractFromI64ArrayAttr(static_strides()));
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
- subViewType, op.getMixedSizes());
- return produceSubViewErrorMsg(result, op, expectedType);
+ subViewType, getMixedSizes());
+ return produceSubViewErrorMsg(result, *this, expectedType);
}
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
@@ -2278,18 +2272,17 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(TransposeOp op) {
- if (!op.permutation().isPermutation())
- return op.emitOpError("expected a permutation map");
- if (op.permutation().getNumDims() != op.getShapedType().getRank())
- return op.emitOpError(
- "expected a permutation map of same rank as the input");
+LogicalResult TransposeOp::verify() {
+ if (!permutation().isPermutation())
+ return emitOpError("expected a permutation map");
+ if (permutation().getNumDims() != getShapedType().getRank())
+ return emitOpError("expected a permutation map of same rank as the input");
- auto srcType = op.in().getType().cast<MemRefType>();
- auto dstType = op.getType().cast<MemRefType>();
- auto transposedType = inferTransposeResultType(srcType, op.permutation());
+ auto srcType = in().getType().cast<MemRefType>();
+ auto dstType = getType().cast<MemRefType>();
+ auto transposedType = inferTransposeResultType(srcType, permutation());
if (dstType != transposedType)
- return op.emitOpError("output type ")
+ return emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType
<< ", " << transposedType;
return success();
@@ -2338,29 +2331,28 @@ static void print(OpAsmPrinter &p, ViewOp op) {
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
-static LogicalResult verify(ViewOp op) {
- auto baseType = op.getOperand(0).getType().cast<MemRefType>();
- auto viewType = op.getType();
+LogicalResult ViewOp::verify() {
+ auto baseType = getOperand(0).getType().cast<MemRefType>();
+ auto viewType = getType();
// The base memref should have identity layout map (or none).
if (!baseType.getLayout().isIdentity())
- return op.emitError("unsupported map for base memref type ") << baseType;
+ return emitError("unsupported map for base memref type ") << baseType;
// The result memref should have identity layout map (or none).
if (!viewType.getLayout().isIdentity())
- return op.emitError("unsupported map for result memref type ") << viewType;
+ return emitError("unsupported map for result memref type ") << viewType;
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != viewType.getMemorySpace())
- return op.emitError("
diff erent memory spaces specified for base memref "
- "type ")
+ return emitError("
diff erent memory spaces specified for base memref "
+ "type ")
<< baseType << " and view memref type " << viewType;
// Verify that we have the correct number of sizes for the result type.
unsigned numDynamicDims = viewType.getNumDynamicDims();
- if (op.sizes().size() != numDynamicDims)
- return op.emitError("incorrect number of size operands for type ")
- << viewType;
+ if (sizes().size() != numDynamicDims)
+ return emitError("incorrect number of size operands for type ") << viewType;
return success();
}
@@ -2467,19 +2459,19 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
// AtomicRMWOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(AtomicRMWOp op) {
- if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
- return op.emitOpError(
+LogicalResult AtomicRMWOp::verify() {
+ if (getMemRefType().getRank() != getNumOperands() - 2)
+ return emitOpError(
"expects the number of subscripts to be equal to memref rank");
- switch (op.kind()) {
+ switch (kind()) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::mulf:
- if (!op.value().getType().isa<FloatType>())
- return op.emitOpError()
- << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
- << "' expects a floating-point type";
+ if (!value().getType().isa<FloatType>())
+ return emitOpError() << "with kind '"
+ << arith::stringifyAtomicRMWKind(kind())
+ << "' expects a floating-point type";
break;
case arith::AtomicRMWKind::addi:
case arith::AtomicRMWKind::maxs:
@@ -2489,10 +2481,10 @@ static LogicalResult verify(AtomicRMWOp op) {
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
case arith::AtomicRMWKind::andi:
- if (!op.value().getType().isa<IntegerType>())
- return op.emitOpError()
- << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
- << "' expects an integer type";
+ if (!value().getType().isa<IntegerType>())
+ return emitOpError() << "with kind '"
+ << arith::stringifyAtomicRMWKind(kind())
+ << "' expects an integer type";
break;
default:
break;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5b0ee4656c491..ecbc989a2c141 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -209,53 +209,51 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
return failure();
}
-static LogicalResult verify(NewOp op) {
- if (!getSparseTensorEncoding(op.result().getType()))
- return op.emitError("expected a sparse tensor result");
+LogicalResult NewOp::verify() {
+ if (!getSparseTensorEncoding(result().getType()))
+ return emitError("expected a sparse tensor result");
return success();
}
-static LogicalResult verify(InitOp op) {
- if (!getSparseTensorEncoding(op.result().getType()))
- return op.emitError("expected a sparse tensor result");
- RankedTensorType ttp = op.getType().cast<RankedTensorType>();
+LogicalResult InitOp::verify() {
+ if (!getSparseTensorEncoding(result().getType()))
+ return emitError("expected a sparse tensor result");
+ RankedTensorType ttp = getType().cast<RankedTensorType>();
unsigned rank = ttp.getRank();
- if (rank != op.sizes().size())
- return op.emitError("unexpected mismatch between tensor rank and sizes: ")
- << rank << " vs. " << op.sizes().size();
+ if (rank != sizes().size())
+ return emitError("unexpected mismatch between tensor rank and sizes: ")
+ << rank << " vs. " << sizes().size();
auto shape = ttp.getShape();
for (unsigned i = 0; i < rank; i++) {
if (shape[i] == ShapedType::kDynamicSize)
continue;
IntegerAttr constantAttr;
- if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) ||
+ if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
constantAttr.getInt() != shape[i]) {
- return op.emitError("unexpected mismatch with static dimension size ")
+ return emitError("unexpected mismatch with static dimension size ")
<< shape[i];
}
}
return success();
}
-static LogicalResult verify(ConvertOp op) {
- if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
- if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
+LogicalResult ConvertOp::verify() {
+ if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
+ if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
if (tp1.getRank() != tp2.getRank())
- return op.emitError("unexpected conversion mismatch in rank");
+ return emitError("unexpected conversion mismatch in rank");
auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape();
// Accept size matches between the source and the destination type
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
- for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
+ for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
- return op.emitError("unexpected conversion mismatch in dimension ")
- << d;
- }
+ return emitError("unexpected conversion mismatch in dimension ") << d;
return success();
}
}
- return op.emitError("unexpected type in convert");
+ return emitError("unexpected type in convert");
}
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
@@ -264,35 +262,35 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-static LogicalResult verify(ToPointersOp op) {
- if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
- if (failed(isInBounds(op.dim(), op.tensor())))
- return op.emitError("requested pointers dimension out of bounds");
- if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
- return op.emitError("unexpected type for pointers");
+LogicalResult ToPointersOp::verify() {
+ if (auto e = getSparseTensorEncoding(tensor().getType())) {
+ if (failed(isInBounds(dim(), tensor())))
+ return emitError("requested pointers dimension out of bounds");
+ if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
+ return emitError("unexpected type for pointers");
return success();
}
- return op.emitError("expected a sparse tensor to get pointers");
+ return emitError("expected a sparse tensor to get pointers");
}
-static LogicalResult verify(ToIndicesOp op) {
- if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
- if (failed(isInBounds(op.dim(), op.tensor())))
- return op.emitError("requested indices dimension out of bounds");
- if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
- return op.emitError("unexpected type for indices");
+LogicalResult ToIndicesOp::verify() {
+ if (auto e = getSparseTensorEncoding(tensor().getType())) {
+ if (failed(isInBounds(dim(), tensor())))
+ return emitError("requested indices dimension out of bounds");
+ if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
+ return emitError("unexpected type for indices");
return success();
}
- return op.emitError("expected a sparse tensor to get indices");
+ return emitError("expected a sparse tensor to get indices");
}
-static LogicalResult verify(ToValuesOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor to get values");
- RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
- MemRefType mtp = op.result().getType().cast<MemRefType>();
+LogicalResult ToValuesOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor to get values");
+ RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
+ MemRefType mtp = result().getType().cast<MemRefType>();
if (ttp.getElementType() != mtp.getElementType())
- return op.emitError("unexpected mismatch in element types");
+ return emitError("unexpected mismatch in element types");
return success();
}
@@ -300,39 +298,39 @@ static LogicalResult verify(ToValuesOp op) {
// TensorDialect Management Operations.
//===----------------------------------------------------------------------===//
-static LogicalResult verify(LexInsertOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor for insertion");
+LogicalResult LexInsertOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor for insertion");
return success();
}
-static LogicalResult verify(ExpandOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor for expansion");
+LogicalResult ExpandOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor for expansion");
return success();
}
-static LogicalResult verify(CompressOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor for compression");
+LogicalResult CompressOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor for compression");
return success();
}
-static LogicalResult verify(LoadOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor to materialize");
+LogicalResult LoadOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor to materialize");
return success();
}
-static LogicalResult verify(ReleaseOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor to release");
+LogicalResult ReleaseOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor to release");
return success();
}
-static LogicalResult verify(OutOp op) {
- if (!getSparseTensorEncoding(op.tensor().getType()))
- return op.emitError("expected a sparse tensor for output");
+LogicalResult OutOp::verify() {
+ if (!getSparseTensorEncoding(tensor().getType()))
+ return emitError("expected a sparse tensor for output");
return success();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2cc927555c3d6..91dfddb2dfe99 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -228,17 +228,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}
-static LogicalResult verify(DimOp op) {
+LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
- Optional<int64_t> index = op.getConstantIndex();
+ Optional<int64_t> index = getConstantIndex();
if (!index.hasValue())
return success();
// Check that constant index is not knowingly out of range.
- auto type = op.source().getType();
+ auto type = source().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (index.getValue() >= tensorType.getRank())
- return op.emitOpError("index is out of range");
+ return emitOpError("index is out of range");
} else if (type.isa<UnrankedTensorType>()) {
// Assume index to be in range.
} else {
@@ -328,11 +328,11 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ExtractOp op) {
+LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
- if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
- if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
- return op.emitOpError("incorrect number of indices for extract_element");
+ if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
+ if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
+ return emitOpError("incorrect number of indices for extract_element");
return success();
}
@@ -480,11 +480,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
// InsertOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(InsertOp op) {
+LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
- if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
- if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
- return op.emitOpError("incorrect number of indices");
+ if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
+ if (destType.getRank() != static_cast<int64_t>(indices().size()))
+ return emitOpError("incorrect number of indices");
return success();
}
@@ -502,27 +502,26 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
// GenerateOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(GenerateOp op) {
+LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are specified
// by the operands.
- RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
- if (op.getNumOperands() != resultTy.getNumDynamicDims())
- return op.emitError("must have as many index operands as dynamic extents "
- "in the result type");
+ RankedTensorType resultTy = getType().cast<RankedTensorType>();
+ if (getNumOperands() != resultTy.getNumDynamicDims())
+ return emitError("must have as many index operands as dynamic extents "
+ "in the result type");
// Ensure that region arguments span the index space.
- if (!llvm::all_of(op.body().getArgumentTypes(),
+ if (!llvm::all_of(body().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); }))
- return op.emitError("all body arguments must be index");
- if (op.body().getNumArguments() != resultTy.getRank())
- return op.emitError("must have one body argument per input dimension");
+ return emitError("all body arguments must be index");
+ if (body().getNumArguments() != resultTy.getRank())
+ return emitError("must have one body argument per input dimension");
// Ensure that the region yields an element of the right type.
- auto yieldOp =
- llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
+ auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
if (yieldOp.value().getType() != resultTy.getElementType())
- return op.emitOpError(
+ return emitOpError(
"body must be terminated with a `yield` operation of the tensor "
"element type");
@@ -686,16 +685,15 @@ static int64_t getNumElements(ShapedType type) {
return numElements;
}
-static LogicalResult verify(ReshapeOp op) {
- TensorType operandType = op.source().getType().cast<TensorType>();
- TensorType resultType = op.result().getType().cast<TensorType>();
+LogicalResult ReshapeOp::verify() {
+ TensorType operandType = source().getType().cast<TensorType>();
+ TensorType resultType = result().getType().cast<TensorType>();
if (operandType.getElementType() != resultType.getElementType())
- return op.emitOpError("element types of source and destination tensor "
- "types should be the same");
+ return emitOpError("element types of source and destination tensor "
+ "types should be the same");
- int64_t shapeSize =
- op.shape().getType().cast<RankedTensorType>().getDimSize(0);
+ int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
@@ -703,14 +701,14 @@ static LogicalResult verify(ReshapeOp op) {
if (operandRankedType && resultRankedType.hasStaticShape() &&
operandRankedType.hasStaticShape()) {
if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
- return op.emitOpError("source and destination tensor should have the "
- "same number of elements");
+ return emitOpError("source and destination tensor should have the "
+ "same number of elements");
}
if (ShapedType::isDynamic(shapeSize))
- return op.emitOpError("cannot use shape operand with dynamic length to "
- "reshape to statically-ranked tensor type");
+ return emitOpError("cannot use shape operand with dynamic length to "
+ "reshape to statically-ranked tensor type");
if (shapeSize != resultRankedType.getRank())
- return op.emitOpError(
+ return emitOpError(
"length of shape operand
diff ers from the result's tensor rank");
}
return success();
@@ -814,12 +812,12 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
return success();
}
-static LogicalResult verify(ExpandShapeOp op) {
- return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
+LogicalResult ExpandShapeOp::verify() {
+ return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
}
-static LogicalResult verify(CollapseShapeOp op) {
- return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
+LogicalResult CollapseShapeOp::verify() {
+ return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
}
namespace {
@@ -1052,14 +1050,12 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
/// Verifier for ExtractSliceOp.
-static LogicalResult verify(ExtractSliceOp op) {
+LogicalResult ExtractSliceOp::verify() {
// Verify result type against inferred type.
- auto expectedType =
- ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides());
- auto result =
- isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
- return produceSliceErrorMsg(result, op, expectedType);
+ auto expectedType = ExtractSliceOp::inferResultType(
+ getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
+ return produceSliceErrorMsg(result, *this, expectedType);
}
/// Infer the canonical type of the result of an extract_slice op. Returns a
@@ -1308,16 +1304,16 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
}
/// Verifier for InsertSliceOp.
-static LogicalResult verify(InsertSliceOp op) {
+LogicalResult InsertSliceOp::verify() {
// insert_slice is the inverse of extract_slice, use the same type inference.
auto expectedType = ExtractSliceOp::inferRankReducedResultType(
- op.getSourceType().getRank(), op.getType(),
- extractFromI64ArrayAttr(op.static_offsets()),
- extractFromI64ArrayAttr(op.static_sizes()),
- extractFromI64ArrayAttr(op.static_strides()));
+ getSourceType().getRank(), getType(),
+ extractFromI64ArrayAttr(static_offsets()),
+ extractFromI64ArrayAttr(static_sizes()),
+ extractFromI64ArrayAttr(static_strides()));
auto result =
- isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
- return produceSliceErrorMsg(result, op, expectedType);
+ isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
+ return produceSliceErrorMsg(result, *this, expectedType);
}
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -1569,40 +1565,40 @@ ParseResult parseInferType(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(PadOp op) {
- auto sourceType = op.source().getType().cast<RankedTensorType>();
- auto resultType = op.result().getType().cast<RankedTensorType>();
- auto expectedType = PadOp::inferResultType(
- sourceType, extractFromI64ArrayAttr(op.static_low()),
- extractFromI64ArrayAttr(op.static_high()));
+LogicalResult PadOp::verify() {
+ auto sourceType = source().getType().cast<RankedTensorType>();
+ auto resultType = result().getType().cast<RankedTensorType>();
+ auto expectedType =
+ PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
+ extractFromI64ArrayAttr(static_high()));
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
if (expectedType.isDynamicDim(i))
continue;
- return op.emitError("specified type ")
+ return emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
- auto ®ion = op.region();
+ auto ®ion = getRegion();
unsigned rank = resultType.getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
- return op.emitError("expected the block to have ") << rank << " arguments";
+ return emitError("expected the block to have ") << rank << " arguments";
// Note: the number and type of yield values are checked in the YieldOp.
for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
if (!en.value().isIndex())
- return op.emitOpError("expected block argument ")
+ return emitOpError("expected block argument ")
<< (en.index() + 1) << " to be an index";
}
// Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
if (yieldOp.value().getType() !=
- op.getType().cast<ShapedType>().getElementType())
- return op.emitOpError("expected yield type to match shape element type");
+ getType().cast<ShapedType>().getElementType())
+ return emitOpError("expected yield type to match shape element type");
return success();
}
More information about the Mlir-commits
mailing list