[Mlir-commits] [mlir] af9dafe - [mlir][spirv] Fix coop matrix store (#65709)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 8 10:58:33 PDT 2023
Author: Jakub Kuderski
Date: 2023-09-08T13:58:29-04:00
New Revision: af9dafeb3810cff4276d9916ccf61158fe49be85
URL: https://github.com/llvm/llvm-project/commit/af9dafeb3810cff4276d9916ccf61158fe49be85
DIFF: https://github.com/llvm/llvm-project/commit/af9dafeb3810cff4276d9916ccf61158fe49be85.diff
LOG: [mlir][spirv] Fix coop matrix store (#65709)
- Fix operand/attribute order
- Use ODS for parsing/printing
- Allow for stride to be any integer type
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 31d26706faecb1d..3ce43c7e2b1fcee 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -182,21 +182,32 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
inactive.
``` {.ebnf}
- coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
- ssa-use `, ` ssa-use `, `
- ssa-use `, ` cooperative-matrix-layout `, `
- (`[` memory-operand `]`)? `:`
- pointer-type `,` coop-matrix-type
+ coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore`
+ ssa-use `,` ssa-use `,`
+ ssa-use `,` `<` cooperative-matrix-layout `>`
+ (`,` `<` memory-operand `>`)? `:`
+ pointer-type `,` coop-matrix-type `,` stride-type
```
+ TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
+ support this optionality in the SPIR-V dialect.
+
#### Example:
```
- spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
- !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+ spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <ColumnMajor>, <Volatile> :
+ !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64
```
}];
+ let assemblyFormat = [{
+ $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+ type(operands)
+ }];
+
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
@@ -207,8 +218,8 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_AnyCooperativeMatrix:$object,
- SPIRV_Integer:$stride,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+ SPIRV_Integer:$stride,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
);
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 4f986065d8d9cd3..600813f361a4712 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -63,49 +63,6 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
// spirv.KHR.CooperativeMatrixStore
//===----------------------------------------------------------------------===//
-ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
- OperationState &result) {
- std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
- for (auto &op : operandInfo) {
- if (parser.parseOperand(op) || parser.parseComma())
- return failure();
- }
-
- CooperativeMatrixLayoutKHR layout;
- if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
- layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
- return failure();
- }
-
- if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
- return failure();
-
- Type ptrType;
- Type objectType;
- if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
- parser.parseType(objectType)) {
- return failure();
- }
-
- Type strideType = parser.getBuilder().getIntegerType(32);
- if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
- << ", " << getMatrixLayout();
-
- // Print optional memory operand attribute.
- if (auto memOperand = getMemoryOperand())
- printer << " [\"" << *memOperand << "\"]";
- printer << " : " << getPointer().getType() << ", " << getObject().getType();
-}
-
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
getObject().getType());
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index aad1e44bf8f7bd1..03acb0c08b275a3 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -69,10 +69,10 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
// CHECK-LABEL: @cooperative_matrix_store
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
- // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
- // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
- spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
- !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
spirv.Return
}
@@ -80,10 +80,21 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %str
spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
%stride : i32) "None" {
- // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
- // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
- spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
- !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_stride_i16
+spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
+ %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ %stride : i16) "None" {
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
spirv.Return
}
@@ -137,9 +148,9 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
- // expected-error @+1 {{expected valid keyword}}
+ // expected-error @+1 {{expected '<'}}
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
- !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
spirv.Return
}
@@ -148,8 +159,8 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
%stride : i32) "None" {
// expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
- spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
- !spirv.ptr<i32, StorageBuffer>, i32
+ spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, i32, i32
spirv.Return
}
More information about the Mlir-commits
mailing list