[Mlir-commits] [mlir] 38abddd - [mlir][NFC] Update AMX/LLVM/NVVM/X86 vector operations to use `hasVerifier` instead of `verifier`
River Riddle
llvmlistbot at llvm.org
Wed Feb 2 13:35:41 PST 2022
Author: River Riddle
Date: 2022-02-02T13:34:29-08:00
New Revision: 38abdddf6f660c6d71d1c018ee1f2a1b46808f68
URL: https://github.com/llvm/llvm-project/commit/38abdddf6f660c6d71d1c018ee1f2a1b46808f68
DIFF: https://github.com/llvm/llvm-project/commit/38abdddf6f660c6d71d1c018ee1f2a1b46808f68.diff
LOG: [mlir][NFC] Update AMX/LLVM/NVVM/X86 vector operations to use `hasVerifier` instead of `verifier`
The verifier field is deprecated, and slated for removal.
Differential Revision: https://reviews.llvm.org/D118819
Added:
Modified:
mlir/include/mlir/Dialect/AMX/AMX.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/Dialect/X86Vector/X86Vector.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 0d23cde90949f..16e2e14504a8b 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -91,7 +91,6 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
%0 = amx.tile_zero : vector<16x16xbf16>
```
}];
- let verifier = [{ return ::verify(*this); }];
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
@@ -100,6 +99,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
+ let hasVerifier = 1;
}
//
@@ -120,7 +120,6 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
```
}];
- let verifier = [{ return ::verify(*this); }];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
@@ -135,6 +134,7 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` type($res)";
+ let hasVerifier = 1;
}
def TileStoreOp : AMX_Op<"tile_store"> {
@@ -151,7 +151,6 @@ def TileStoreOp : AMX_Op<"tile_store"> {
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
```
}];
- let verifier = [{ return ::verify(*this); }];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
@@ -165,6 +164,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` type($val)";
+ let hasVerifier = 1;
}
//
@@ -186,7 +186,6 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
```
}];
- let verifier = [{ return ::verify(*this); }];
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
@@ -204,6 +203,7 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
+ let hasVerifier = 1;
}
def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
@@ -224,7 +224,6 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
```
}];
- let verifier = [{ return ::verify(*this); }];
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
@@ -245,6 +244,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b4cf3adc7bb88..66d28c713bcfa 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -351,9 +351,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
}];
let hasFolder = 1;
- let verifier = [{
- return ::verify(*this);
- }];
+ let hasVerifier = 1;
}
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
@@ -386,7 +384,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
CArg<"bool", "false">:$isNonTemporal)>];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
@@ -410,7 +408,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
];
let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
// Casts.
@@ -494,18 +492,18 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
unwindOps, normal, unwind);
}]>];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInvokeOp(parser, result); }];
let printer = [{ printInvokeOp(p, *this); }];
+ let hasVerifier = 1;
}
def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let arguments = (ins UnitAttr:$cleanup, Variadic<LLVM_Type>);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseLandingpadOp(parser, result); }];
let printer = [{ printLandingpadOp(p, *this); }];
+ let hasVerifier = 1;
}
def LLVM_CallOp : LLVM_Op<"call",
@@ -562,9 +560,9 @@ def LLVM_CallOp : LLVM_Op<"call",
build($_builder, $_state, results,
StringAttr::get($_builder.getContext(), callee), operands);
}]>];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseCallOp(parser, result); }];
let printer = [{ printCallOp(p, *this); }];
+ let hasVerifier = 1;
}
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
@@ -575,9 +573,9 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$vector, "Value":$position,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseExtractElementOp(parser, result); }];
let printer = [{ printExtractElementOp(p, *this); }];
+ let hasVerifier = 1;
}
def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position);
@@ -586,10 +584,10 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
$res = builder.CreateExtractValue($container, extractPosition($position));
}];
let builders = [LLVM_OneResultOpBuilder];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseExtractValueOp(parser, result); }];
let printer = [{ printExtractValueOp(p, *this); }];
let hasFolder = 1;
+ let hasVerifier = 1;
}
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
@@ -599,9 +597,9 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
$res = builder.CreateInsertElement($vector, $value, $position);
}];
let builders = [LLVM_OneResultOpBuilder];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInsertElementOp(parser, result); }];
let printer = [{ printInsertElementOp(p, *this); }];
+ let hasVerifier = 1;
}
def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value,
@@ -616,9 +614,9 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
[{
build($_builder, $_state, container.getType(), container, value, position);
}]>];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInsertValueOp(parser, result); }];
let printer = [{ printInsertValueOp(p, *this); }];
+ let hasVerifier = 1;
}
def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask);
@@ -631,16 +629,9 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
- let verifier = [{
- auto type1 = getV1().getType();
- auto type2 = getV2().getType();
- if (::mlir::LLVM::getVectorElementType(type1) !=
- ::mlir::LLVM::getVectorElementType(type2))
- return emitOpError("expected matching LLVM IR Dialect element types");
- return success();
- }];
let parser = [{ return parseShuffleVectorOp(parser, result); }];
let printer = [{ printShuffleVectorOp(p, *this); }];
+ let hasVerifier = 1;
}
// Misc operations.
@@ -718,27 +709,15 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
builder.CreateRetVoid();
}];
- let verifier = [{
- if (getNumOperands() > 1)
- return emitOpError("expects at most 1 operand");
- return success();
- }];
-
let parser = [{ return parseReturnOp(parser, result); }];
let printer = [{ printReturnOp(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {
let arguments = (ins LLVM_Type:$value);
string llvmBuilder = [{ builder.CreateResume($value); }];
- let verifier = [{
- if (!isa_and_nonnull<LandingpadOp>(getValue().getDefiningOp()))
- return emitOpError("expects landingpad value as operand");
- // No check for personality of function - landingpad op verifies it.
- return success();
- }];
-
let assemblyFormat = "$value attr-dict `:` type($value)";
+ let hasVerifier = 1;
}
def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
string llvmBuilder = [{ builder.CreateUnreachable(); }];
@@ -761,7 +740,6 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
- let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
$value `:` type($value) `,`
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
@@ -769,6 +747,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
$caseOperands, type($caseOperands)) `]`
attr-dict
}];
+ let hasVerifier = 1;
let builders = [
OpBuilder<(ins "Value":$value,
@@ -924,7 +903,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> {
}];
let assemblyFormat = "$global_name attr-dict `:` type($res)";
- let verifier = "return ::verify(*this);";
+ let hasVerifier = 1;
}
def LLVM_MetadataOp : LLVM_Op<"metadata", [
@@ -1175,7 +1154,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
let printer = "printGlobalOp(p, *this);";
let parser = "return parseGlobalOp(parser, result);";
- let verifier = "return ::verify(*this);";
+ let hasVerifier = 1;
}
def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
@@ -1205,8 +1184,8 @@ def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
```
}];
- let verifier = [{ return ::verify(*this); }];
let assemblyFormat = "attr-dict";
+ let hasVerifier = 1;
}
def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
@@ -1234,8 +1213,8 @@ def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
```
}];
- let verifier = [{ return ::verify(*this); }];
let assemblyFormat = "attr-dict";
+ let hasVerifier = 1;
}
def LLVM_LLVMFuncOp : LLVM_Op<"func", [
@@ -1310,9 +1289,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
LogicalResult verifyType();
}];
- let verifier = [{ return ::verify(*this); }];
let printer = [{ printLLVMFuncOp(p, *this); }];
let parser = [{ return parseLLVMFuncOp(parser, result); }];
+ let hasVerifier = 1;
}
def LLVM_NullOp
@@ -1402,8 +1381,8 @@ def LLVM_ConstantOp
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
- let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
+ let hasVerifier = 1;
}
// Operations that correspond to LLVM intrinsics. With MLIR operation set being
@@ -1848,7 +1827,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
}];
let parser = [{ return parseAtomicRMWOp(parser, result); }];
let printer = [{ printAtomicRMWOp(p, *this); }];
- let verifier = "return ::verify(*this);";
+ let hasVerifier = 1;
}
def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>;
@@ -1878,7 +1857,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> {
}];
let parser = [{ return parseAtomicCmpXchgOp(parser, result); }];
let printer = [{ printAtomicCmpXchgOp(p, *this); }];
- let verifier = "return ::verify(*this);";
+ let hasVerifier = 1;
}
def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> {
@@ -1901,7 +1880,7 @@ def LLVM_FenceOp : LLVM_Op<"fence"> {
}];
let parser = [{ return parseFenceOp(parser, result); }];
let printer = [{ printFenceOp(p, *this); }];
- let verifier = "return ::verify(*this);";
+ let hasVerifier = 1;
}
def AsmATT : LLVM_EnumAttrCase<
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 3fd7c5bc06609..de942f6fb4d31 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -22,12 +22,16 @@
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
+namespace mlir {
+namespace NVVM {
/// Return the element type and number of elements associated with a wmma matrix
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
/// WMMA_REGS structure.
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
mlir::NVVM::MMAFrag frag,
mlir::MLIRContext *context);
+} // namespace NVVM
+} // namespace mlir
///// Ops /////
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d26a0b2c6f30b..4a55ddd96cb79 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -131,22 +131,11 @@ def NVVM_ShflOp :
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
- let verifier = [{
- if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
- return success();
- auto type = getType().dyn_cast<LLVM::LLVMStructType>();
- auto elementType = (type && type.getBody().size() == 2)
- ? type.getBody()[1].dyn_cast<IntegerType>()
- : nullptr;
- if (!elementType || elementType.getWidth() != 1)
- return emitError("expected return type to be a two-element struct with "
- "i1 as the second element");
- return success();
- }];
let assemblyFormat = [{
$kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict
`:` type($val) `->` type($res)
}];
+ let hasVerifier = 1;
}
def NVVM_VoteBallotOp :
@@ -183,12 +172,8 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
}
createIntrinsicCall(builder, id, {$dst, $src});
}];
- let verifier = [{
- if (size() != 4 && size() != 8 && size() != 16)
- return emitError("expected byte size to be either 4, 8 or 16.");
- return success();
- }];
let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
+ let hasVerifier = 1;
}
def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
@@ -220,7 +205,7 @@ def NVVM_MmaOp :
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
/// Helpers to instantiate
diff erent version of wmma intrinsics.
@@ -538,7 +523,7 @@ def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
}];
let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)";
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
@@ -593,7 +578,7 @@ def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
}];
let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)";
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
// Base class for all the variants of WMMA mmaOps that may be defined.
@@ -647,7 +632,7 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
- let verifier = [{ return ::verify(*this); }];
+ let hasVerifier = 1;
}
#endif // NVVMIR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index bda3440aa9740..03fa89ef899a6 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -76,7 +76,6 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
with their respective bit set in writemask `k`) to `dst`, and pass through the
remaining elements from `src`.
}];
- let verifier = [{ return ::verify(*this); }];
let arguments = (ins VectorOfLengthAndType<[16, 8],
[I1]>:$k,
VectorOfLengthAndType<[16, 8],
@@ -88,6 +87,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
[F32, I32, F64, I64]>:$dst);
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
" `:` type($dst) (`,` type($src)^)?";
+ let hasVerifier = 1;
}
def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index f52e589ad9dfe..4a3ba46233ff8 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -358,22 +358,19 @@ struct WmmaElementwiseOpToNVVMLowering
} // namespace
-namespace mlir {
-
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
+LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
std::pair<Type, unsigned> typeInfo =
- inferMMAType(eltType, frag, type.getContext());
+ NVVM::inferMMAType(eltType, frag, type.getContext());
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}
-void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns) {
+void mlir::populateGpuWMMAToNVVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
WmmaElementwiseOpToNVVMLowering>(converter);
}
-} // namespace mlir
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index c5cf1f41d7098..9ea96791cef4b 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -52,53 +52,55 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
return success();
}
-static LogicalResult verify(amx::TileZeroOp op) {
- return verifyTileSize(op, op.getVectorType());
+LogicalResult amx::TileZeroOp::verify() {
+ return verifyTileSize(*this, getVectorType());
}
-static LogicalResult verify(amx::TileLoadOp op) {
- unsigned rank = op.getMemRefType().getRank();
- if (llvm::size(op.indices()) != rank)
- return op.emitOpError("requires ") << rank << " indices";
- return verifyTileSize(op, op.getVectorType());
+LogicalResult amx::TileLoadOp::verify() {
+ unsigned rank = getMemRefType().getRank();
+ if (indices().size() != rank)
+ return emitOpError("requires ") << rank << " indices";
+ return verifyTileSize(*this, getVectorType());
}
-static LogicalResult verify(amx::TileStoreOp op) {
- unsigned rank = op.getMemRefType().getRank();
- if (llvm::size(op.indices()) != rank)
- return op.emitOpError("requires ") << rank << " indices";
- return verifyTileSize(op, op.getVectorType());
+LogicalResult amx::TileStoreOp::verify() {
+ unsigned rank = getMemRefType().getRank();
+ if (indices().size() != rank)
+ return emitOpError("requires ") << rank << " indices";
+ return verifyTileSize(*this, getVectorType());
}
-static LogicalResult verify(amx::TileMulFOp op) {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
- if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
- failed(verifyTileSize(op, cType)) ||
- failed(verifyMultShape(op, aType, bType, cType, 1)))
+LogicalResult amx::TileMulFOp::verify() {
+ VectorType aType = getLhsVectorType();
+ VectorType bType = getRhsVectorType();
+ VectorType cType = getVectorType();
+ if (failed(verifyTileSize(*this, aType)) ||
+ failed(verifyTileSize(*this, bType)) ||
+ failed(verifyTileSize(*this, cType)) ||
+ failed(verifyMultShape(*this, aType, bType, cType, 1)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
- return op.emitOpError("unsupported type combination");
+ return emitOpError("unsupported type combination");
return success();
}
-static LogicalResult verify(amx::TileMulIOp op) {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
- if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
- failed(verifyTileSize(op, cType)) ||
- failed(verifyMultShape(op, aType, bType, cType, 2)))
+LogicalResult amx::TileMulIOp::verify() {
+ VectorType aType = getLhsVectorType();
+ VectorType bType = getRhsVectorType();
+ VectorType cType = getVectorType();
+ if (failed(verifyTileSize(*this, aType)) ||
+ failed(verifyTileSize(*this, bType)) ||
+ failed(verifyTileSize(*this, cType)) ||
+ failed(verifyMultShape(*this, aType, bType, cType, 2)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
- return op.emitOpError("unsupported type combination");
+ return emitOpError("unsupported type combination");
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 30eec4369dd19..fb2cd2546d5f8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -334,18 +334,17 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
p.printNewline();
}
-static LogicalResult verify(SwitchOp op) {
- if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) ||
- (op.getCaseValues() &&
- op.getCaseValues()->size() !=
- static_cast<int64_t>(op.getCaseDestinations().size())))
- return op.emitOpError("expects number of case values to match number of "
- "case destinations");
- if (op.getBranchWeights() &&
- op.getBranchWeights()->size() != op.getNumSuccessors())
- return op.emitError("expects number of branch weights to match number of "
- "successors: ")
- << op.getBranchWeights()->size() << " vs " << op.getNumSuccessors();
+LogicalResult SwitchOp::verify() {
+ if ((!getCaseValues() && !getCaseDestinations().empty()) ||
+ (getCaseValues() &&
+ getCaseValues()->size() !=
+ static_cast<int64_t>(getCaseDestinations().size())))
+ return emitOpError("expects number of case values to match number of "
+ "case destinations");
+ if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
+ return emitError("expects number of branch weights to match number of "
+ "successors: ")
+ << getBranchWeights()->size() << " vs " << getNumSuccessors();
return success();
}
@@ -518,11 +517,11 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
});
}
-LogicalResult verify(LLVM::GEPOp gepOp) {
+LogicalResult LLVM::GEPOp::verify() {
SmallVector<unsigned> indices;
SmallVector<unsigned> structSizes;
- findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
- DenseIntElementsAttr structIndices = gepOp.getStructIndices();
+ findKnownStructIndices(getBase().getType(), indices, &structSizes);
+ DenseIntElementsAttr structIndices = getStructIndices();
for (unsigned i : llvm::seq<unsigned>(0, indices.size())) {
unsigned index = indices[i];
// GEP may not be indexing as deep as some structs nested in the type.
@@ -531,11 +530,11 @@ LogicalResult verify(LLVM::GEPOp gepOp) {
int32_t staticIndex = structIndices.getValues<int32_t>()[index];
if (staticIndex == LLVM::GEPOp::kDynamicIndex)
- return gepOp.emitOpError() << "expected index " << index
- << " indexing a struct to be constant";
+ return emitOpError() << "expected index " << index
+ << " indexing a struct to be constant";
if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
- return gepOp.emitOpError()
- << "index " << index << " indexing a struct is out of bounds";
+ return emitOpError() << "index " << index
+ << " indexing a struct is out of bounds";
}
return success();
}
@@ -613,9 +612,7 @@ static LogicalResult verifyMemoryOpMetadata(Operation *op) {
return success();
}
-static LogicalResult verify(LoadOp op) {
- return verifyMemoryOpMetadata(op.getOperation());
-}
+LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
Value addr, unsigned alignment, bool isVolatile,
@@ -675,9 +672,7 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
// Builder, printer and parser for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
-static LogicalResult verify(StoreOp op) {
- return verifyMemoryOpMetadata(op.getOperation());
-}
+LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
Value addr, unsigned alignment, bool isVolatile,
@@ -739,19 +734,18 @@ InvokeOp::getMutableSuccessorOperands(unsigned index) {
: getUnwindDestOperandsMutable();
}
-static LogicalResult verify(InvokeOp op) {
- if (op.getNumResults() > 1)
- return op.emitOpError("must have 0 or 1 result");
+LogicalResult InvokeOp::verify() {
+ if (getNumResults() > 1)
+ return emitOpError("must have 0 or 1 result");
- Block *unwindDest = op.getUnwindDest();
+ Block *unwindDest = getUnwindDest();
if (unwindDest->empty())
- return op.emitError(
- "must have at least one operation in unwind destination");
+ return emitError("must have at least one operation in unwind destination");
// In unwind destination, first operation must be LandingpadOp
if (!isa<LandingpadOp>(unwindDest->front()))
- return op.emitError("first operation in unwind destination should be a "
- "llvm.landingpad operation");
+ return emitError("first operation in unwind destination should be a "
+ "llvm.landingpad operation");
return success();
}
@@ -880,20 +874,20 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
///===----------------------------------------------------------------------===//
-static LogicalResult verify(LandingpadOp op) {
+LogicalResult LandingpadOp::verify() {
Value value;
- if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
+ if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
if (!func.getPersonality().hasValue())
- return op.emitError(
+ return emitError(
"llvm.landingpad needs to be in a function with a personality");
}
- if (!op.getCleanup() && op.getOperands().empty())
- return op.emitError("landingpad instruction expects at least one clause or "
- "cleanup attribute");
+ if (!getCleanup() && getOperands().empty())
+ return emitError("landingpad instruction expects at least one clause or "
+ "cleanup attribute");
- for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
- value = op.getOperand(idx);
+ for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
+ value = getOperand(idx);
bool isFilter = value.getType().isa<LLVMArrayType>();
if (isFilter) {
// FIXME: Verify filter clauses when arrays are appropriately handled
@@ -903,8 +897,7 @@ static LogicalResult verify(LandingpadOp op) {
if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
continue;
- return op.emitError("constant clauses expected")
- .attachNote(bcOp.getLoc())
+ return emitError("constant clauses expected").attachNote(bcOp.getLoc())
<< "global addresses expected as operand to "
"bitcast used in clauses for landingpad";
}
@@ -913,7 +906,7 @@ static LogicalResult verify(LandingpadOp op) {
continue;
if (value.getDefiningOp<AddressOfOp>())
continue;
- return op.emitError("clause #")
+ return emitError("clause #")
<< idx << " is not a known constant - null, addressof, bitcast";
}
}
@@ -970,9 +963,9 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser,
// Verifying/Printing/parsing for LLVM::CallOp.
//===----------------------------------------------------------------------===//
-static LogicalResult verify(CallOp &op) {
- if (op.getNumResults() > 1)
- return op.emitOpError("must have 0 or 1 result");
+LogicalResult CallOp::verify() {
+ if (getNumResults() > 1)
+ return emitOpError("must have 0 or 1 result");
// Type for the callee, we'll get it
diff erently depending if it is a direct
// or indirect call.
@@ -981,75 +974,73 @@ static LogicalResult verify(CallOp &op) {
bool isIndirect = false;
// If this is an indirect call, the callee attribute is missing.
- FlatSymbolRefAttr calleeName = op.getCalleeAttr();
+ FlatSymbolRefAttr calleeName = getCalleeAttr();
if (!calleeName) {
isIndirect = true;
- if (!op.getNumOperands())
- return op.emitOpError(
+ if (!getNumOperands())
+ return emitOpError(
"must have either a `callee` attribute or at least an operand");
- auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
+ auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
if (!ptrType)
- return op.emitOpError("indirect call expects a pointer as callee: ")
+ return emitOpError("indirect call expects a pointer as callee: ")
<< ptrType;
fnType = ptrType.getElementType();
} else {
Operation *callee =
- SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr());
+ SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
if (!callee)
- return op.emitOpError()
+ return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a symbol in the current scope";
auto fn = dyn_cast<LLVMFuncOp>(callee);
if (!fn)
- return op.emitOpError() << "'" << calleeName.getValue()
- << "' does not reference a valid LLVM function";
+ return emitOpError() << "'" << calleeName.getValue()
+ << "' does not reference a valid LLVM function";
fnType = fn.getType();
}
LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
if (!funcType)
- return op.emitOpError("callee does not have a functional type: ") << fnType;
+ return emitOpError("callee does not have a functional type: ") << fnType;
// Verify that the operand and result types match the callee.
if (!funcType.isVarArg() &&
- funcType.getNumParams() != (op.getNumOperands() - isIndirect))
- return op.emitOpError()
- << "incorrect number of operands ("
- << (op.getNumOperands() - isIndirect)
- << ") for callee (expecting: " << funcType.getNumParams() << ")";
-
- if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
- return op.emitOpError() << "incorrect number of operands ("
- << (op.getNumOperands() - isIndirect)
- << ") for varargs callee (expecting at least: "
- << funcType.getNumParams() << ")";
+ funcType.getNumParams() != (getNumOperands() - isIndirect))
+ return emitOpError() << "incorrect number of operands ("
+ << (getNumOperands() - isIndirect)
+ << ") for callee (expecting: "
+ << funcType.getNumParams() << ")";
+
+ if (funcType.getNumParams() > (getNumOperands() - isIndirect))
+ return emitOpError() << "incorrect number of operands ("
+ << (getNumOperands() - isIndirect)
+ << ") for varargs callee (expecting at least: "
+ << funcType.getNumParams() << ")";
for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
- if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
- return op.emitOpError() << "operand type mismatch for operand " << i
- << ": " << op.getOperand(i + isIndirect).getType()
- << " != " << funcType.getParamType(i);
+ if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
+ return emitOpError() << "operand type mismatch for operand " << i << ": "
+ << getOperand(i + isIndirect).getType()
+ << " != " << funcType.getParamType(i);
- if (op.getNumResults() == 0 &&
+ if (getNumResults() == 0 &&
!funcType.getReturnType().isa<LLVM::LLVMVoidType>())
- return op.emitOpError() << "expected function call to produce a value";
+ return emitOpError() << "expected function call to produce a value";
- if (op.getNumResults() != 0 &&
+ if (getNumResults() != 0 &&
funcType.getReturnType().isa<LLVM::LLVMVoidType>())
- return op.emitOpError()
+ return emitOpError()
<< "calling function with void result must not produce values";
- if (op.getNumResults() > 1)
- return op.emitOpError()
+ if (getNumResults() > 1)
+ return emitOpError()
<< "expected LLVM function call to produce 0 or 1 result";
- if (op.getNumResults() &&
- op.getResult(0).getType() != funcType.getReturnType())
- return op.emitOpError()
- << "result type mismatch: " << op.getResult(0).getType()
- << " != " << funcType.getReturnType();
+ if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
+ return emitOpError() << "result type mismatch: " << getResult(0).getType()
+ << " != " << funcType.getReturnType();
return success();
}
@@ -1200,17 +1191,17 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(ExtractElementOp op) {
- Type vectorType = op.getVector().getType();
+LogicalResult ExtractElementOp::verify() {
+ Type vectorType = getVector().getType();
if (!LLVM::isCompatibleVectorType(vectorType))
- return op->emitOpError("expected LLVM dialect-compatible vector type for "
- "operand #1, got")
+ return emitOpError("expected LLVM dialect-compatible vector type for "
+ "operand #1, got")
<< vectorType;
Type valueType = LLVM::getVectorElementType(vectorType);
- if (valueType != op.getRes().getType())
- return op.emitOpError() << "Type mismatch: extracting from " << vectorType
- << " should produce " << valueType
- << " but this op returns " << op.getRes().getType();
+ if (valueType != getRes().getType())
+ return emitOpError() << "Type mismatch: extracting from " << vectorType
+ << " should produce " << valueType
+ << " but this op returns " << getRes().getType();
return success();
}
@@ -1367,17 +1358,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-static LogicalResult verify(ExtractValueOp op) {
- Type valueType = getInsertExtractValueElementType(op.getContainer().getType(),
- op.getPositionAttr(), op);
+LogicalResult ExtractValueOp::verify() {
+ Type valueType = getInsertExtractValueElementType(getContainer().getType(),
+ getPositionAttr(), *this);
if (!valueType)
return failure();
- if (op.getRes().getType() != valueType)
- return op.emitOpError()
- << "Type mismatch: extracting from " << op.getContainer().getType()
- << " should produce " << valueType << " but this op returns "
- << op.getRes().getType();
+ if (getRes().getType() != valueType)
+ return emitOpError() << "Type mismatch: extracting from "
+ << getContainer().getType() << " should produce "
+ << valueType << " but this op returns "
+ << getRes().getType();
return success();
}
@@ -1423,14 +1414,15 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(InsertElementOp op) {
- Type valueType = LLVM::getVectorElementType(op.getVector().getType());
- if (valueType != op.getValue().getType())
- return op.emitOpError()
- << "Type mismatch: cannot insert " << op.getValue().getType()
- << " into " << op.getVector().getType();
+LogicalResult InsertElementOp::verify() {
+ Type valueType = LLVM::getVectorElementType(getVector().getType());
+ if (valueType != getValue().getType())
+ return emitOpError() << "Type mismatch: cannot insert "
+ << getValue().getType() << " into "
+ << getVector().getType();
return success();
}
+
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::InsertValueOp.
//===----------------------------------------------------------------------===//
@@ -1473,16 +1465,16 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(InsertValueOp op) {
- Type valueType = getInsertExtractValueElementType(op.getContainer().getType(),
- op.getPositionAttr(), op);
+LogicalResult InsertValueOp::verify() {
+ Type valueType = getInsertExtractValueElementType(getContainer().getType(),
+ getPositionAttr(), *this);
if (!valueType)
return failure();
- if (op.getValue().getType() != valueType)
- return op.emitOpError()
- << "Type mismatch: cannot insert " << op.getValue().getType()
- << " into " << op.getContainer().getType();
+ if (getValue().getType() != valueType)
+ return emitOpError() << "Type mismatch: cannot insert "
+ << getValue().getType() << " into "
+ << getContainer().getType();
return success();
}
@@ -1519,28 +1511,28 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
return success();
}
-static LogicalResult verify(ReturnOp op) {
- if (op->getNumOperands() > 1)
- return op->emitOpError("expected at most 1 operand");
+LogicalResult ReturnOp::verify() {
+ if (getNumOperands() > 1)
+ return emitOpError("expected at most 1 operand");
- if (auto parent = op->getParentOfType<LLVMFuncOp>()) {
+ if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
Type expectedType = parent.getType().getReturnType();
if (expectedType.isa<LLVMVoidType>()) {
- if (op->getNumOperands() == 0)
+ if (getNumOperands() == 0)
return success();
- InFlightDiagnostic diag = op->emitOpError("expected no operands");
+ InFlightDiagnostic diag = emitOpError("expected no operands");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
- if (op->getNumOperands() == 0) {
+ if (getNumOperands() == 0) {
if (expectedType.isa<LLVMVoidType>())
return success();
- InFlightDiagnostic diag = op->emitOpError("expected 1 operand");
+ InFlightDiagnostic diag = emitOpError("expected 1 operand");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
- if (expectedType != op->getOperand(0).getType()) {
- InFlightDiagnostic diag = op->emitOpError("mismatching result types");
+ if (expectedType != getOperand(0).getType()) {
+ InFlightDiagnostic diag = emitOpError("mismatching result types");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
@@ -1548,6 +1540,17 @@ static LogicalResult verify(ReturnOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// ResumeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ResumeOp::verify() {
+ if (!getValue().getDefiningOp<LandingpadOp>())
+ return emitOpError("expects landingpad value as operand");
+ // No check for personality of function - landingpad op verifies it.
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//
@@ -1572,22 +1575,22 @@ LLVMFuncOp AddressOfOp::getFunction() {
getGlobalName());
}
-static LogicalResult verify(AddressOfOp op) {
- auto global = op.getGlobal();
- auto function = op.getFunction();
+LogicalResult AddressOfOp::verify() {
+ auto global = getGlobal();
+ auto function = getFunction();
if (!global && !function)
- return op.emitOpError(
+ return emitOpError(
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
if (global &&
LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) !=
- op.getResult().getType())
- return op.emitOpError(
+ getResult().getType())
+ return emitOpError(
"the type must be a pointer to the type of the referenced global");
- if (function && LLVM::LLVMPointerType::get(function.getType()) !=
- op.getResult().getType())
- return op.emitOpError(
+ if (function &&
+ LLVM::LLVMPointerType::get(function.getType()) != getResult().getType())
+ return emitOpError(
"the type must be a pointer to the type of the referenced function");
return success();
@@ -1791,60 +1794,60 @@ static bool isZeroAttribute(Attribute value) {
return false;
}
-static LogicalResult verify(GlobalOp op) {
- if (!LLVMPointerType::isValidElementType(op.getType()))
- return op.emitOpError(
+LogicalResult GlobalOp::verify() {
+ if (!LLVMPointerType::isValidElementType(getType()))
+ return emitOpError(
"expects type to be a valid element type for an LLVM pointer");
- if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
- return op.emitOpError("must appear at the module level");
+ if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
+ return emitOpError("must appear at the module level");
- if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
- auto type = op.getType().dyn_cast<LLVMArrayType>();
+ if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
+ auto type = getType().dyn_cast<LLVMArrayType>();
IntegerType elementType =
type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
if (!elementType || elementType.getWidth() != 8 ||
type.getNumElements() != strAttr.getValue().size())
- return op.emitOpError(
+ return emitOpError(
"requires an i8 array type of the length equal to that of the string "
"attribute");
}
- if (Block *b = op.getInitializerBlock()) {
+ if (Block *b = getInitializerBlock()) {
ReturnOp ret = cast<ReturnOp>(b->getTerminator());
if (ret.operand_type_begin() == ret.operand_type_end())
- return op.emitOpError("initializer region cannot return void");
- if (*ret.operand_type_begin() != op.getType())
- return op.emitOpError("initializer region type ")
+ return emitOpError("initializer region cannot return void");
+ if (*ret.operand_type_begin() != getType())
+ return emitOpError("initializer region type ")
<< *ret.operand_type_begin() << " does not match global type "
- << op.getType();
+ << getType();
- if (op.getValueOrNull())
- return op.emitOpError("cannot have both initializer value and region");
+ if (getValueOrNull())
+ return emitOpError("cannot have both initializer value and region");
}
- if (op.getLinkage() == Linkage::Common) {
- if (Attribute value = op.getValueOrNull()) {
+ if (getLinkage() == Linkage::Common) {
+ if (Attribute value = getValueOrNull()) {
if (!isZeroAttribute(value)) {
- return op.emitOpError()
+ return emitOpError()
<< "expected zero value for '"
<< stringifyLinkage(Linkage::Common) << "' linkage";
}
}
}
- if (op.getLinkage() == Linkage::Appending) {
- if (!op.getType().isa<LLVMArrayType>()) {
- return op.emitOpError()
- << "expected array type for '"
- << stringifyLinkage(Linkage::Appending) << "' linkage";
+ if (getLinkage() == Linkage::Appending) {
+ if (!getType().isa<LLVMArrayType>()) {
+ return emitOpError() << "expected array type for '"
+ << stringifyLinkage(Linkage::Appending)
+ << "' linkage";
}
}
- Optional<uint64_t> alignAttr = op.getAlignment();
+ Optional<uint64_t> alignAttr = getAlignment();
if (alignAttr.hasValue()) {
uint64_t value = alignAttr.getValue();
if (!llvm::isPowerOf2_64(value))
- return op->emitError() << "alignment attribute is not a power of 2";
+ return emitError() << "alignment attribute is not a power of 2";
}
return success();
@@ -1864,9 +1867,9 @@ GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-static LogicalResult verify(GlobalCtorsOp op) {
- if (op.getCtors().size() != op.getPriorities().size())
- return op.emitError(
+LogicalResult GlobalCtorsOp::verify() {
+ if (getCtors().size() != getPriorities().size())
+ return emitError(
"mismatch between the number of ctors and the number of priorities");
return success();
}
@@ -1885,9 +1888,9 @@ GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-static LogicalResult verify(GlobalDtorsOp op) {
- if (op.getDtors().size() != op.getPriorities().size())
- return op.emitError(
+LogicalResult GlobalDtorsOp::verify() {
+ if (getDtors().size() != getPriorities().size())
+ return emitError(
"mismatch between the number of dtors and the number of priorities");
return success();
}
@@ -1940,6 +1943,14 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
return success();
}
+LogicalResult ShuffleVectorOp::verify() {
+ Type type1 = getV1().getType();
+ Type type2 = getV2().getType();
+ if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
+ return emitOpError("expected matching LLVM IR Dialect element types");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
@@ -2117,42 +2128,43 @@ LogicalResult LLVMFuncOp::verifyType() {
// - external functions have 'external' or 'extern_weak' linkage;
// - vararg is (currently) only supported for external functions;
// - entry block arguments are of LLVM types and match the function signature.
-static LogicalResult verify(LLVMFuncOp op) {
- if (op.getLinkage() == LLVM::Linkage::Common)
- return op.emitOpError()
- << "functions cannot have '"
- << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
+LogicalResult LLVMFuncOp::verify() {
+ if (getLinkage() == LLVM::Linkage::Common)
+ return emitOpError() << "functions cannot have '"
+ << stringifyLinkage(LLVM::Linkage::Common)
+ << "' linkage";
// Check to see if this function has a void return with a result attribute to
// it. It isn't clear what semantics we would assign to that.
- if (op.getType().getReturnType().isa<LLVMVoidType>() &&
- !op.getResultAttrs(0).empty()) {
- return op.emitOpError()
+ if (getType().getReturnType().isa<LLVMVoidType>() &&
+ !getResultAttrs(0).empty()) {
+ return emitOpError()
<< "cannot attach result attributes to functions with a void return";
}
- if (op.isExternal()) {
- if (op.getLinkage() != LLVM::Linkage::External &&
- op.getLinkage() != LLVM::Linkage::ExternWeak)
- return op.emitOpError()
- << "external functions must have '"
- << stringifyLinkage(LLVM::Linkage::External) << "' or '"
- << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
+ if (isExternal()) {
+ if (getLinkage() != LLVM::Linkage::External &&
+ getLinkage() != LLVM::Linkage::ExternWeak)
+ return emitOpError() << "external functions must have '"
+ << stringifyLinkage(LLVM::Linkage::External)
+ << "' or '"
+ << stringifyLinkage(LLVM::Linkage::ExternWeak)
+ << "' linkage";
return success();
}
- if (op.isVarArg())
- return op.emitOpError("only external functions can be variadic");
+ if (isVarArg())
+ return emitOpError("only external functions can be variadic");
- unsigned numArguments = op.getType().getNumParams();
- Block &entryBlock = op.front();
+ unsigned numArguments = getType().getNumParams();
+ Block &entryBlock = front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (!isCompatibleType(argType))
- return op.emitOpError("entry block argument #")
+ return emitOpError("entry block argument #")
<< i << " is not of LLVM type";
- if (op.getType().getParamType(i) != argType)
- return op.emitOpError("the type of entry block argument #")
+ if (getType().getParamType(i) != argType)
+ return emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}
@@ -2163,42 +2175,42 @@ static LogicalResult verify(LLVMFuncOp op) {
// Verification for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//
-static LogicalResult verify(LLVM::ConstantOp op) {
- if (StringAttr sAttr = op.getValue().dyn_cast<StringAttr>()) {
- auto arrayType = op.getType().dyn_cast<LLVMArrayType>();
+LogicalResult LLVM::ConstantOp::verify() {
+ if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
+ auto arrayType = getType().dyn_cast<LLVMArrayType>();
if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
!arrayType.getElementType().isInteger(8)) {
- return op->emitOpError()
- << "expected array type of " << sAttr.getValue().size()
- << " i8 elements for the string constant";
+ return emitOpError() << "expected array type of "
+ << sAttr.getValue().size()
+ << " i8 elements for the string constant";
}
return success();
}
- if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
+ if (auto structType = getType().dyn_cast<LLVMStructType>()) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
- return op.emitError() << "expected struct type with two elements of the "
- "same type, the type of a complex constant";
+ return emitError() << "expected struct type with two elements of the "
+ "same type, the type of a complex constant";
}
- auto arrayAttr = op.getValue().dyn_cast<ArrayAttr>();
+ auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr || arrayAttr.size() != 2 ||
arrayAttr[0].getType() != arrayAttr[1].getType()) {
- return op.emitOpError() << "expected array attribute with two elements, "
- "representing a complex constant";
+ return emitOpError() << "expected array attribute with two elements, "
+ "representing a complex constant";
}
Type elementType = structType.getBody()[0];
if (!elementType
.isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
- return op.emitError()
+ return emitError()
<< "expected struct element types to be floating point type or "
"integer type";
}
return success();
}
- if (!op.getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
- return op.emitOpError()
+ if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
+ return emitOpError()
<< "only supports integer, float, string or elements attributes";
return success();
}
@@ -2294,42 +2306,40 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(AtomicRMWOp op) {
- auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>();
- auto valType = op.getVal().getType();
+LogicalResult AtomicRMWOp::verify() {
+ auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
+ auto valType = getVal().getType();
if (valType != ptrType.getElementType())
- return op.emitOpError("expected LLVM IR element type for operand #0 to "
- "match type for operand #1");
- auto resType = op.getRes().getType();
+ return emitOpError("expected LLVM IR element type for operand #0 to "
+ "match type for operand #1");
+ auto resType = getRes().getType();
if (resType != valType)
- return op.emitOpError(
+ return emitOpError(
"expected LLVM IR result type to match type for operand #1");
- if (op.getBinOp() == AtomicBinOp::fadd ||
- op.getBinOp() == AtomicBinOp::fsub) {
+ if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
- return op.emitOpError("expected LLVM IR floating point type");
- } else if (op.getBinOp() == AtomicBinOp::xchg) {
+ return emitOpError("expected LLVM IR floating point type");
+ } else if (getBinOp() == AtomicBinOp::xchg) {
auto intType = valType.dyn_cast<IntegerType>();
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
!valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
!valType.isa<Float64Type>())
- return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
+ return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
auto intType = valType.dyn_cast<IntegerType>();
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64)
- return op.emitOpError("expected LLVM IR integer type");
+ return emitOpError("expected LLVM IR integer type");
}
- if (static_cast<unsigned>(op.getOrdering()) <
+ if (static_cast<unsigned>(getOrdering()) <
static_cast<unsigned>(AtomicOrdering::monotonic))
- return op.emitOpError()
- << "expected at least '"
- << stringifyAtomicOrdering(AtomicOrdering::monotonic)
- << "' ordering";
+ return emitOpError() << "expected at least '"
+ << stringifyAtomicOrdering(AtomicOrdering::monotonic)
+ << "' ordering";
return success();
}
@@ -2375,28 +2385,28 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
return success();
}
-static LogicalResult verify(AtomicCmpXchgOp op) {
- auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>();
+LogicalResult AtomicCmpXchgOp::verify() {
+ auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
if (!ptrType)
- return op.emitOpError("expected LLVM IR pointer type for operand #0");
- auto cmpType = op.getCmp().getType();
- auto valType = op.getVal().getType();
+ return emitOpError("expected LLVM IR pointer type for operand #0");
+ auto cmpType = getCmp().getType();
+ auto valType = getVal().getType();
if (cmpType != ptrType.getElementType() || cmpType != valType)
- return op.emitOpError("expected LLVM IR element type for operand #0 to "
- "match type for all other operands");
+ return emitOpError("expected LLVM IR element type for operand #0 to "
+ "match type for all other operands");
auto intType = valType.dyn_cast<IntegerType>();
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
!valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
!valType.isa<Float32Type>() && !valType.isa<Float64Type>())
- return op.emitOpError("unexpected LLVM IR type");
- if (op.getSuccessOrdering() < AtomicOrdering::monotonic ||
- op.getFailureOrdering() < AtomicOrdering::monotonic)
- return op.emitOpError("ordering must be at least 'monotonic'");
- if (op.getFailureOrdering() == AtomicOrdering::release ||
- op.getFailureOrdering() == AtomicOrdering::acq_rel)
- return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
+ return emitOpError("unexpected LLVM IR type");
+ if (getSuccessOrdering() < AtomicOrdering::monotonic ||
+ getFailureOrdering() < AtomicOrdering::monotonic)
+ return emitOpError("ordering must be at least 'monotonic'");
+ if (getFailureOrdering() == AtomicOrdering::release ||
+ getFailureOrdering() == AtomicOrdering::acq_rel)
+ return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
return success();
}
@@ -2432,12 +2442,12 @@ static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
p << stringifyAtomicOrdering(op.getOrdering());
}
-static LogicalResult verify(FenceOp &op) {
- if (op.getOrdering() == AtomicOrdering::not_atomic ||
- op.getOrdering() == AtomicOrdering::unordered ||
- op.getOrdering() == AtomicOrdering::monotonic)
- return op.emitOpError("can be given only acquire, release, acq_rel, "
- "and seq_cst orderings");
+LogicalResult FenceOp::verify() {
+ if (getOrdering() == AtomicOrdering::not_atomic ||
+ getOrdering() == AtomicOrdering::unordered ||
+ getOrdering() == AtomicOrdering::monotonic)
+ return emitOpError("can be given only acquire, release, acq_rel, "
+ "and seq_cst orderings");
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 811a09aac173e..5d5e8f4012122 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -62,8 +62,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
parser.getNameLoc(), result.operands));
}
-static LogicalResult verify(MmaOp op) {
- MLIRContext *context = op.getContext();
+LogicalResult CpAsyncOp::verify() {
+ if (size() != 4 && size() != 8 && size() != 16)
+ return emitError("expected byte size to be either 4, 8 or 16.");
+ return success();
+}
+
+LogicalResult MmaOp::verify() {
+ MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
auto f32Ty = Float32Type::get(context);
@@ -72,44 +78,55 @@ static LogicalResult verify(MmaOp op) {
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
- SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
- op.getOperandTypes().end());
+ auto operandTypes = getOperandTypes();
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
- operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
- f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
- f32Ty, f32Ty, f32Ty}) {
- return op.emitOpError(
- "expected operands to be 4 <halfx2>s followed by either "
- "4 <halfx2>s or 8 floats");
+ operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty}) {
+ return emitOpError("expected operands to be 4 <halfx2>s followed by either "
+ "4 <halfx2>s or 8 floats");
}
- if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
- return op.emitOpError("expected result type to be a struct of either 4 "
- "<halfx2>s or 8 floats");
+ if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
+ return emitOpError("expected result type to be a struct of either 4 "
+ "<halfx2>s or 8 floats");
}
- auto alayout = op->getAttrOfType<StringAttr>("alayout");
- auto blayout = op->getAttrOfType<StringAttr>("blayout");
+ auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
+ auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
if (!(alayout && blayout) ||
!(alayout.getValue() == "row" || alayout.getValue() == "col") ||
!(blayout.getValue() == "row" || blayout.getValue() == "col")) {
- return op.emitOpError(
- "alayout and blayout attributes must be set to either "
- "\"row\" or \"col\"");
+ return emitOpError("alayout and blayout attributes must be set to either "
+ "\"row\" or \"col\"");
}
- if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
- f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
- f32Ty, f32Ty, f32Ty} &&
- op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
+ if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty} &&
+ getType() == f32x8StructTy && alayout.getValue() == "row" &&
blayout.getValue() == "col") {
return success();
}
- return op.emitOpError("unimplemented mma.sync variant");
+ return emitOpError("unimplemented mma.sync variant");
+}
+
+LogicalResult ShflOp::verify() {
+ if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
+ return success();
+ auto type = getType().dyn_cast<LLVM::LLVMStructType>();
+ auto elementType = (type && type.getBody().size() == 2)
+ ? type.getBody()[1].dyn_cast<IntegerType>()
+ : nullptr;
+ if (!elementType || elementType.getWidth() != 1)
+ return emitError("expected return type to be a two-element struct with "
+ "i1 as the second element");
+ return success();
}
-std::pair<mlir::Type, unsigned>
-inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
+std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
+ NVVM::MMAFrag frag,
+ MLIRContext *context) {
unsigned numberElements = 0;
Type elementType;
OpBuilder builder(context);
@@ -131,76 +148,72 @@ inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
return std::make_pair(elementType, numberElements);
}
-static LogicalResult verify(NVVM::WMMALoadOp op) {
+LogicalResult NVVM::WMMALoadOp::verify() {
unsigned addressSpace =
- op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+ ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
- return op.emitOpError("expected source pointer in memory "
- "space 0, 1, 3");
+ return emitOpError("expected source pointer in memory "
+ "space 0, 1, 3");
- if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
- op.eltype(), op.frag()) == 0)
- return op.emitOpError() << "invalid attribute combination";
+ if (NVVM::WMMALoadOp::getIntrinsicID(m(), n(), k(), layout(), eltype(),
+ frag()) == 0)
+ return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo =
- inferMMAType(op.eltype(), op.frag(), op.getContext());
+ inferMMAType(eltype(), frag(), getContext());
Type dstType = LLVM::LLVMStructType::getLiteral(
- op.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
- if (op.getType() != dstType)
- return op.emitOpError("expected destination type is a structure of ")
+ getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
+ if (getType() != dstType)
+ return emitOpError("expected destination type is a structure of ")
<< typeInfo.second << " elements of type " << typeInfo.first;
return success();
}
-static LogicalResult verify(NVVM::WMMAStoreOp op) {
+LogicalResult NVVM::WMMAStoreOp::verify() {
unsigned addressSpace =
- op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+ ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
- return op.emitOpError("expected operands to be a source pointer in memory "
- "space 0, 1, 3");
+ return emitOpError("expected operands to be a source pointer in memory "
+ "space 0, 1, 3");
- if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
- op.eltype()) == 0)
- return op.emitOpError() << "invalid attribute combination";
+ if (NVVM::WMMAStoreOp::getIntrinsicID(m(), n(), k(), layout(), eltype()) == 0)
+ return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo =
- inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext());
- if (op.args().size() != typeInfo.second)
- return op.emitOpError()
- << "expected " << typeInfo.second << " data operands";
- if (llvm::any_of(op.args(), [&typeInfo](Value operands) {
+ inferMMAType(eltype(), NVVM::MMAFrag::c, getContext());
+ if (args().size() != typeInfo.second)
+ return emitOpError() << "expected " << typeInfo.second << " data operands";
+ if (llvm::any_of(args(), [&typeInfo](Value operands) {
return operands.getType() != typeInfo.first;
}))
- return op.emitOpError()
- << "expected data operands of type " << typeInfo.first;
+ return emitOpError() << "expected data operands of type " << typeInfo.first;
return success();
}
-static LogicalResult verify(NVVM::WMMAMmaOp op) {
- if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(),
- op.layoutB(), op.eltypeA(),
- op.eltypeB()) == 0)
- return op.emitOpError() << "invalid attribute combination";
+LogicalResult NVVM::WMMAMmaOp::verify() {
+ if (NVVM::WMMAMmaOp::getIntrinsicID(m(), n(), k(), layoutA(), layoutB(),
+ eltypeA(), eltypeB()) == 0)
+ return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfoA =
- inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext());
+ inferMMAType(eltypeA(), NVVM::MMAFrag::a, getContext());
std::pair<Type, unsigned> typeInfoB =
- inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext());
+ inferMMAType(eltypeA(), NVVM::MMAFrag::b, getContext());
std::pair<Type, unsigned> typeInfoC =
- inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext());
+ inferMMAType(eltypeB(), NVVM::MMAFrag::c, getContext());
SmallVector<Type, 32> arguments;
arguments.append(typeInfoA.second, typeInfoA.first);
arguments.append(typeInfoB.second, typeInfoB.first);
arguments.append(typeInfoC.second, typeInfoC.first);
unsigned numArgs = arguments.size();
- if (op.args().size() != numArgs)
- return op.emitOpError() << "expected " << numArgs << " arguments";
+ if (args().size() != numArgs)
+ return emitOpError() << "expected " << numArgs << " arguments";
for (unsigned i = 0; i < numArgs; i++) {
- if (op.args()[i].getType() != arguments[i])
- return op.emitOpError()
- << "expected argument " << i << " to be of type " << arguments[i];
+ if (args()[i].getType() != arguments[i])
+ return emitOpError() << "expected argument " << i << " to be of type "
+ << arguments[i];
}
Type dstType = LLVM::LLVMStructType::getLiteral(
- op.getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
- if (op.getType() != dstType)
- return op.emitOpError("expected destination type is a structure of ")
+ getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
+ if (getType() != dstType)
+ return emitOpError("expected destination type is a structure of ")
<< typeInfoC.second << " elements of type " << typeInfoC.first;
return success();
}
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index ee7b0580cc483..7b70e53a6e9c3 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -28,17 +28,15 @@ void x86vector::X86VectorDialect::initialize() {
>();
}
-static LogicalResult verify(x86vector::MaskCompressOp op) {
- if (op.src() && op.constant_src())
- return emitError(op.getLoc(), "cannot use both src and constant_src");
+LogicalResult x86vector::MaskCompressOp::verify() {
+ if (src() && constant_src())
+ return emitError("cannot use both src and constant_src");
- if (op.src() && (op.src().getType() != op.dst().getType()))
- return emitError(op.getLoc(),
- "failed to verify that src and dst have same type");
+ if (src() && (src().getType() != dst().getType()))
+ return emitError("failed to verify that src and dst have same type");
- if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType()))
+ if (constant_src() && (constant_src()->getType() != dst().getType()))
return emitError(
- op.getLoc(),
"failed to verify that constant_src and dst have same type");
return success();
More information about the Mlir-commits
mailing list