[Mlir-commits] [mlir] ee54401 - [mlir][spirv] Set signed coop matrix operands (#197932)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 17 23:56:09 PDT 2026


Author: Dhruv Chauhan
Date: 2026-05-18T07:56:04+01:00
New Revision: ee54401f2ab98924c64f59673441b1e1ede44e48

URL: https://github.com/llvm/llvm-project/commit/ee54401f2ab98924c64f59673441b1e1ede44e48
DIFF: https://github.com/llvm/llvm-project/commit/ee54401f2ab98924c64f59673441b1e1ede44e48.diff

LOG: [mlir][spirv] Set signed coop matrix operands (#197932)

Populate CooperativeMatrixOperandsKHR on KHR cooperative matrix
multiply-add based on the cooperative matrix element types. Signed
integer A, B, C and result matrices require their corresponding signed
component bits; otherwise SPIR-V treats those integer components as
unsigned.

Added lit test

Co-authored-by: Hsiangkai Wang <hsiangkai.wang at arm.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 84c1febd600f6..956134dfee55d 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -92,6 +92,30 @@ bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
   return isa<spirv::CooperativeMatrixType>(operands.front().getType());
 }
 
+static bool hasSignedIntegerElementType(spirv::CooperativeMatrixType type) {
+  auto elementType = dyn_cast<IntegerType>(type.getElementType());
+  return elementType && elementType.isSigned();
+}
+
+static spirv::CooperativeMatrixOperandsKHR
+getSignedCoopMatrixOperands(spirv::CooperativeMatrixType aType,
+                            spirv::CooperativeMatrixType bType,
+                            spirv::CooperativeMatrixType cType,
+                            spirv::CooperativeMatrixType resultType) {
+  using Operands = spirv::CooperativeMatrixOperandsKHR;
+
+  Operands operands = Operands::None;
+  if (hasSignedIntegerElementType(aType))
+    operands |= Operands::ASigned;
+  if (hasSignedIntegerElementType(bType))
+    operands |= Operands::BSigned;
+  if (hasSignedIntegerElementType(cType))
+    operands |= Operands::CSigned;
+  if (hasSignedIntegerElementType(resultType))
+    operands |= Operands::ResultSigned;
+  return operands;
+}
+
 namespace {
 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
 /// matrix ops.
@@ -342,9 +366,30 @@ struct WmmaMmaOpToSPIRVLowering final
   matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto aType =
+        dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpA().getType());
+    auto bType =
+        dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpB().getType());
+    auto cType =
+        dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpC().getType());
+    auto resultType =
+        getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
+            subgroupMmaComputeOp.getResult().getType());
+    if (!aType || !bType || !cType || !resultType)
+      return rewriter.notifyMatchFailure(subgroupMmaComputeOp,
+                                         "type conversion failed");
+
+    using Operands = spirv::CooperativeMatrixOperandsKHR;
+    Operands operands =
+        getSignedCoopMatrixOperands(aType, bType, cType, resultType);
+    spirv::CooperativeMatrixOperandsKHRAttr operandsAttr;
+    if (operands != Operands::None)
+      operandsAttr = spirv::CooperativeMatrixOperandsKHRAttr::get(
+          rewriter.getContext(), operands);
+
     rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
         subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
-        adaptor.getOpC());
+        adaptor.getOpC(), operandsAttr);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 4b371118fde30..30ea121a54ead 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -4,7 +4,7 @@
 module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
-    [Shader, CooperativeMatrixKHR, Float16],
+    [Shader, CooperativeMatrixKHR, Float16, Int8],
     [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
     #spirv.resource_limits<>>} {
 
@@ -76,6 +76,30 @@ module attributes {
       gpu.return
     }
 
+    // CHECK-LABEL: spirv.func @gpu_wmma_signed_i8_mma_op
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixA>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixB>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xi32, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_signed_i8_mma_op(
+      %A: !gpu.mma_matrix<16x16xsi8, "AOp">,
+      %B: !gpu.mma_matrix<16x16xsi8, "BOp">,
+      %C: !gpu.mma_matrix<16x16xi32, "COp">,
+      %ptr: memref<16x16xi32, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK:      spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <ASigned|BSigned> :
+      // CHECK-SAME:   !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixA>,
+      // CHECK-SAME:   !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixB>
+      // CHECK-SAME:   -> !spirv.coopmatrix<16x16xi32, Subgroup, MatrixAcc>
+      %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xsi8, "AOp">,
+                                                 !gpu.mma_matrix<16x16xsi8, "BOp">
+                                                 -> !gpu.mma_matrix<16x16xi32, "COp">
+      %i = arith.constant 0 : index
+      // CHECK:      spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
+      gpu.subgroup_mma_store_matrix %D, %ptr[%i, %i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32, #spirv.storage_class<StorageBuffer>>
+      gpu.return
+    }
+
     // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
     gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {


        


More information about the Mlir-commits mailing list