[Mlir-commits] [mlir] [MLIR][ROCDL] Add Scale Convert Packed FP4 <-> F32/BF16/F16 Support for GFX950 (PR #140676)
Tim Gymnich
llvmlistbot at llvm.org
Mon May 19 22:55:44 PDT 2025
https://github.com/tgymnich updated https://github.com/llvm/llvm-project/pull/140676
>From 03e40d995506d18c43d255fa284cc48630b9e03e Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 19 May 2025 12:52:26 +0000
Subject: [PATCH 1/2] [MLIR][ROCDL] Add Scale Convert Packed FP4 <->
F32/BF16/F16 Support for GFX950
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 118 +++++++++++++++++++
mlir/test/Dialect/LLVMIR/rocdl.mlir | 19 +++
mlir/test/Target/LLVMIR/rocdl.mlir | 19 +++
3 files changed, 156 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6fb9e3aba1f0a..95aed716d8ff8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -713,6 +713,10 @@ def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getI16Type())">;
+def ROCDL_V2F32Type : FixedVectorOfLengthAndType<[2], [F32]>,
+ BuildableType<"::mlir::VectorType::get("
+ "{2},$_builder.getF32Type())">;
+
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getF16Type())">;
@@ -1005,6 +1009,120 @@ def ROCDL_CvtScaleF32SrBf8F32Op :
}];
}
+//===---------------------------------------------------------------------===//
+// 4-bit float scale intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtScaleF32PkFp4F32Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f32", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, F32:$src0, F32:$src1, F32: $scale, I32:$byteSel)> {
+ let summary = "Convert f32 to packed fp4 and scale";
+ let description = [{ Convert `src` based on $byteSe to packed fp4, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $src1 `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkFp4F16Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, F32: $scale, I32:$byteSel)> {
+ let summary = "Convert f16 to packed fp4 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkFp4Bf16Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.fp4.bf16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, F32: $scale, I32:$byteSel)> {
+ let summary = "Convert bf16 to packed fp4 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F32Op :
+ ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f32", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2F32Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert f32 to packed fp4 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4F16Op :
+ ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert f16 to packed fp4 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32SrPkFp4Bf16Op :
+ ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.bf16", [], [], [Pure], 1>,
+ Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+ let summary = "Scale and convert bf16 to packed fp4 using stochastic rounding";
+ let description = [{
+ Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
+ using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkF32Fp4Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp4", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+ let summary = "Convert fp4 to packed f32 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed f32, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+
+def ROCDL_CvtScaleF32PkF16Fp4Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp4", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+ let summary = "Convert fp4 to packed f16 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed f16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Fp4Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp4", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
+ let summary = "Convert fp4 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $byteSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+ }];
+}
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index fbde993891342..5b12b23a9b130 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -844,6 +844,25 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
llvm.return %source_scaled : vector<2xi16>
}
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+ // CHECK-LABEL: @rocdl_4bit_packed_floats
+ // CHECK: rocdl.cvt.scalef32.pk.fp4.f32
+ // CHECK: rocdl.cvt.scalef32.pk.fp4.f16
+ // CHECK: rocdl.cvt.scalef32.pk.fp4.bf16
+ // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
+ // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16
+ // CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ %scale = llvm.mlir.constant(1.0 : f32) : f32
+ %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+ %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+ %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+ %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
+ %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+ %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+ llvm.return %sr3 : i32
+}
+
llvm.func @rocdl.s.waitcnt() {
// CHECK-LABEL: rocdl.s.waitcnt
// CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index b37f0da361950..7057dc3b6f66d 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1145,6 +1145,25 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
llvm.return %source : vector<2xf16>
}
+llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
+ // CHECK-LABEL: @rocdl_4bit_packed_floats
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %0, float %1, float %2, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %8, <2 x half> %4, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %9, <2 x bfloat> %5, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ %scale = llvm.mlir.constant(1.0 : f32) : f32
+ %pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
+ %pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
+ %pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
+ %sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
+ %sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
+ %sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+ llvm.return %sr3 : i32
+}
+
llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
// CHECK-LABEL: @rocdl_atomic_attrs
// CHECK: atomicrmw
>From 6fc0d1316cb7366420b7f5eac27fd3d725f1edb5 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Tue, 20 May 2025 05:00:50 +0000
Subject: [PATCH 2/2] update tests
---
mlir/test/Dialect/LLVMIR/rocdl.mlir | 6 ++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 6 ++++++
2 files changed, 12 insertions(+)
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 5b12b23a9b130..84fa29ee2d8a1 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -852,6 +852,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
+ // CHECK: rocdl.cvt.scalef32.pk.f32.fp4
+ // CHECK: rocdl.cvt.scalef32.pk.f16.fp4
+ // CHECK: rocdl.cvt.scalef32.pk.bf16.fp4
%c0 = llvm.mlir.constant(0 : i32) : i32
%scale = llvm.mlir.constant(1.0 : f32) : f32
%pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
@@ -860,6 +863,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
%sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
%sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
%sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+ %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+ %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+ %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
llvm.return %sr3 : i32
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7057dc3b6f66d..73862ac112da8 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1153,6 +1153,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
+ // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %0, float 1.000000e+00, i32 0)
+ // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %0, float 1.000000e+00, i32 0)
+ // CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %0, float 1.000000e+00, i32 0)
%c0 = llvm.mlir.constant(0 : i32) : i32
%scale = llvm.mlir.constant(1.0 : f32) : f32
%pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
@@ -1161,6 +1164,9 @@ llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %so
%sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
%sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
%sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
+ %pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
+ %pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
+ %pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
llvm.return %sr3 : i32
}
More information about the Mlir-commits
mailing list