[Mlir-commits] [mlir] ee54401 - [mlir][spirv] Set signed coop matrix operands (#197932)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 17 23:56:09 PDT 2026
Author: Dhruv Chauhan
Date: 2026-05-18T07:56:04+01:00
New Revision: ee54401f2ab98924c64f59673441b1e1ede44e48
URL: https://github.com/llvm/llvm-project/commit/ee54401f2ab98924c64f59673441b1e1ede44e48
DIFF: https://github.com/llvm/llvm-project/commit/ee54401f2ab98924c64f59673441b1e1ede44e48.diff
LOG: [mlir][spirv] Set signed coop matrix operands (#197932)
Populate CooperativeMatrixOperandsKHR on KHR cooperative matrix
multiply-add based on the cooperative matrix element types. Signed
integer A, B, C and result matrices require their corresponding signed
component bits; otherwise SPIR-V treats those integer components as
unsigned.
Added lit test
Co-authored-by: Hsiangkai Wang <hsiangkai.wang at arm.com>
Added:
Modified:
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 84c1febd600f6..956134dfee55d 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -92,6 +92,30 @@ bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
return isa<spirv::CooperativeMatrixType>(operands.front().getType());
}
+static bool hasSignedIntegerElementType(spirv::CooperativeMatrixType type) {
+ auto elementType = dyn_cast<IntegerType>(type.getElementType());
+ return elementType && elementType.isSigned();
+}
+
+static spirv::CooperativeMatrixOperandsKHR
+getSignedCoopMatrixOperands(spirv::CooperativeMatrixType aType,
+ spirv::CooperativeMatrixType bType,
+ spirv::CooperativeMatrixType cType,
+ spirv::CooperativeMatrixType resultType) {
+ using Operands = spirv::CooperativeMatrixOperandsKHR;
+
+ Operands operands = Operands::None;
+ if (hasSignedIntegerElementType(aType))
+ operands |= Operands::ASigned;
+ if (hasSignedIntegerElementType(bType))
+ operands |= Operands::BSigned;
+ if (hasSignedIntegerElementType(cType))
+ operands |= Operands::CSigned;
+ if (hasSignedIntegerElementType(resultType))
+ operands |= Operands::ResultSigned;
+ return operands;
+}
+
namespace {
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
/// matrix ops.
@@ -342,9 +366,30 @@ struct WmmaMmaOpToSPIRVLowering final
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto aType =
+ dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpA().getType());
+ auto bType =
+ dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpB().getType());
+ auto cType =
+ dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpC().getType());
+ auto resultType =
+ getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
+ subgroupMmaComputeOp.getResult().getType());
+ if (!aType || !bType || !cType || !resultType)
+ return rewriter.notifyMatchFailure(subgroupMmaComputeOp,
+ "type conversion failed");
+
+ using Operands = spirv::CooperativeMatrixOperandsKHR;
+ Operands operands =
+ getSignedCoopMatrixOperands(aType, bType, cType, resultType);
+ spirv::CooperativeMatrixOperandsKHRAttr operandsAttr;
+ if (operands != Operands::None)
+ operandsAttr = spirv::CooperativeMatrixOperandsKHRAttr::get(
+ rewriter.getContext(), operands);
+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
- adaptor.getOpC());
+ adaptor.getOpC(), operandsAttr);
return success();
}
};
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 4b371118fde30..30ea121a54ead 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -4,7 +4,7 @@
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
- [Shader, CooperativeMatrixKHR, Float16],
+ [Shader, CooperativeMatrixKHR, Float16, Int8],
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
#spirv.resource_limits<>>} {
@@ -76,6 +76,30 @@ module attributes {
gpu.return
}
+ // CHECK-LABEL: spirv.func @gpu_wmma_signed_i8_mma_op
+ // CHECK-SAME: !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixA>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixB>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xi32, Subgroup, MatrixAcc>
+ gpu.func @gpu_wmma_signed_i8_mma_op(
+ %A: !gpu.mma_matrix<16x16xsi8, "AOp">,
+ %B: !gpu.mma_matrix<16x16xsi8, "BOp">,
+ %C: !gpu.mma_matrix<16x16xi32, "COp">,
+ %ptr: memref<16x16xi32, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <ASigned|BSigned> :
+ // CHECK-SAME: !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixA>,
+ // CHECK-SAME: !spirv.coopmatrix<16x16xsi8, Subgroup, MatrixB>
+ // CHECK-SAME: -> !spirv.coopmatrix<16x16xi32, Subgroup, MatrixAcc>
+ %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xsi8, "AOp">,
+ !gpu.mma_matrix<16x16xsi8, "BOp">
+ -> !gpu.mma_matrix<16x16xi32, "COp">
+ %i = arith.constant 0 : index
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
+ gpu.subgroup_mma_store_matrix %D, %ptr[%i, %i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32, #spirv.storage_class<StorageBuffer>>
+ gpu.return
+ }
+
// CHECK-LABEL: spirv.func @gpu_wmma_constant_op
gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
More information about the Mlir-commits
mailing list