[Mlir-commits] [mlir] [mlir][spirv] Add support for Aligned memory operand in CoopMatrix memory operations (PR #145480)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 24 02:46:47 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

<details>
<summary>Changes</summary>

In the process of adding support for Aligned, I have noticed that the support for `MakePointerAvailable` and `MakePointerVisible` is incomplete as the operation does not accept a scope nor check for `NonPrivatePointer`. The PR does not address it, but the relevant issues has been created #<!-- -->145485.

---
Full diff: https://github.com/llvm/llvm-project/pull/145480.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+10-6) 
- (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-10) 
- (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+31-3) 
- (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 46732ba19afed..fd75532ae3d70 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   }];
 
   let assemblyFormat = [{
-    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
       type(operands) `->` type($result)
   }];
 
@@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
     Capability<[SPIRV_C_CooperativeMatrixKHR]>
   ];
 
+  // TODO: Add scope operand for MakePointer*. See #145485.
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
     SPIRV_Integer:$stride,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+    OptionalAttr<I32Attr>:$alignment
   );
 
   let results = (outs
@@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
                    "spirv::ConstantOp":$stride,
                    "spirv::CooperativeMatrixLayoutKHR":$layout), [{
       build($_builder, $_state, result, pointer, layout, stride,
-            spirv::MemoryAccessAttr{});
+            spirv::MemoryAccessAttr{}, IntegerAttr{});
     }]>
   ];
 }
@@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   }];
 
   let assemblyFormat = [{
-    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
       type(operands)
   }];
 
@@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
     Capability<[SPIRV_C_CooperativeMatrixKHR]>
   ];
 
+  // TODO: Add scope operand for MakePointer*. See #145485.
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_AnyCooperativeMatrix:$object,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
     SPIRV_Integer:$stride,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+    OptionalAttr<I32Attr>:$alignment
   );
 
   let results = (outs);
@@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
                    "spirv::ConstantOp":$stride,
                    "spirv::CooperativeMatrixLayoutKHR":$layout), [{
       build($_builder, $_state, pointer, object, layout, stride,
-            spirv::MemoryAccessAttr{});
+            spirv::MemoryAccessAttr{}, IntegerAttr{});
     }]>
   ];
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 2ff3efdc96a7f..fa20cc179f892 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -23,7 +23,8 @@ namespace mlir::spirv {
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
-                       spirv::MemoryAccessAttr memoryOperand) {
+                       spirv::MemoryAccessAttr memoryOperand,
+                       IntegerAttr alignment) {
   auto pointerType = cast<PointerType>(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa<ScalarType, VectorType>(pointeeType)) {
@@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
           "not compatible with memory operand 'MakePointerVisible'");
     }
 
-    // The 'Aligned' memory operand requires an alignment literal to follow,
-    // which needs to be implemented on the level of op parsing and
-    // (de-)serialization.
-    // TODO: Consider adding support for this attribute value.
-    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                  spirv::MemoryAccess::Aligned)) {
-      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
+    // #145485.
+
+    if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+        !alignment) {
+      return op->emitOpError("missing value for the 'Aligned' memory operand");
+    }
+
+    if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+        alignment) {
+      return op->emitOpError(
+          "found alignment attribute for non-'Aligned' memory operand");
     }
   }
 
@@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
-                                getResult().getType(), getMemoryOperandAttr());
+                                getResult().getType(), getMemoryOperandAttr(),
+                                getAlignmentAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
-                                getObject().getType(), getMemoryOperandAttr());
+                                getObject().getType(), getMemoryOperandAttr(),
+                                getAlignmentAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..56d477cca97b7 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_load_aligned
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
 // CHECK-LABEL: @cooperative_matrix_store
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
@@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBu
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_store_aligned
+spirv.func @cooperative_matrix_store_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
 // -----
 
 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuf
 // -----
 
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
@@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer
 // -----
 
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
@@ -179,7 +198,7 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageB
 
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
   spirv.Return
@@ -187,6 +206,15 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %str
 
 // -----
 
+spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <MakePointerVisible>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 153ff47937972..77949908e8883 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires
     spirv.Return
   }
 
+  // CHECK-LABEL: @cooperative_matrix_load_3
+  spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+      !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    spirv.Return
+  }
+
   // CHECK-LABEL: @cooperative_matrix_store_1
   spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                          %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
@@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
 
+    // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+      !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
     // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/145480


More information about the Mlir-commits mailing list