[Mlir-commits] [mlir] [mlir][gpu] Support arith.truncf in subgroup MMA elementwise ops (PR #182499)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 20 05:53:06 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Simone Pellegrini (simpel01)
<details>
<summary>Changes</summary>
This commit adds support for arith.truncf in the supported list of elementwise ops for subgroup MMA ops, and enables lowering to SPIR-V.
---
Full diff: https://github.com/llvm/llvm-project/pull/182499.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+3-1)
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+1)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+8-2)
- (modified) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (+7-3)
- (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir (+22)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 48de1a8bf118e..6b0fd1ed9080e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2121,6 +2121,7 @@ def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">;
+def GPU_ElementwiseOpTRUNCF : I32EnumAttrCase<"TRUNCF", 14, "truncf">;
def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix", [
@@ -2137,7 +2138,8 @@ def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
GPU_ElementwiseOpDivU,
GPU_ElementwiseOpNEGF,
GPU_ElementwiseOpNEGS,
- GPU_ElementwiseOpEXTF
+ GPU_ElementwiseOpEXTF,
+ GPU_ElementwiseOpTRUNCF
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index c4d9310874cc4..84c1febd600f6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -74,6 +74,7 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
return true;
case gpu::MMAElementwiseOp::EXTF:
+ case gpu::MMAElementwiseOp::TRUNCF:
builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
return true;
default:
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 53585fd34c504..499a127936522 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -239,6 +239,7 @@ static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
}
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
+static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp) { return true; }
/// Return the MMA elementwise enum associated with `op` if it is supported.
/// Return `std::nullopt` otherwise.
@@ -270,6 +271,8 @@ convertElementwiseOpToMMA(Operation *op) {
return gpu::MMAElementwiseOp::NEGATEF;
if (isa<arith::ExtFOp>(op))
return gpu::MMAElementwiseOp::EXTF;
+ if (isa<arith::TruncFOp>(op))
+ return gpu::MMAElementwiseOp::TRUNCF;
return std::nullopt;
}
@@ -329,6 +332,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
return fpExtendSupportsMMAMatrixType(fpExtend);
+ if (auto fpTrunc = dyn_cast<arith::TruncFOp>(op))
+ return fpTruncSupportsMMAMatrixType(fpTrunc);
return elementwiseSupportsMMAMatrixType(op);
}
@@ -1246,8 +1251,9 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
matrixOperands.push_back(it->second);
}
auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
- if (opType == gpu::MMAElementwiseOp::EXTF) {
- // The floating point extension case has a different result type.
+ if (opType == gpu::MMAElementwiseOp::EXTF ||
+ opType == gpu::MMAElementwiseOp::TRUNCF) {
+ // The floating point extension and truncation has a different result type.
auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
resultType = gpu::MMAMatrixType::get(resultType.getShape(),
vectorType.getElementType(),
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 6dba9c3486c7b..4b371118fde30 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -125,7 +125,7 @@ module attributes {
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
%B: !gpu.mma_matrix<16x16xf16, "COp">,
- %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+ %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_elementwise addf %A, %B :
@@ -143,11 +143,15 @@ module attributes {
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
%G = gpu.subgroup_mma_elementwise extf %F :
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %H = gpu.subgroup_mma_elementwise truncf %G :
+ (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
- gpu.subgroup_mma_store_matrix %G, %ptr[%i,%i] {leadDimension = 32 : index} :
- !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+ gpu.subgroup_mma_store_matrix %H, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index bf858789c7e07..32065035b6f21 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -479,6 +479,28 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
// -----
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f32_to_f16_write
+// CHECK: %[[COMPUTE:.+]] = gpu.subgroup_mma_compute
+// CHECK: %[[EXT:.+]] = gpu.subgroup_mma_elementwise truncf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[EXT]]
+func.func @cast_f32_to_f16_write(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>, %arg3: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+ %cast = arith.truncf %D : vector<16x16xf32> to vector<16x16xf16>
+ vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}
+
+// -----
+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
``````````
</details>
https://github.com/llvm/llvm-project/pull/182499
More information about the Mlir-commits
mailing list