[Mlir-commits] [mlir] [mlir][spirv] Improve coop matrix attribute handling (PR #66020)

Jakub Kuderski llvmlistbot at llvm.org
Mon Sep 11 15:15:44 PDT 2023


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/66020:

- Fix values of Matrix Operand bit enums
- Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.

The target test passes vaidation with `spirv-val`.

>From b39464c2640a67cf2f53949af0785e05af3fa7b3 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 11 Sep 2023 18:12:57 -0400
Subject: [PATCH] [mlir][spirv] Improve coop matrix attribute handling

- Fix values of Matrix Operand bit enums
- Add verification for the aligned Memory Operand attributes. Mark the
  'Aligned' enumerant as not supported.

The target test passes vaidation with `spirv-val`.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 10 +++----
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 25 ++++++++++++-----
 .../SPIRV/IR/cooperative-matrix-ops.mlir      | 28 +++++++++++++++++++
 .../SPIRV/khr-cooperative-matrix-ops.mlir     | 21 ++++++++++++--
 4 files changed, 70 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ce3ad875fa45d1..1013cbc8ca562b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
 
 // Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
 def SPIRV_KHR_CMO_None           : I32BitEnumAttrCaseNone<"None">;
-def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
-def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
-def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
-def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
-def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 16>;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
+def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
+def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 4>;
 
 def SPIRV_KHR_CooperativeMatrixOperandsAttr :
     SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 77dbf130c777857..d43f7a1823e912b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
+static LogicalResult
+verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
+                       spirv::MemoryAccessAttr memoryOperand) {
   auto pointerType = cast<PointerType>(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa<ScalarType, VectorType>(pointeeType)) {
-    return op->emitError(
+    return op->emitOpError(
                "Pointer must point to a scalar or vector type but provided ")
            << pointeeType;
   }
 
+  // The 'Aligned' memory operand requires an alignment literal to follow, which
+  // needs to be implemented on the level of op parsing and (de-)serialization.
+  // TODO: Consider adding support for this attribute value.
+  if (memoryOperand &&
+      spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  }
+
   // TODO: Verify the memory object behind the pointer:
   // > If the Shader capability was declared, Pointer must point into an array
   // > and any ArrayStride decoration on Pointer is ignored.
@@ -41,8 +52,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 }
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getResult().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -50,8 +61,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getObject().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 40736367520e843..3adcd711f74a8f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
 
 // -----
 
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
 
 // -----
 
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 8546172f4f797b5..153ff4793797267 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
     // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+    // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
+      !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.Return
   }
 
@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
                                                         !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                         -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <ASigned> :
     // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
     // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
     // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
     %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                           <ASigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                       !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                       -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+    %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
                                            <BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                                        !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                        -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // TODO: Handle multiple matrix operands and add relevant testcases here.
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
+    // CHECK-SAME:           <ASigned|BSigned|ResultSigned|AccSat> :
+    %s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                           <ASigned|BSigned|ResultSigned|AccSat> :
+                                           !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                           !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                           -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
     spirv.Return
   }
 



More information about the Mlir-commits mailing list