[Mlir-commits] [mlir] [mlir][spirv] Improve coop matrix attribute handling (PR #66020)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 15:16:40 PDT 2023


llvmbot wrote:

@llvm/pr-subscribers-mlir-core

<details>
<summary>Changes</summary>

- Fix values of Matrix Operand bit enums.
- Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.

The target test passes validation with `spirv-val`.
--
Full diff: https://github.com/llvm/llvm-project/pull/66020.diff

4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+5-5) 
- (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-7) 
- (modified) mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir (+28) 
- (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+19-2) 


<pre>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ce3ad875fa45d1..1013cbc8ca562b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
 
 // Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
 def SPIRV_KHR_CMO_None           : I32BitEnumAttrCaseNone<"None">;
-def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
-def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
-def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
-def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
-def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 16>;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
+def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
+def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 4>;
 
 def SPIRV_KHR_CooperativeMatrixOperandsAttr :
     SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 77dbf130c777857..d43f7a1823e912b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
+static LogicalResult
+verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
+                       spirv::MemoryAccessAttr memoryOperand) {
   auto pointerType = cast<PointerType>(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa<ScalarType, VectorType>(pointeeType)) {
-    return op->emitError(
+    return op->emitOpError(
                "Pointer must point to a scalar or vector type but provided ")
            << pointeeType;
   }
 
+  // The 'Aligned' memory operand requires an alignment literal to follow, which
+  // needs to be implemented on the level of op parsing and (de-)serialization.
+  // TODO: Consider adding support for this attribute value.
+  if (memoryOperand &&
+      spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  }
+
   // TODO: Verify the memory object behind the pointer:
   // > If the Shader capability was declared, Pointer must point into an array
   // > and any ArrayStride decoration on Pointer is ignored.
@@ -41,8 +52,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 }
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getResult().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -50,8 +61,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getObject().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 40736367520e843..3adcd711f74a8f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
 
 // -----
 
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
 
 // -----
 
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 8546172f4f797b5..153ff4793797267 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
     // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+    // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
+      !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.Return
   }
 
@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
                                                         !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                         -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <ASigned> :
     // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
     // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
     // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
     %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                           <ASigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                       !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                       -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+    %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
                                            <BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                                        !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                        -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // TODO: Handle multiple matrix operands and add relevant testcases here.
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
+    // CHECK-SAME:           <ASigned|BSigned|ResultSigned|AccSat> :
+    %s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                           <ASigned|BSigned|ResultSigned|AccSat> :
+                                           !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                           !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                           -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
     spirv.Return
   }
 
</pre>

</details>

https://github.com/llvm/llvm-project/pull/66020


More information about the Mlir-commits mailing list