[Mlir-commits] [mlir] [MLIR][ROCDL] Add Scale Convert Packed FP4 <-> F32/BF16/F16 Support for GFX950 (PR #140676)

Tim Gymnich llvmlistbot at llvm.org
Mon May 19 22:55:44 PDT 2025


https://github.com/tgymnich updated https://github.com/llvm/llvm-project/pull/140676

>From 03e40d995506d18c43d255fa284cc48630b9e03e Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 19 May 2025 12:52:26 +0000
Subject: [PATCH 1/2] [MLIR][ROCDL] Add Scale Convert Packed FP4 <->
 F32/BF16/F16 Support for GFX950

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 118 +++++++++++++++++++
 mlir/test/Dialect/LLVMIR/rocdl.mlir          |  19 +++
 mlir/test/Target/LLVMIR/rocdl.mlir           |  19 +++
 3 files changed, 156 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6fb9e3aba1f0a..95aed716d8ff8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -713,6 +713,10 @@ def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
                         BuildableType<"::mlir::VectorType::get("
                           "{2},$_builder.getI16Type())">;
 
+def ROCDL_V2F32Type : FixedVectorOfLengthAndType<[2], [F32]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getF32Type())">;
+
 def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
                         BuildableType<"::mlir::VectorType::get("
                           "{2},$_builder.getF16Type())">;
@@ -1005,6 +1009,120 @@ def ROCDL_CvtScaleF32SrBf8F32Op :
   }];
 }
 
+//===---------------------------------------------------------------------===//
+// 4-bit float scale intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtScaleF32PkFp4F32Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F32:$src0, F32:$src1, F32: $scale, I32:$byteSel)> {
+    let summary = "Convert f32 to packed fp4 and scale";
+    let description = [{ Convert `src` based on $byteSe to packed fp4, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src0 `,` $src1 `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkFp4F16Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, F32: $scale, I32:$byteSel)> {
+    let summary = "Convert f16 to packed fp4 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkFp4Bf16Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp4.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, F32: $scale, I32:$byteSel)> {
+    let summary = "Convert bf16 to packed fp4 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F32Op :
+    ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2F32Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f32 to packed fp4 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F16Op :
+    ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f16 to packed fp4 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4Bf16Op :
+    ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert bf16 to packed fp4 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkF32Fp4Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp4", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+    let summary = "Convert fp4 to packed f32 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed f32, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScaleF32PkF16Fp4Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp4", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+    let summary = "Convert fp4 to packed f16 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed f16, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Fp4Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp4", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+    let summary = "Convert fp4 to packed bf16 and scale";
+    let description = [{ Convert `src` based on $byteSel to packed bf16, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index fbde993891342..5b12b23a9b130 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -844,6 +844,25 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
   llvm.return %source_scaled : vector<2xi16>
 }
 
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+  // CHECK-LABEL: @rocdl_4bit_packed_floats
+  // CHECK: rocdl.cvt.scalef32.pk.fp4.f32
+  // CHECK: rocdl.cvt.scalef32.pk.fp4.f16
+  // CHECK: rocdl.cvt.scalef32.pk.fp4.bf16
+  // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
+  // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16 
+  // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  %scale = llvm.mlir.constant(1.0 : f32) : f32
+  %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+  %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+  %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+  %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0]  : i32
+  %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+  %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+  llvm.return %sr3 : i32
+}
+
 llvm.func @rocdl.s.waitcnt() {
   // CHECK-LABEL: rocdl.s.waitcnt
   // CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index b37f0da361950..7057dc3b6f66d 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1145,6 +1145,25 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
   llvm.return %source : vector<2xf16>
 }
 
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+  // CHECK-LABEL: @rocdl_4bit_packed_floats
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %0, float %1, float %2, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %8, <2 x half> %4, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %9, <2 x bfloat> %5, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  %scale = llvm.mlir.constant(1.0 : f32) : f32
+  %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+  %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+  %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+  %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0]  : i32
+  %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+  %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+  llvm.return %sr3 : i32
+}
+
 llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
   // CHECK-LABEL: @rocdl_atomic_attrs
   // CHECK: atomicrmw

>From 6fc0d1316cb7366420b7f5eac27fd3d725f1edb5 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Tue, 20 May 2025 05:00:50 +0000
Subject: [PATCH 2/2] update tests

---
 mlir/test/Dialect/LLVMIR/rocdl.mlir | 6 ++++++
 mlir/test/Target/LLVMIR/rocdl.mlir  | 6 ++++++
 2 files changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 5b12b23a9b130..84fa29ee2d8a1 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -852,6 +852,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
   // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
   // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16 
   // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
+  // CHECK: rocdl.cvt.scalef32.pk.f32.fp4
+  // CHECK: rocdl.cvt.scalef32.pk.f16.fp4
+  // CHECK: rocdl.cvt.scalef32.pk.bf16.fp4
   %c0 = llvm.mlir.constant(0 : i32) : i32
   %scale = llvm.mlir.constant(1.0 : f32) : f32
   %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
@@ -860,6 +863,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
   %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0]  : i32
   %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
   %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+  %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+  %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+  %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
   llvm.return %sr3 : i32
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7057dc3b6f66d..73862ac112da8 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1153,6 +1153,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
   // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
   // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
   // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
+  // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %0, float 1.000000e+00, i32 0)
+  // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %0, float 1.000000e+00, i32 0)
+  // CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %0, float 1.000000e+00, i32 0)
   %c0 = llvm.mlir.constant(0 : i32) : i32
   %scale = llvm.mlir.constant(1.0 : f32) : f32
   %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
@@ -1161,6 +1164,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
   %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0]  : i32
   %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
   %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+  %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+  %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+  %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
   llvm.return %sr3 : i32
 }
 



More information about the Mlir-commits mailing list