[Mlir-commits] [mlir] [mlir][spirv] Add support for Aligned memory operand in CoopMatrix memory operations (PR #145480)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 24 02:46:47 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Igor Wodiany (IgWod-IMG)
<details>
<summary>Changes</summary>
In the process of adding support for Aligned, I have noticed that the support for `MakePointerAvailable` and `MakePointerVisible` is incomplete as the operation does not accept a scope nor check for `NonPrivatePointer`. The PR does not address it, but the relevant issues has been created #<!-- -->145485.
---
Full diff: https://github.com/llvm/llvm-project/pull/145480.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+10-6)
- (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-10)
- (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+31-3)
- (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 46732ba19afed..fd75532ae3d70 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
}];
let assemblyFormat = [{
- $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+ $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
type(operands) `->` type($result)
}];
@@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
+ // TODO: Add scope operand for MakePointer*. See #145485.
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
SPIRV_Integer:$stride,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+ OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+ OptionalAttr<I32Attr>:$alignment
);
let results = (outs
@@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, result, pointer, layout, stride,
- spirv::MemoryAccessAttr{});
+ spirv::MemoryAccessAttr{}, IntegerAttr{});
}]>
];
}
@@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
}];
let assemblyFormat = [{
- $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+ $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
type(operands)
}];
@@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
+ // TODO: Add scope operand for MakePointer*. See #145485.
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_AnyCooperativeMatrix:$object,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
SPIRV_Integer:$stride,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+ OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+ OptionalAttr<I32Attr>:$alignment
);
let results = (outs);
@@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, pointer, object, layout, stride,
- spirv::MemoryAccessAttr{});
+ spirv::MemoryAccessAttr{}, IntegerAttr{});
}]>
];
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 2ff3efdc96a7f..fa20cc179f892 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -23,7 +23,8 @@ namespace mlir::spirv {
static LogicalResult
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
- spirv::MemoryAccessAttr memoryOperand) {
+ spirv::MemoryAccessAttr memoryOperand,
+ IntegerAttr alignment) {
auto pointerType = cast<PointerType>(pointer);
Type pointeeType = pointerType.getPointeeType();
if (!isa<ScalarType, VectorType>(pointeeType)) {
@@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
"not compatible with memory operand 'MakePointerVisible'");
}
- // The 'Aligned' memory operand requires an alignment literal to follow,
- // which needs to be implemented on the level of op parsing and
- // (de-)serialization.
- // TODO: Consider adding support for this attribute value.
- if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
- spirv::MemoryAccess::Aligned)) {
- return op->emitOpError("has unhandled memory operand 'Aligned'");
+ // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
+ // #145485.
+
+ if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+ !alignment) {
+ return op->emitOpError("missing value for the 'Aligned' memory operand");
+ }
+
+ if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+ alignment) {
+ return op->emitOpError(
+ "found alignment attribute for non-'Aligned' memory operand");
}
}
@@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
- getResult().getType(), getMemoryOperandAttr());
+ getResult().getType(), getMemoryOperandAttr(),
+ getAlignmentAttr());
}
//===----------------------------------------------------------------------===//
@@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
- getObject().getType(), getMemoryOperandAttr());
+ getObject().getType(), getMemoryOperandAttr(),
+ getAlignmentAttr());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..56d477cca97b7 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
spirv.Return
}
+// CHECK-LABEL: @cooperative_matrix_load_aligned
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
// CHECK-LABEL: @cooperative_matrix_store
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
@@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBu
spirv.Return
}
+// CHECK-LABEL: @cooperative_matrix_store_aligned
+spirv.func @cooperative_matrix_store_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+ spirv.Return
+}
+
// -----
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuf
// -----
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+ // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
@@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer
// -----
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+ // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
@@ -179,7 +198,7 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageB
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
- // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+ // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
spirv.Return
@@ -187,6 +206,15 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %str
// -----
+spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <MakePointerVisible>, 16 :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 153ff47937972..77949908e8883 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires
spirv.Return
}
+ // CHECK-LABEL: @cooperative_matrix_load_3
+ spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+ // CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.Return
+ }
+
// CHECK-LABEL: @cooperative_matrix_store_1
spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
@@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+ // CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
// CHECK-NEXT: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
``````````
</details>
https://github.com/llvm/llvm-project/pull/145480
More information about the Mlir-commits
mailing list