[Mlir-commits] [mlir] [mlir][gpu] Support arith.truncf in subgroup MMA elementwise ops (PR #182499)

Simone Pellegrini llvmlistbot at llvm.org
Sun Feb 22 23:42:17 PST 2026


https://github.com/simpel01 updated https://github.com/llvm/llvm-project/pull/182499

>From 0374758a91a13a5b9831bb545dfcaefe1606f41e Mon Sep 17 00:00:00 2001
From: Simone Pellegrini <simone.pellegrini at arm.com>
Date: Thu, 19 Feb 2026 21:57:06 +0100
Subject: [PATCH] [mlir][gpu] Support arith.truncf in subgroup MMA elementwise
 ops

This commit adds support for arith.truncf in the supported list of
elementwise ops for subgroup MMA ops, and enables lowering to SPIR-V.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  4 +-
 .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp  |  1 +
 .../Conversion/VectorToGPU/VectorToGPU.cpp    | 52 ++++++++-----------
 .../wmma-ops-to-spirv-khr-coop-matrix.mlir    | 10 ++--
 .../VectorToGPU/vector-to-mma-ops.mlir        | 22 ++++++++
 5 files changed, 56 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 48de1a8bf118e..6b0fd1ed9080e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2121,6 +2121,7 @@ def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
 def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
 def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
 def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">;
+def GPU_ElementwiseOpTRUNCF : I32EnumAttrCase<"TRUNCF", 14, "truncf">;
 
 def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
   "elementwise operation to apply to mma matrix", [
@@ -2137,7 +2138,8 @@ def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
     GPU_ElementwiseOpDivU,
     GPU_ElementwiseOpNEGF,
     GPU_ElementwiseOpNEGS,
-    GPU_ElementwiseOpEXTF
+    GPU_ElementwiseOpEXTF,
+    GPU_ElementwiseOpTRUNCF
   ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::gpu";
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index c4d9310874cc4..84c1febd600f6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -74,6 +74,7 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
     builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
     return true;
   case gpu::MMAElementwiseOp::EXTF:
+  case gpu::MMAElementwiseOp::TRUNCF:
     builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
     return true;
   default:
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 53585fd34c504..115a00896d899 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -239,38 +239,29 @@ static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
 }
 
 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
+static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp) { return true; }
 
 /// Return the MMA elementwise enum associated with `op` if it is supported.
 /// Return `std::nullopt` otherwise.
 static std::optional<gpu::MMAElementwiseOp>
 convertElementwiseOpToMMA(Operation *op) {
-  if (isa<arith::AddFOp>(op))
-    return gpu::MMAElementwiseOp::ADDF;
-  if (isa<arith::MulFOp>(op))
-    return gpu::MMAElementwiseOp::MULF;
-  if (isa<arith::SubFOp>(op))
-    return gpu::MMAElementwiseOp::SUBF;
-  if (isa<arith::MaximumFOp>(op))
-    return gpu::MMAElementwiseOp::MAXF;
-  if (isa<arith::MinimumFOp>(op))
-    return gpu::MMAElementwiseOp::MINF;
-  if (isa<arith::DivFOp>(op))
-    return gpu::MMAElementwiseOp::DIVF;
-  if (isa<arith::AddIOp>(op))
-    return gpu::MMAElementwiseOp::ADDI;
-  if (isa<arith::MulIOp>(op))
-    return gpu::MMAElementwiseOp::MULI;
-  if (isa<arith::SubIOp>(op))
-    return gpu::MMAElementwiseOp::SUBI;
-  if (isa<arith::DivSIOp>(op))
-    return gpu::MMAElementwiseOp::DIVS;
-  if (isa<arith::DivUIOp>(op))
-    return gpu::MMAElementwiseOp::DIVU;
-  if (isa<arith::NegFOp>(op))
-    return gpu::MMAElementwiseOp::NEGATEF;
-  if (isa<arith::ExtFOp>(op))
-    return gpu::MMAElementwiseOp::EXTF;
-  return std::nullopt;
+  using MMAEwO = gpu::MMAElementwiseOp;
+  return TypeSwitch<Operation *, std::optional<MMAEwO>>(op)
+      .Case<arith::AddFOp>([](auto) { return MMAEwO::ADDF; })
+      .Case<arith::AddIOp>([](auto) { return MMAEwO::ADDI; })
+      .Case<arith::DivFOp>([](auto) { return MMAEwO::DIVF; })
+      .Case<arith::DivSIOp>([](auto) { return MMAEwO::DIVS; })
+      .Case<arith::DivUIOp>([](auto) { return MMAEwO::DIVU; })
+      .Case<arith::ExtFOp>([](auto) { return MMAEwO::EXTF; })
+      .Case<arith::MaximumFOp>([](auto) { return MMAEwO::MAXF; })
+      .Case<arith::MinimumFOp>([](auto) { return MMAEwO::MINF; })
+      .Case<arith::MulFOp>([](auto) { return MMAEwO::MULF; })
+      .Case<arith::MulIOp>([](auto) { return MMAEwO::MULI; })
+      .Case<arith::NegFOp>([](auto) { return MMAEwO::NEGATEF; })
+      .Case<arith::SubFOp>([](auto) { return MMAEwO::SUBF; })
+      .Case<arith::SubIOp>([](auto) { return MMAEwO::SUBI; })
+      .Case<arith::TruncFOp>([](auto) { return MMAEwO::TRUNCF; })
+      .Default(std::nullopt);
 }
 
 /// Return true if the op is supported as elementwise op on MMAMatrix type.
@@ -329,6 +320,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
     return fpExtendSupportsMMAMatrixType(fpExtend);
+  if (auto fpTrunc = dyn_cast<arith::TruncFOp>(op))
+    return fpTruncSupportsMMAMatrixType(fpTrunc);
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -1246,8 +1239,9 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
     matrixOperands.push_back(it->second);
   }
   auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
-  if (opType == gpu::MMAElementwiseOp::EXTF) {
-    // The floating point extension case has a different result type.
+  if (opType == gpu::MMAElementwiseOp::EXTF ||
+      opType == gpu::MMAElementwiseOp::TRUNCF) {
+    // The floating point extension and truncation has a different result type.
     auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
                                          vectorType.getElementType(),
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 6dba9c3486c7b..4b371118fde30 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -125,7 +125,7 @@ module attributes {
     // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
     gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
                                               %B: !gpu.mma_matrix<16x16xf16, "COp">,
-                                              %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+                                              %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
       %C = gpu.subgroup_mma_elementwise addf %A, %B :
@@ -143,11 +143,15 @@ module attributes {
       // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
       %G = gpu.subgroup_mma_elementwise extf %F :
         (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+      // CHECK:  {{%.*}} = spirv.FConvert {{%.*}} :
+      // CHECK-SAME: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %H = gpu.subgroup_mma_elementwise truncf %G :
+        (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
 
       %i = arith.constant 0 : index
       // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
-      gpu.subgroup_mma_store_matrix %G, %ptr[%i,%i] {leadDimension = 32 : index} :
-        !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+      gpu.subgroup_mma_store_matrix %H, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
       // CHECK: spirv.Return
       gpu.return
     }
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index bf858789c7e07..32065035b6f21 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -479,6 +479,28 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f32_to_f16_write
+//       CHECK:    %[[COMPUTE:.+]] = gpu.subgroup_mma_compute
+//       CHECK:    %[[EXT:.+]] = gpu.subgroup_mma_elementwise  truncf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:    gpu.subgroup_mma_store_matrix %[[EXT]]
+func.func @cast_f32_to_f16_write(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>, %arg3: memref<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+  %cast = arith.truncf %D : vector<16x16xf32> to vector<16x16xf16>
+  vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}
+
+// -----
+
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>



More information about the Mlir-commits mailing list