[Mlir-commits] [mlir] [ROCDL] Added rocdl.cvt.scale.sr.pk8 ops (PR #162244)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 7 01:58:17 PDT 2025
https://github.com/ravil-mobile created https://github.com/llvm/llvm-project/pull/162244
This patch introduces some missing FP conversion instructions in the ROCDL dialect for the GFX1250 arch.
Specifically:
Downscaling 8x packed F16, Bf16, Fp32 values to Fp8, Bf8, Fp4 with stochastic rounding
Tests:
Added lit-tests to check MLIR -> LLVM lowering
>From 4f5d0ff81b6a149e8b90b9050079a71d81793896 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 7 Oct 2025 08:53:40 +0000
Subject: [PATCH] [ROCDL] Added rocdl.cvt.scale.sr.pk8 ops
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 18 +++++++++++
mlir/test/Dialect/LLVMIR/rocdl.mlir | 33 ++++++++++++++++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 33 ++++++++++++++++++++
3 files changed, 84 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 29001e26eaaaf..db1b7e3af62fd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -1029,6 +1029,24 @@ foreach smallT = [
attr-dict $src `,` $scale `:` type($res)
}];
}
+
+
+ def ROCDL_CvtScaleF32SrPk8 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk8." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name # " with stochastic rounding";
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so and apply stochastic rounding. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `:` type($res)
+ }];
+ }
} // foreach largeT
} // foreach smallTOp
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 6134695e9ced6..a88b59aeb61b2 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
// -----
+// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
+llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %seed: i32,
+ %scale: f32) {
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32
+ %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32
+ %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32
+ %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16
+ %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16
+ %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16
+ %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16
+ %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16
+ %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16
+ %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
+
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: rocdl.cvt.scale.pk16
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 00ee6b795c43a..1c0c2eba002aa 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>
llvm.return
}
+// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
+// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %seed: i32,
+ %scale: f32) {
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
+
+ llvm.return
+}
+
+
// CHECK-LABEL: @rocdl.cvt.scale.pk16
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
More information about the Mlir-commits
mailing list