[Mlir-commits] [mlir] [mlir][spirv] Fix coop matrix load (PR #65712)

Jakub Kuderski llvmlistbot at llvm.org
Fri Sep 8 08:41:36 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/65712:

>From 7d64a7d9dbc2543b0001af705a5383260f41ff2f Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 8 Sep 2023 00:41:50 -0400
Subject: [PATCH 1/2] [mlir][spirv] Fix coop matrix load

- Allow for stride to be any integer type
- Use ODS for parsing/printing
- Update examples and tests
---
 .../SPIRV/IR/SPIRVCooperativeMatrixOps.td     | 19 +++++--
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 43 ---------------
 .../SPIRV/IR/cooperative-matrix-ops.mlir      | 53 +++++++++++--------
 3 files changed, 45 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 7060aa80dc113ed..5d6c5f057ae6acc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -101,20 +101,29 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
     ``` {.ebnf}
     cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad`
                               ssa-use `,` ssa-use `,`
-                              cooperative-matrix-layout `,`
-                              (`[` memory-operand `]`)? ` : `
-                                pointer-type `as` cooperative-matrix-type
+                              `<` cooperative-matrix-layout `>`
+                              (`,` `<` memory-operand `>`)? `:`
+                                pointer-type `as` cooperative-matrix-type `,` stride-type
     ```
 
     #### Example:
 
     ```
-    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>
          : !spirv.ptr<i32, StorageBuffer>
-             as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
+             as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+    %1 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile>
+         : !spirv.ptr<f32, StorageBuffer>
+             as !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64
     ```
   }];
 
+  let assemblyFormat = [{
+    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+      type($pointer) `as` type($result) `,` type($stride)
+  }];
+
   let availability = [
     MinVersion<SPIRV_V_1_6>,
     MaxVersion<SPIRV_V_1_6>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index bc1d30f55518300..4f986065d8d9cd3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -37,49 +37,6 @@ LogicalResult KHRCooperativeMatrixLengthOp::verify() {
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
 
-ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
-                                              OperationState &result) {
-  std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
-  if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
-    return failure();
-  if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
-    return failure();
-
-  CooperativeMatrixLayoutKHR layout;
-  if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
-          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
-    return failure();
-  }
-
-  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
-    return failure();
-
-  Type ptrType;
-  Type elementType;
-  if (parser.parseColon() || parser.parseType(ptrType) ||
-      parser.parseKeywordType("as", elementType)) {
-    return failure();
-  }
-  result.addTypes(elementType);
-
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  if (parser.resolveOperands(operandInfo, {ptrType, strideType},
-                             parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  return success();
-}
-
-void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getStride() << ", "
-          << getMatrixLayout();
-  // Print optional memory operand attribute.
-  if (auto memOperand = getMemoryOperand())
-    printer << " [\"" << memOperand << "\"]";
-  printer << " : " << getPointer().getType() << " as " << getType();
-}
-
 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
                                                     Type coopMatrix) {
   auto pointerType = cast<PointerType>(pointer);
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index aa6e072b03c5d3d..4a34674b810309d 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -23,37 +23,46 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
 
 // CHECK-LABEL: @cooperative_matrix_load
 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
-  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_memoperand
 spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
-  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
 spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
-  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] :
-  // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] :
-    !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
+  // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
+    !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_function
 spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
-  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
   // CHECK-SAME:   !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
-    !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+    !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>, i32
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_stride_i16
+spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i16
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i16
   spirv.Return
 }
 
@@ -82,8 +91,8 @@ spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBu
 
 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{Pointer must point to a scalar or vector type}}
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
-    !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
+    !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>, i32
   spirv.Return
 }
 
@@ -92,16 +101,16 @@ spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32
 spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{expected ','}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>, i32
   spirv.Return
 }
 
 // -----
 
 spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{expected valid keyword}}
+  // expected-error @+1 {{expected '<'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
+    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>, i32
   spirv.Return
 }
 
@@ -109,8 +118,8 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB
 
 spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>, i32
   spirv.Return
 }
 

>From da404cb9e9d4971d5b976335780a8d7cd54c1d8c Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 8 Sep 2023 11:41:12 -0400
Subject: [PATCH 2/2] Use functional type for consistency

---
 .../SPIRV/IR/SPIRVCooperativeMatrixOps.td     | 17 ++++++-----
 .../SPIRV/IR/cooperative-matrix-ops.mlir      | 28 +++++++++----------
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  6 ++--
 3 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 5d6c5f057ae6acc..31d26706faecb1d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -103,25 +103,28 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
                               ssa-use `,` ssa-use `,`
                               `<` cooperative-matrix-layout `>`
                               (`,` `<` memory-operand `>`)? `:`
-                                pointer-type `as` cooperative-matrix-type `,` stride-type
+                                pointer-type `,` stride-type `->` cooperative-matrix-type
     ```
 
+    TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
+    support this optionality in the SPIR-V dialect.
+
     #### Example:
 
     ```
     %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>
-         : !spirv.ptr<i32, StorageBuffer>
-             as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+         : !spirv.ptr<i32, StorageBuffer>, i32
+             -> !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
 
     %1 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile>
-         : !spirv.ptr<f32, StorageBuffer>
-             as !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64
+         : !spirv.ptr<f32, StorageBuffer>, i64
+             -> !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc>
     ```
   }];
 
   let assemblyFormat = [{
     $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
-      type($pointer) `as` type($result) `,` type($stride)
+      type(operands) `->` type($result)
   }];
 
   let availability = [
@@ -133,8 +136,8 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
 
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
-    SPIRV_Integer:$stride,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+    SPIRV_Integer:$stride,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
   );
 
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 4a34674b810309d..aad1e44bf8f7bd1 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -24,45 +24,45 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
 // CHECK-LABEL: @cooperative_matrix_load
 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
-  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_memoperand
 spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
-  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
 spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
-  // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+  // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
-    !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
+    !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_function
 spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
-  // CHECK-SAME:   !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+  // CHECK-SAME:   !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
-    !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>, i32
+    !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_stride_i16
 spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
-  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i16
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i16
+    !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
   spirv.Return
 }
 
@@ -92,7 +92,7 @@ spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBu
 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{Pointer must point to a scalar or vector type}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
-    !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>, i32
+    !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
 }
 
@@ -101,7 +101,7 @@ spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32
 spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{expected ','}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>, i32
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
 }
 
@@ -110,7 +110,7 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB
 spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{expected '<'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>, i32
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
 }
 
@@ -119,7 +119,7 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB
 spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
-    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>, i32
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ccf4240f8e56089..8468f92600a44e9 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -929,9 +929,9 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
     if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
       if (valueArg->isVariableLength()) {
         if (i != e - 1) {
-          PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
-                               "std::optional<...> arguments only if "
-                               "it's the last argument");
+          PrintFatalError(
+              loc, "SPIR-V ops can have Variadic<..> or "
+                   "Optional<...> arguments only if it's the last argument");
         }
         os << tabs
            << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);



More information about the Mlir-commits mailing list