[Mlir-commits] [mlir] [MLIR][ROCDL] Add Scale Convert Packed (B)FP8 <-> (B)F16 Support for GFX950 (PR #130300)

Corbin Robeck llvmlistbot at llvm.org
Fri Mar 7 10:06:23 PST 2025


https://github.com/CRobeck updated https://github.com/llvm/llvm-project/pull/130300

>From 08e177d740eca366994bd7228cdd80d9dab66a88 Mon Sep 17 00:00:00 2001
From: Corbin Robeck <corbin.robeck at amd.com>
Date: Tue, 4 Mar 2025 21:01:28 +0000
Subject: [PATCH] add f16 convert instructions

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 180 ++++++++++++++++++-
 mlir/test/Dialect/LLVMIR/rocdl.mlir          |  30 +++-
 mlir/test/Target/LLVMIR/rocdl.mlir           |  34 +++-
 3 files changed, 235 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 18fec95f700c4..f194e70ee275b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -652,6 +652,20 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
   }];
 }
 
+def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getI16Type())">;
+
+def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getF16Type())">;
+
+def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getBF16Type())">;
+
+// TODO: The word and byte selectors are immarg in LLVM 
+// update to be attributes in MLIR
 //===---------------------------------------------------------------------===//
 // 16-bit float intrinsics
 //===---------------------------------------------------------------------===//
@@ -667,10 +681,168 @@ def ROCDL_CvtPkRtz:
   }];
 }
 
+def ROCDL_CvtScaleF32PkFp8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert f16 to packed fp8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkFp8Bf16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert packed bf16 to packed fp8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScaleF32PkBf8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert f16 to packed bf8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed bf8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScaleF32PkBf8Bf16 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert bf16 to packed bf8";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed bf8.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrFp8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrBf8F16 :
+    ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrFp8Bf16 :
+    ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrBf8Bf16:
+    ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
+    let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
+    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkF16Fp8 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert fp8 to packed f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale` 
+    then convert to packed f16.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkF16Bf8 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+    let summary = "Scale and convert bf8 to packed f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+    then convert to packed f16.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF16Fp8 :
+    ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
+    let summary = "Scale and convert fp8 to f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+    then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF16Bf8 :
+    ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
+    Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
+    let summary = "Scale and convert fp8 to f16";
+    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
+    then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // 32-bit float intrinsics
 //===---------------------------------------------------------------------===//
-def ROCDL_CvtScalePkF32Fp8 :
+def ROCDL_CvtScale32PkF32Fp8 :
     ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
     Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
   let summary = "Scale and convert packed fp8 to packed f32";
@@ -682,7 +854,7 @@ def ROCDL_CvtScalePkF32Fp8 :
     attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
   }];
 }
-def ROCDL_CvtScalePkF32Bf8 :
+def ROCDL_CvtScale32PkF32Bf8 :
     ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
     Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
   let summary = "Scale and convert packed bf8 to packed f32";
@@ -697,10 +869,6 @@ def ROCDL_CvtScalePkF32Bf8 :
 //===---------------------------------------------------------------------===//
 // 8-bit float scale intrinsics
 //===---------------------------------------------------------------------===//
-def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
-                        BuildableType<"::mlir::VectorType::get("
-                          "{2},$_builder.getI16Type())">;
-
 def ROCDL_CvtScaleF32PkFp8F32:
     ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 05469b64a8083..bc917041998d8 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -759,19 +759,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
-llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: rocdl.cvt.f32.bf8
 // CHECK: rocdl.cvt.f32.fp8
 // CHECK: rocdl.cvt.scalef32.f32.bf8
 // CHECK: rocdl.cvt.scalef32.f32.fp8
+// CHECK: rocdl.cvt.scalef32.pk.f16.bf8 
+// CHECK: rocdl.cvt.scalef32.pk.f16.fp8 
+// CHECK: rocdl.cvt.scalef32.f16.fp8
+// CHECK: rocdl.cvt.scalef32.f16.bf8
 // CHECK: rocdl.cvt.pk.bf8.f32
 // CHECK: rocdl.cvt.pk.fp8.f32
 // CHECK: rocdl.cvt.sr.bf8.f32
 // CHECK: rocdl.cvt.sr.fp8.f32
 // CHECK: rocdl.cvt.scalef32.sr.fp8.f32
+// CHECK: rocdl.cvt.scalef32.sr.fp8.f16
+// CHECK: rocdl.cvt.scalef32.sr.fp8.bf16
 // CHECK: rocdl.cvt.sr.bf8.f32
 // CHECK: rocdl.cvt.scalef32.sr.bf8.f32
+// CHECK: rocdl.cvt.scalef32.sr.bf8.f16
+// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
 // CHECK: rocdl.cvt.scalef32.pk.f32.fp8
 // CHECK: rocdl.cvt.scalef32.pk.f32.bf8
   %c0 = llvm.mlir.constant(0 : i32) : i32
@@ -783,13 +791,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
   %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
   %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
+  %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
+  %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
+  %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
+  %v6  = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
   %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
   %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
   %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
   %source6_scaled  = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_bfloat =  rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
   %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
   %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
   llvm.return %source5 : i32
@@ -805,6 +821,18 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
   llvm.return %source_scaled : vector<2xi16>
 }
 
+llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
+// CHECK-LABEL: @rocdl_v2f16_v2i16
+// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
+  %c0 = llvm.mlir.constant(1.0 : f32) : f32
+  %false = llvm.mlir.constant(false) : i1
+  %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  llvm.return %source_scaled : vector<2xi16>
+}
+
 llvm.func @rocdl.s.waitcnt() {
   // CHECK-LABEL: rocdl.s.waitcnt
   // CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 981ef4848535c..11f2faa2761ff 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1032,22 +1032,29 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
   llvm.return %val : i32
 }
 
-llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
 // CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i32 0)
 // CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
 // CHECK: call float @llvm.amdgcn.cvt.scalef32.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i32 0)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
+// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
 // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.fp8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
 // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f32(i32 %{{.+}}, float %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.f16(i32 %{{.+}}, half %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
+// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
 // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
 // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
-
   %c0 = llvm.mlir.constant(0 : i32) : i32
   %c2 = llvm.mlir.constant(2 : i32) : i32
   %c3 = llvm.mlir.constant(3 : i32) : i32
@@ -1057,13 +1064,21 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
   %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
   %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
+  %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
+  %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
+  %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
+  %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
   %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
   %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
+  %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
   %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
   %source6_scaled  = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
+  %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
   %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
   %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
   llvm.return %source5 : i32
@@ -1080,6 +1095,21 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
   llvm.return %source_scaled : vector<2xi16>
 }
 
+llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %old: vector<2xi16>) -> vector<2xi16> {
+// CHECK-LABEL: @rocdl_v2f16_v2i16
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
+// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
+  %c0 = llvm.mlir.constant(1.0 : f32) : f32
+  %false = llvm.mlir.constant(false) : i1
+  %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
+  %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+  llvm.return %source_scaled : vector<2xi16>
+}
+
 llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
   // CHECK-LABEL: @rocdl_16bit_packed_floats
   // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})



More information about the Mlir-commits mailing list