[Mlir-commits] [mlir] [mlir][spirv] Fix coop matrix store (PR #65709)

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 7 21:45:55 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/65709:

>From b9c971d76bbee4e861aeccd1fd5581f3c601e16a Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 8 Sep 2023 00:10:59 -0400
Subject: [PATCH 1/2] [mlir][spirv] Fix coop matrix store

- Fix operand/attribute order
- Use ODS for parsing/printing
- Allow for stride to be any integer type
---
 .../SPIRV/IR/SPIRVCooperativeMatrixOps.td     | 15 ++++---
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 43 -------------------
 .../SPIRV/IR/cooperative-matrix-ops.mlir      | 35 +++++++++------
 3 files changed, 33 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 7060aa80dc113ed..9da120d1322277c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -171,10 +171,10 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
 
     ``` {.ebnf}
      coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
-                              ssa-use `, ` ssa-use `, `
-                              ssa-use `, ` cooperative-matrix-layout `, `
-                              (`[` memory-operand `]`)? `:`
-                              pointer-type `,` coop-matrix-type
+                              ssa-use `,` ssa-use `,`
+                              ssa-use `,` `<` cooperative-matrix-layout `>
+                              (`,` `<` memory-operand `>`)? `:`
+                              pointer-type `,` coop-matrix-type, stride-type
     ```
 
     #### Example:
@@ -185,6 +185,11 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
     ```
   }];
 
+  let assemblyFormat = [{
+    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+      type($pointer) `,` type($object) `,` type($stride)
+  }];
+
   let availability = [
     MinVersion<SPIRV_V_1_6>,
     MaxVersion<SPIRV_V_1_6>,
@@ -195,8 +200,8 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_AnyCooperativeMatrix:$object,
-    SPIRV_Integer:$stride,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+    SPIRV_Integer:$stride,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
   );
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index bc1d30f55518300..36aea151a87e380 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -106,49 +106,6 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 // spirv.KHR.CooperativeMatrixStore
 //===----------------------------------------------------------------------===//
 
-ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
-                                               OperationState &result) {
-  std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
-  for (auto &op : operandInfo) {
-    if (parser.parseOperand(op) || parser.parseComma())
-      return failure();
-  }
-
-  CooperativeMatrixLayoutKHR layout;
-  if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
-          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
-    return failure();
-  }
-
-  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
-    return failure();
-
-  Type ptrType;
-  Type objectType;
-  if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
-      parser.parseType(objectType)) {
-    return failure();
-  }
-
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
-                             parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  return success();
-}
-
-void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
-          << ", " << getMatrixLayout();
-
-  // Print optional memory operand attribute.
-  if (auto memOperand = getMemoryOperand())
-    printer << " [\"" << *memOperand << "\"]";
-  printer << " : " << getPointer().getType() << ", " << getObject().getType();
-}
-
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
   return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
                                         getObject().getType());
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index aa6e072b03c5d3d..73be42aeeab90fe 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -60,10 +60,10 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %
 // CHECK-LABEL: @cooperative_matrix_store
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
-  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
-  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
-  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
-    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
   spirv.Return
 }
 
@@ -71,10 +71,21 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %str
 spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
                                                 %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
                                                 %stride : i32) "None" {
-  // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
-  // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
-  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
-    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
+  // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_stride_i16
+spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
+                                                %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
+                                                %stride : i16) "None" {
+  // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
+  // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
   spirv.Return
 }
 
@@ -128,9 +139,9 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
 
 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
-  // expected-error @+1 {{expected valid keyword}}
+  // expected-error @+1 {{expected '<'}}
   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
-    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
   spirv.Return
 }
 
@@ -139,8 +150,8 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
 spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
                                                      %stride : i32) "None" {
   // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
-  spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
-    !spirv.ptr<i32, StorageBuffer>, i32
+  spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
+    !spirv.ptr<i32, StorageBuffer>, i32, i32
   spirv.Return
 }
 

>From c6d879d92346115dceb0e367f97787642c7a66ab Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 8 Sep 2023 00:45:39 -0400
Subject: [PATCH 2/2] Update examples

---
 .../mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td     | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 9da120d1322277c..2bf0eb6fd74bcec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -180,8 +180,11 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
     #### Example:
 
     ```
-      spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
-        !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+      spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <RowMajor> :
+        !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+      spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <ColumnMajor>, <Volatile> :
+        !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64
     ```
   }];
 



More information about the Mlir-commits mailing list