[Mlir-commits] [mlir] 68cd1db - [mlir][spirv] Add cooperative matrix store op

Jakub Kuderski llvmlistbot at llvm.org
Wed Jul 19 08:03:05 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-19T11:01:09-04:00
New Revision: 68cd1dbc2ec97e20306694a7cdc480584295e62c

URL: https://github.com/llvm/llvm-project/commit/68cd1dbc2ec97e20306694a7cdc480584295e62c
DIFF: https://github.com/llvm/llvm-project/commit/68cd1dbc2ec97e20306694a7cdc480584295e62c.diff

LOG: [mlir][spirv] Add cooperative matrix store op

Implement cooperative matrix store for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D155631

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 7888d6e3aa7f0a..1e61aa747967d4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4447,6 +4447,7 @@ def SPIRV_OC_OpUDotAccSat                 : I32EnumAttrCase<"OpUDotAccSat", 4454
 def SPIRV_OC_OpSUDotAccSat                : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
 def SPIRV_OC_OpTypeCooperativeMatrixKHR   : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
 def SPIRV_OC_OpCooperativeMatrixLoadKHR   : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
+def SPIRV_OC_OpCooperativeMatrixStoreKHR  : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
 def SPIRV_OC_OpTypeCooperativeMatrixNV    : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
 def SPIRV_OC_OpCooperativeMatrixLoadNV    : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4546,11 +4547,12 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
       SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
-      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
-      SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
-      SPIRV_OC_OpTypeCooperativeMatrixNV,
-      SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
-      SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
+      SPIRV_OC_OpSUDotAccSat,
+      SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
+      SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
+      SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
+      SPIRV_OC_OpCooperativeMatrixLengthNV,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
       SPIRV_OC_OpGroupFMulKHR,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index c3c1e2cd042800..6de744039483b9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -134,6 +134,75 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   );
 }
 
+// -----
+
+def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
+  let summary = "Stores a cooperative matrix through a pointer";
+
+  let description = [{
+    Store a cooperative matrix through a pointer.
+    Pointer is a pointer. Its type must be an OpTypePointer whose Type operand
+    is a scalar or vector type. If the Shader capability was declared, Pointer
+    must point into an array and any ArrayStride decoration on Pointer is
+    ignored.
+
+    Object is the object to store. Its type must be an
+    OpTypeCooperativeMatrixKHR.
+
+    MemoryLayout specifies how matrix elements are laid out in memory. It must
+    come from a 32-bit integer constant instruction whose value corresponds to a
+    Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
+    description of the layouts and detailed layout-specific rules.
+
+    Stride further qualifies how matrix elements are laid out in memory. It must
+    be a scalar integer type and its exact semantics depend on MemoryLayout.
+
+    Memory Operand must be a Memory Operand literal. If not present, it is the
+    same as specifying None.
+
+    NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
+    as 'Memory Access'.
+
+    For a given dynamic instance of this instruction, all operands of this
+    instruction must be the same for all invocations in a given scope instance
+    (where the scope is the scope the cooperative matrix type was created with).
+    All invocations in a given scope instance must be active or all must be
+    inactive.
+
+    ``` {.ebnf}
+     coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
+                              ssa-use `, ` ssa-use `, `
+                              ssa-use `, ` cooperative-matrix-layout `, `
+                              (`[` memory-operand `]`)? `:`
+                              pointer-type `,` coop-matrix-type
+    ```
+
+    #### Example:
+
+    ```
+      spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
+        !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_6>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_KHR_cooperative_matrix]>,
+    Capability<[SPIRV_C_CooperativeMatrixKHR]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyPtr:$pointer,
+    SPIRV_AnyCooperativeMatrix:$object,
+    SPIRV_Integer:$stride,
+    SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+  );
+
+  let results = (outs);
+}
+
 //===----------------------------------------------------------------------===//
 // SPV_NV_cooperative_matrix extension ops.
 //===----------------------------------------------------------------------===//
@@ -364,7 +433,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
                               ssa-use `, ` ssa-use `, `
                               ssa-use `, ` ssa-use `, `
                               (`[` memory-access `]`)? `:`
-                              pointer-type `,` spirv-element-type
+                              pointer-type `,` coop-matrix-type
     ```
 
     For example:

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 61b084e6a56412..2516c29fbc58a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4111,6 +4111,58 @@ LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
                                         getResult().getType());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixStore
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
+                                                      OperationState &result) {
+  std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
+  for (auto &op : operandInfo) {
+    if (parser.parseOperand(op) || parser.parseComma())
+      return failure();
+  }
+
+  spirv::CooperativeMatrixLayoutKHR layout;
+  if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
+          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+    return failure();
+  }
+
+  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+    return failure();
+
+  Type ptrType;
+  Type objectType;
+  if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
+      parser.parseType(objectType)) {
+    return failure();
+  }
+
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
+                             parser.getNameLoc(), result.operands)) {
+    return failure();
+  }
+
+  return success();
+}
+
+void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
+  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+          << ", " << getMatrixLayout();
+
+  // Print optional memory operand attribute.
+  if (auto memOperand = getMemoryOperand())
+    printer << " [\"" << *memOperand << "\"]";
+  printer << " : " << getPointer().getType() << ", " << getObject().getType();
+}
+
+LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() {
+  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+                                        getObject().getType());
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.NV.CooperativeMatrixLength
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 7e38161aebae80..ce9a61a6277a04 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -57,6 +57,27 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_store
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_memoperand
+spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
+                                                %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+                                                %stride : i32) "None" {
+  // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
+  // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  spirv.Return
+}
+
 // -----
 
 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -95,6 +116,36 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
 
 // -----
 
+spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                                  %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{expected ','}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                                  %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{expected valid keyword}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
+                                                     %stride : i32) "None" {
+  // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
+    !spirv.ptr<i32, StorageBuffer>, i32
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // NV.CooperativeMatrix
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list