[Mlir-commits] [mlir] [mlir][gpu] Support arith.truncf in subgroup MMA elementwise ops (PR #182499)
Simone Pellegrini
llvmlistbot at llvm.org
Sun Feb 22 23:42:17 PST 2026
https://github.com/simpel01 updated https://github.com/llvm/llvm-project/pull/182499
>From 0374758a91a13a5b9831bb545dfcaefe1606f41e Mon Sep 17 00:00:00 2001
From: Simone Pellegrini <simone.pellegrini at arm.com>
Date: Thu, 19 Feb 2026 21:57:06 +0100
Subject: [PATCH] [mlir][gpu] Support arith.truncf in subgroup MMA elementwise
ops
This commit adds support for arith.truncf in the supported list of
elementwise ops for subgroup MMA ops, and enables lowering to SPIR-V.
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 4 +-
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 1 +
.../Conversion/VectorToGPU/VectorToGPU.cpp | 52 ++++++++-----------
.../wmma-ops-to-spirv-khr-coop-matrix.mlir | 10 ++--
.../VectorToGPU/vector-to-mma-ops.mlir | 22 ++++++++
5 files changed, 56 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 48de1a8bf118e..6b0fd1ed9080e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2121,6 +2121,7 @@ def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">;
+def GPU_ElementwiseOpTRUNCF : I32EnumAttrCase<"TRUNCF", 14, "truncf">;
def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix", [
@@ -2137,7 +2138,8 @@ def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
GPU_ElementwiseOpDivU,
GPU_ElementwiseOpNEGF,
GPU_ElementwiseOpNEGS,
- GPU_ElementwiseOpEXTF
+ GPU_ElementwiseOpEXTF,
+ GPU_ElementwiseOpTRUNCF
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index c4d9310874cc4..84c1febd600f6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -74,6 +74,7 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
return true;
case gpu::MMAElementwiseOp::EXTF:
+ case gpu::MMAElementwiseOp::TRUNCF:
builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
return true;
default:
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 53585fd34c504..115a00896d899 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -239,38 +239,29 @@ static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
}
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
+static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp) { return true; }
/// Return the MMA elementwise enum associated with `op` if it is supported.
/// Return `std::nullopt` otherwise.
static std::optional<gpu::MMAElementwiseOp>
convertElementwiseOpToMMA(Operation *op) {
- if (isa<arith::AddFOp>(op))
- return gpu::MMAElementwiseOp::ADDF;
- if (isa<arith::MulFOp>(op))
- return gpu::MMAElementwiseOp::MULF;
- if (isa<arith::SubFOp>(op))
- return gpu::MMAElementwiseOp::SUBF;
- if (isa<arith::MaximumFOp>(op))
- return gpu::MMAElementwiseOp::MAXF;
- if (isa<arith::MinimumFOp>(op))
- return gpu::MMAElementwiseOp::MINF;
- if (isa<arith::DivFOp>(op))
- return gpu::MMAElementwiseOp::DIVF;
- if (isa<arith::AddIOp>(op))
- return gpu::MMAElementwiseOp::ADDI;
- if (isa<arith::MulIOp>(op))
- return gpu::MMAElementwiseOp::MULI;
- if (isa<arith::SubIOp>(op))
- return gpu::MMAElementwiseOp::SUBI;
- if (isa<arith::DivSIOp>(op))
- return gpu::MMAElementwiseOp::DIVS;
- if (isa<arith::DivUIOp>(op))
- return gpu::MMAElementwiseOp::DIVU;
- if (isa<arith::NegFOp>(op))
- return gpu::MMAElementwiseOp::NEGATEF;
- if (isa<arith::ExtFOp>(op))
- return gpu::MMAElementwiseOp::EXTF;
- return std::nullopt;
+ using MMAEwO = gpu::MMAElementwiseOp;
+ return TypeSwitch<Operation *, std::optional<MMAEwO>>(op)
+ .Case<arith::AddFOp>([](auto) { return MMAEwO::ADDF; })
+ .Case<arith::AddIOp>([](auto) { return MMAEwO::ADDI; })
+ .Case<arith::DivFOp>([](auto) { return MMAEwO::DIVF; })
+ .Case<arith::DivSIOp>([](auto) { return MMAEwO::DIVS; })
+ .Case<arith::DivUIOp>([](auto) { return MMAEwO::DIVU; })
+ .Case<arith::ExtFOp>([](auto) { return MMAEwO::EXTF; })
+ .Case<arith::MaximumFOp>([](auto) { return MMAEwO::MAXF; })
+ .Case<arith::MinimumFOp>([](auto) { return MMAEwO::MINF; })
+ .Case<arith::MulFOp>([](auto) { return MMAEwO::MULF; })
+ .Case<arith::MulIOp>([](auto) { return MMAEwO::MULI; })
+ .Case<arith::NegFOp>([](auto) { return MMAEwO::NEGATEF; })
+ .Case<arith::SubFOp>([](auto) { return MMAEwO::SUBF; })
+ .Case<arith::SubIOp>([](auto) { return MMAEwO::SUBI; })
+ .Case<arith::TruncFOp>([](auto) { return MMAEwO::TRUNCF; })
+ .Default(std::nullopt);
}
/// Return true if the op is supported as elementwise op on MMAMatrix type.
@@ -329,6 +320,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
return fpExtendSupportsMMAMatrixType(fpExtend);
+ if (auto fpTrunc = dyn_cast<arith::TruncFOp>(op))
+ return fpTruncSupportsMMAMatrixType(fpTrunc);
return elementwiseSupportsMMAMatrixType(op);
}
@@ -1246,8 +1239,9 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
matrixOperands.push_back(it->second);
}
auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
- if (opType == gpu::MMAElementwiseOp::EXTF) {
- // The floating point extension case has a different result type.
+ if (opType == gpu::MMAElementwiseOp::EXTF ||
+ opType == gpu::MMAElementwiseOp::TRUNCF) {
+ // The floating point extension and truncation has a different result type.
auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
resultType = gpu::MMAMatrixType::get(resultType.getShape(),
vectorType.getElementType(),
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 6dba9c3486c7b..4b371118fde30 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -125,7 +125,7 @@ module attributes {
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
%B: !gpu.mma_matrix<16x16xf16, "COp">,
- %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+ %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_elementwise addf %A, %B :
@@ -143,11 +143,15 @@ module attributes {
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
%G = gpu.subgroup_mma_elementwise extf %F :
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %H = gpu.subgroup_mma_elementwise truncf %G :
+ (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
- gpu.subgroup_mma_store_matrix %G, %ptr[%i,%i] {leadDimension = 32 : index} :
- !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+ gpu.subgroup_mma_store_matrix %H, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index bf858789c7e07..32065035b6f21 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -479,6 +479,28 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
// -----
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f32_to_f16_write
+// CHECK: %[[COMPUTE:.+]] = gpu.subgroup_mma_compute
+// CHECK: %[[EXT:.+]] = gpu.subgroup_mma_elementwise truncf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[EXT]]
+func.func @cast_f32_to_f16_write(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>, %arg3: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+ %cast = arith.truncf %D : vector<16x16xf32> to vector<16x16xf16>
+ vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}
+
+// -----
+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
More information about the Mlir-commits
mailing list