[Mlir-commits] [mlir] 68cd1db - [mlir][spirv] Add cooperative matrix store op
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 19 08:03:05 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-19T11:01:09-04:00
New Revision: 68cd1dbc2ec97e20306694a7cdc480584295e62c
URL: https://github.com/llvm/llvm-project/commit/68cd1dbc2ec97e20306694a7cdc480584295e62c
DIFF: https://github.com/llvm/llvm-project/commit/68cd1dbc2ec97e20306694a7cdc480584295e62c.diff
LOG: [mlir][spirv] Add cooperative matrix store op
Implement cooperative matrix store for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D155631
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 7888d6e3aa7f0a..1e61aa747967d4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4447,6 +4447,7 @@ def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454
def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
+def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4546,11 +4547,12 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
- SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
- SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
- SPIRV_OC_OpTypeCooperativeMatrixNV,
- SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
- SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
+ SPIRV_OC_OpSUDotAccSat,
+ SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
+ SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
+ SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
+ SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
SPIRV_OC_OpGroupFMulKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index c3c1e2cd042800..6de744039483b9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -134,6 +134,75 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
);
}
+// -----
+
+def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
+ let summary = "Stores a cooperative matrix through a pointer";
+
+ let description = [{
+ Store a cooperative matrix through a pointer.
+ Pointer is a pointer. Its type must be an OpTypePointer whose Type operand
+ is a scalar or vector type. If the Shader capability was declared, Pointer
+ must point into an array and any ArrayStride decoration on Pointer is
+ ignored.
+
+ Object is the object to store. Its type must be an
+ OpTypeCooperativeMatrixKHR.
+
+ MemoryLayout specifies how matrix elements are laid out in memory. It must
+ come from a 32-bit integer constant instruction whose value corresponds to a
+ Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
+ description of the layouts and detailed layout-specific rules.
+
+ Stride further qualifies how matrix elements are laid out in memory. It must
+ be a scalar integer type and its exact semantics depend on MemoryLayout.
+
+ Memory Operand must be a Memory Operand literal. If not present, it is the
+ same as specifying None.
+
+ NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
+ as 'Memory Access'.
+
+ For a given dynamic instance of this instruction, all operands of this
+ instruction must be the same for all invocations in a given scope instance
+ (where the scope is the scope the cooperative matrix type was created with).
+ All invocations in a given scope instance must be active or all must be
+ inactive.
+
+ ``` {.ebnf}
+ coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
+ ssa-use `, ` ssa-use `, `
+ ssa-use `, ` cooperative-matrix-layout `, `
+ (`[` memory-operand `]`)? `:`
+ pointer-type `,` coop-matrix-type
+ ```
+
+ #### Example:
+
+ ```
+ spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_6>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_KHR_cooperative_matrix]>,
+ Capability<[SPIRV_C_CooperativeMatrixKHR]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyPtr:$pointer,
+ SPIRV_AnyCooperativeMatrix:$object,
+ SPIRV_Integer:$stride,
+ SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+ OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+ );
+
+ let results = (outs);
+}
+
//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
@@ -364,7 +433,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
ssa-use `, ` ssa-use `, `
ssa-use `, ` ssa-use `, `
(`[` memory-access `]`)? `:`
- pointer-type `,` spirv-element-type
+ pointer-type `,` coop-matrix-type
```
For example:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 61b084e6a56412..2516c29fbc58a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4111,6 +4111,58 @@ LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
getResult().getType());
}
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixStore
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
+ for (auto &op : operandInfo) {
+ if (parser.parseOperand(op) || parser.parseComma())
+ return failure();
+ }
+
+ spirv::CooperativeMatrixLayoutKHR layout;
+ if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
+ layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+ return failure();
+ }
+
+ if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+ return failure();
+
+ Type ptrType;
+ Type objectType;
+ if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
+ parser.parseType(objectType)) {
+ return failure();
+ }
+
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
+ parser.getNameLoc(), result.operands)) {
+ return failure();
+ }
+
+ return success();
+}
+
+void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
+ printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+ << ", " << getMatrixLayout();
+
+ // Print optional memory operand attribute.
+ if (auto memOperand = getMemoryOperand())
+ printer << " [\"" << *memOperand << "\"]";
+ printer << " : " << getPointer().getType() << ", " << getObject().getType();
+}
+
+LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() {
+ return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+ getObject().getType());
+}
+
//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLength
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 7e38161aebae80..ce9a61a6277a04 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -57,6 +57,27 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %
spirv.Return
}
+// CHECK-LABEL: @cooperative_matrix_store
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_memoperand
+spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
+ %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+ %stride : i32) "None" {
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ spirv.Return
+}
+
// -----
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -95,6 +116,36 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
// -----
+spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+ // expected-error @+1 {{expected ','}}
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+ // expected-error @+1 {{expected valid keyword}}
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
+ %stride : i32) "None" {
+ // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
+ spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
+ !spirv.ptr<i32, StorageBuffer>, i32
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// NV.CooperativeMatrix
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list