[Mlir-commits] [mlir] [mlir][spirv] Improve coop matrix attribute handling (PR #66020)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 11 15:16:39 PDT 2023
llvmbot wrote:
@llvm/pr-subscribers-mlir-spirv
<details>
<summary>Changes</summary>
- Fix values of Matrix Operand bit enums.
- Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.
The target test passes validation with `spirv-val`.
--
Full diff: https://github.com/llvm/llvm-project/pull/66020.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+5-5)
- (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-7)
- (modified) mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir (+28)
- (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+19-2)
<pre>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ce3ad875fa45d1..1013cbc8ca562b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
// Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
def SPIRV_KHR_CMO_None : I32BitEnumAttrCaseNone<"None">;
-def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
-def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
-def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
-def SPIRV_KHR_CMO_Result_Signed : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
-def SPIRV_KHR_CMO_AccSat : I32BitEnumAttrCaseBit<"AccSat", 16>;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
+def SPIRV_KHR_CMO_Result_Signed : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
+def SPIRV_KHR_CMO_AccSat : I32BitEnumAttrCaseBit<"AccSat", 4>;
def SPIRV_KHR_CooperativeMatrixOperandsAttr :
SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 77dbf130c777857..d43f7a1823e912b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
// spirv.KHR.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
- Type coopMatrix) {
+static LogicalResult
+verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
+ spirv::MemoryAccessAttr memoryOperand) {
auto pointerType = cast<PointerType>(pointer);
Type pointeeType = pointerType.getPointeeType();
if (!isa<ScalarType, VectorType>(pointeeType)) {
- return op->emitError(
+ return op->emitOpError(
"Pointer must point to a scalar or vector type but provided ")
<< pointeeType;
}
+ // The 'Aligned' memory operand requires an alignment literal to follow, which
+ // needs to be implemented on the level of op parsing and (de-)serialization.
+ // TODO: Consider adding support for this attribute value.
+ if (memoryOperand &&
+ spirv::bitEnumContainsAll(memoryOperand.getValue(),
+ spirv::MemoryAccess::Aligned)) {
+ return op->emitOpError("has unhandled memory operand 'Aligned'");
+ }
+
// TODO: Verify the memory object behind the pointer:
// > If the Shader capability was declared, Pointer must point into an array
// > and any ArrayStride decoration on Pointer is ignored.
@@ -41,8 +52,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
}
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
- getResult().getType());
+ return verifyCoopMatrixAccess(*this, getPointer().getType(),
+ getResult().getType(), getMemoryOperandAttr());
}
//===----------------------------------------------------------------------===//
@@ -50,8 +61,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
- getObject().getType());
+ return verifyCoopMatrixAccess(*this, getPointer().getType(),
+ getObject().getType(), getMemoryOperandAttr());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 40736367520e843..3adcd711f74a8f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
// -----
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
// -----
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+ // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+ spirv.Return
+}
+
+// -----
+
spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 8546172f4f797b5..153ff4793797267 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+ // CHECK-NEXT: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
spirv.Return
}
@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
- // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+ // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <ASigned> :
// CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
// CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
// CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
%q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+ <ASigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+ -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+ // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
<BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
- // TODO: Handle multiple matrix operands and add relevant testcases here.
+ // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
+ // CHECK-SAME: <ASigned|BSigned|ResultSigned|AccSat> :
+ %s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+ <ASigned|BSigned|ResultSigned|AccSat> :
+ !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+ -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
spirv.Return
}
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66020
More information about the Mlir-commits
mailing list