[Mlir-commits] [mlir] 53b3be7 - [mlir][spirv] Fix coop matrix load (#65712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 8 10:29:16 PDT 2023
Author: Jakub Kuderski
Date: 2023-09-08T13:29:12-04:00
New Revision: 53b3be7ecb8fbd5c5a190fd972191cd94efab571
URL: https://github.com/llvm/llvm-project/commit/53b3be7ecb8fbd5c5a190fd972191cd94efab571
DIFF: https://github.com/llvm/llvm-project/commit/53b3be7ecb8fbd5c5a190fd972191cd94efab571.diff
LOG: [mlir][spirv] Fix coop matrix load (#65712)
- Fix order of operands/attributes
- Allow for stride to be any integer type
- Use ODS for parsing/printing
- Update examples and tests
- Fix a typo in SPIR-V tblgen code
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 7060aa80dc113ed..31d26706faecb1d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -101,20 +101,32 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
``` {.ebnf}
cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad`
ssa-use `,` ssa-use `,`
- cooperative-matrix-layout `,`
- (`[` memory-operand `]`)? ` : `
- pointer-type `as` cooperative-matrix-type
+ `<` cooperative-matrix-layout `>`
+ (`,` `<` memory-operand `>`)? `:`
+ pointer-type `,` stride-type `->` cooperative-matrix-type
```
+ TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
+ support this optionality in the SPIR-V dialect.
+
#### Example:
```
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor
- : !spirv.ptr<i32, StorageBuffer>
- as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>
+ : !spirv.ptr<i32, StorageBuffer>, i32
+ -> !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
+
+ %1 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile>
+ : !spirv.ptr<f32, StorageBuffer>, i64
+ -> !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc>
```
}];
+ let assemblyFormat = [{
+ $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+ type(operands) `->` type($result)
+ }];
+
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
@@ -124,8 +136,8 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
let arguments = (ins
SPIRV_AnyPtr:$pointer,
- SPIRV_Integer:$stride,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+ SPIRV_Integer:$stride,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
);
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index bc1d30f55518300..4f986065d8d9cd3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -37,49 +37,6 @@ LogicalResult KHRCooperativeMatrixLengthOp::verify() {
// spirv.KHR.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//
-ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
- OperationState &result) {
- std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
- if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
- return failure();
- if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
- return failure();
-
- CooperativeMatrixLayoutKHR layout;
- if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
- layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
- return failure();
- }
-
- if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
- return failure();
-
- Type ptrType;
- Type elementType;
- if (parser.parseColon() || parser.parseType(ptrType) ||
- parser.parseKeywordType("as", elementType)) {
- return failure();
- }
- result.addTypes(elementType);
-
- Type strideType = parser.getBuilder().getIntegerType(32);
- if (parser.resolveOperands(operandInfo, {ptrType, strideType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getStride() << ", "
- << getMatrixLayout();
- // Print optional memory operand attribute.
- if (auto memOperand = getMemoryOperand())
- printer << " [\"" << memOperand << "\"]";
- printer << " : " << getPointer().getType() << " as " << getType();
-}
-
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
Type coopMatrix) {
auto pointerType = cast<PointerType>(pointer);
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index aa6e072b03c5d3d..aad1e44bf8f7bd1 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -23,37 +23,46 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
- // CHECK-SAME: !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
- !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_memoperand
spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
- // CHECK-SAME: !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] :
- !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] :
- // CHECK-SAME: !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] :
- !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
+ // CHECK-SAME: !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
+ !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_function
spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
- // CHECK-SAME: !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
- !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
+ // CHECK-SAME: !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+ !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_stride_i16
+spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}
@@ -82,8 +91,8 @@ spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBu
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
- !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
+ !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
}
@@ -92,16 +101,16 @@ spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32
spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{expected ','}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
- !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // expected-error @+1 {{expected valid keyword}}
+ // expected-error @+1 {{expected '<'}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
- !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
}
@@ -109,8 +118,8 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB
spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
- !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ccf4240f8e56089..8468f92600a44e9 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -929,9 +929,9 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
if (valueArg->isVariableLength()) {
if (i != e - 1) {
- PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
- "std::optional<...> arguments only if "
- "it's the last argument");
+ PrintFatalError(
+ loc, "SPIR-V ops can have Variadic<..> or "
+ "Optional<...> arguments only if it's the last argument");
}
os << tabs
<< formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
More information about the Mlir-commits
mailing list