[Mlir-commits] [mlir] 22f9940 - [mlir][gpu] Support arith.truncf in subgroup MMA elementwise ops (#182499)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 25 02:05:06 PST 2026


Author: Simone Pellegrini
Date: 2026-02-25T10:05:01Z
New Revision: 22f9940d709247b2ea2e4d788974d5bd43c3dfba

URL: https://github.com/llvm/llvm-project/commit/22f9940d709247b2ea2e4d788974d5bd43c3dfba
DIFF: https://github.com/llvm/llvm-project/commit/22f9940d709247b2ea2e4d788974d5bd43c3dfba.diff

LOG: [mlir][gpu] Support arith.truncf in subgroup MMA elementwise ops (#182499)

This commit adds support for arith.truncf in the supported list of
elementwise ops for subgroup MMA ops, and enables lowering to SPIR-V.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 48de1a8bf118e..6b0fd1ed9080e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2121,6 +2121,7 @@ def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
 def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
 def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
 def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">;
+def GPU_ElementwiseOpTRUNCF : I32EnumAttrCase<"TRUNCF", 14, "truncf">;
 
 def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
   "elementwise operation to apply to mma matrix", [
@@ -2137,7 +2138,8 @@ def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
     GPU_ElementwiseOpDivU,
     GPU_ElementwiseOpNEGF,
     GPU_ElementwiseOpNEGS,
-    GPU_ElementwiseOpEXTF
+    GPU_ElementwiseOpEXTF,
+    GPU_ElementwiseOpTRUNCF
   ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::gpu";

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index c4d9310874cc4..84c1febd600f6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -74,6 +74,7 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
     builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
     return true;
   case gpu::MMAElementwiseOp::EXTF:
+  case gpu::MMAElementwiseOp::TRUNCF:
     builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
     return true;
   default:

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 53585fd34c504..975fe28399609 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -239,38 +239,29 @@ static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
 }
 
 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
+static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp) { return true; }
 
 /// Return the MMA elementwise enum associated with `op` if it is supported.
 /// Return `std::nullopt` otherwise.
 static std::optional<gpu::MMAElementwiseOp>
 convertElementwiseOpToMMA(Operation *op) {
-  if (isa<arith::AddFOp>(op))
-    return gpu::MMAElementwiseOp::ADDF;
-  if (isa<arith::MulFOp>(op))
-    return gpu::MMAElementwiseOp::MULF;
-  if (isa<arith::SubFOp>(op))
-    return gpu::MMAElementwiseOp::SUBF;
-  if (isa<arith::MaximumFOp>(op))
-    return gpu::MMAElementwiseOp::MAXF;
-  if (isa<arith::MinimumFOp>(op))
-    return gpu::MMAElementwiseOp::MINF;
-  if (isa<arith::DivFOp>(op))
-    return gpu::MMAElementwiseOp::DIVF;
-  if (isa<arith::AddIOp>(op))
-    return gpu::MMAElementwiseOp::ADDI;
-  if (isa<arith::MulIOp>(op))
-    return gpu::MMAElementwiseOp::MULI;
-  if (isa<arith::SubIOp>(op))
-    return gpu::MMAElementwiseOp::SUBI;
-  if (isa<arith::DivSIOp>(op))
-    return gpu::MMAElementwiseOp::DIVS;
-  if (isa<arith::DivUIOp>(op))
-    return gpu::MMAElementwiseOp::DIVU;
-  if (isa<arith::NegFOp>(op))
-    return gpu::MMAElementwiseOp::NEGATEF;
-  if (isa<arith::ExtFOp>(op))
-    return gpu::MMAElementwiseOp::EXTF;
-  return std::nullopt;
+  using MMAEwO = gpu::MMAElementwiseOp;
+  return TypeSwitch<Operation *, std::optional<MMAEwO>>(op)
+      .Case([](arith::AddFOp) { return MMAEwO::ADDF; })
+      .Case([](arith::AddIOp) { return MMAEwO::ADDI; })
+      .Case([](arith::DivFOp) { return MMAEwO::DIVF; })
+      .Case([](arith::DivSIOp) { return MMAEwO::DIVS; })
+      .Case([](arith::DivUIOp) { return MMAEwO::DIVU; })
+      .Case([](arith::ExtFOp) { return MMAEwO::EXTF; })
+      .Case([](arith::MaximumFOp) { return MMAEwO::MAXF; })
+      .Case([](arith::MinimumFOp) { return MMAEwO::MINF; })
+      .Case([](arith::MulFOp) { return MMAEwO::MULF; })
+      .Case([](arith::MulIOp) { return MMAEwO::MULI; })
+      .Case([](arith::NegFOp) { return MMAEwO::NEGATEF; })
+      .Case([](arith::SubFOp) { return MMAEwO::SUBF; })
+      .Case([](arith::SubIOp) { return MMAEwO::SUBI; })
+      .Case([](arith::TruncFOp) { return MMAEwO::TRUNCF; })
+      .Default(std::nullopt);
 }
 
 /// Return true if the op is supported as elementwise op on MMAMatrix type.
@@ -329,6 +320,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
   if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
     return fpExtendSupportsMMAMatrixType(fpExtend);
+  if (auto fpTrunc = dyn_cast<arith::TruncFOp>(op))
+    return fpTruncSupportsMMAMatrixType(fpTrunc);
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -1246,8 +1239,9 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
     matrixOperands.push_back(it->second);
   }
   auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
-  if (opType == gpu::MMAElementwiseOp::EXTF) {
-    // The floating point extension case has a 
diff erent result type.
+  if (opType == gpu::MMAElementwiseOp::EXTF ||
+      opType == gpu::MMAElementwiseOp::TRUNCF) {
+    // The floating point extension and truncation has a 
diff erent result type.
     auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
     resultType = gpu::MMAMatrixType::get(resultType.getShape(),
                                          vectorType.getElementType(),

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 6dba9c3486c7b..4b371118fde30 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -125,7 +125,7 @@ module attributes {
     // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
     gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
                                               %B: !gpu.mma_matrix<16x16xf16, "COp">,
-                                              %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+                                              %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
       %C = gpu.subgroup_mma_elementwise addf %A, %B :
@@ -143,11 +143,15 @@ module attributes {
       // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
       %G = gpu.subgroup_mma_elementwise extf %F :
         (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+      // CHECK:  {{%.*}} = spirv.FConvert {{%.*}} :
+      // CHECK-SAME: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %H = gpu.subgroup_mma_elementwise truncf %G :
+        (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
 
       %i = arith.constant 0 : index
       // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
-      gpu.subgroup_mma_store_matrix %G, %ptr[%i,%i] {leadDimension = 32 : index} :
-        !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+      gpu.subgroup_mma_store_matrix %H, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
       // CHECK: spirv.Return
       gpu.return
     }

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index bf858789c7e07..32065035b6f21 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -479,6 +479,28 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f32_to_f16_write
+//       CHECK:    %[[COMPUTE:.+]] = gpu.subgroup_mma_compute
+//       CHECK:    %[[EXT:.+]] = gpu.subgroup_mma_elementwise  truncf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:    gpu.subgroup_mma_store_matrix %[[EXT]]
+func.func @cast_f32_to_f16_write(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>, %arg3: memref<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+  %cast = arith.truncf %D : vector<16x16xf32> to vector<16x16xf16>
+  vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}
+
+// -----
+
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>


        


More information about the Mlir-commits mailing list