[Mlir-commits] [mlir] [AMD][ROCDL][AMDGPU] Support packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 (PR #131850)

Yi Qian llvmlistbot at llvm.org
Tue Mar 18 21:46:15 PDT 2025


https://github.com/yiqian1 updated https://github.com/llvm/llvm-project/pull/131850

>From e7e14efa6df888889b576027305d3a083ca5fd5b Mon Sep 17 00:00:00 2001
From: Yi Qian <yi.qian at amd.com>
Date: Sat, 15 Mar 2025 05:07:21 +0000
Subject: [PATCH 1/2] [AMD][ROCDL][AMDGPU] Support packed conversions
 fp8/bf8->bf16 and fp8/bf8->fp32

Add packed conversions fp8/bf8->bf16 in gfx950 and fp8/bf8->fp32 in gfx942
Update amdgpu.ext_packed_fp8 lowering to use ROCDL CvtPkF32Fp8Op
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  13 +-
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  | 145 ++++++++++++------
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  10 +-
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  32 ++--
 .../AMDGPUToROCDL/8-bit-floats-ocp.mlir       |  32 ++--
 .../AMDGPUToROCDL/8-bit-floats.mlir           |  32 ++--
 .../ArithToAMDGPU/8-bit-floats-ocp.mlir       |  90 +++++------
 .../ArithToAMDGPU/8-bit-floats.mlir           |  94 +++++-------
 mlir/test/Dialect/AMDGPU/ops.mlir             |   6 +-
 mlir/test/Dialect/LLVMIR/rocdl.mlir           |   8 +
 mlir/test/Target/LLVMIR/rocdl.mlir            |   4 +
 11 files changed, 251 insertions(+), 215 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3acc383923ca8..3ed6e84d19044 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -85,11 +85,12 @@ def AMDGPU_ExtPackedFp8Op :
     AMDGPU_Op<"ext_packed_fp8", [Pure]>,
     Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
         VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
-      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
-    Results<(outs F32:$res)> {
-  let summary = "Extend one of a vector of packed fp8 values to a float";
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex)>,
+    Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> {
+  let summary = "Extend a vector of packed fp8 values to two floats";
+
   let description = [{
-    Extend the value `source[index]` to a 32-bit float and return it.
+    Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them.
 
     This rather unusual signature arises from the fact that AMD GPUs cannot
     easily work with sub 32-bit quantities, so the compiler intrinsics for
@@ -97,11 +98,11 @@ def AMDGPU_ExtPackedFp8Op :
     this operation) take packed vectors of 4 such floats.
 
     If the passed-in vector has fewer than four elements, or the input is scalar,
-    the remaining values in the <4 x i8> will be filled with with
+    the remaining values in the <4 x i8> will be filled with
     undefined values as needed.
   }];
   let assemblyFormat = [{
-    attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
+    attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res)
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index f194e70ee275b..9a433202e3149 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -681,26 +681,26 @@ def ROCDL_CvtPkRtz:
   }];
 }
 
-def ROCDL_CvtScaleF32PkFp8F16 :
+def ROCDL_CvtScaleF32PkFp8F16Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
     let summary = "Scale and convert f16 to packed fp8";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed fp8.
-    Store the result in low/high word based on $wordSel, preserving the other word.
+    Scale `src` by the exponent in `scale`, then convert to packed fp8.
+    Store the result in low/high word of `old` based on $wordSel, preserving the other word.
   }];
   let assemblyFormat = [{
     attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
   }];
 }
 
-def ROCDL_CvtScaleF32PkFp8Bf16 :
+def ROCDL_CvtScaleF32PkFp8Bf16Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
     let summary = "Scale and convert packed bf16 to packed fp8";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed fp8.
-    Store the result in low/high word based on $wordSel, preserving the other word.
+    Scale `src` by the exponent in `scale`, then convert to packed fp8.
+    Store the result in low/high word of `old` based on $wordSel, preserving the other word.
   }];
   let assemblyFormat = [{
     attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -708,13 +708,13 @@ def ROCDL_CvtScaleF32PkFp8Bf16 :
 }
 
 
-def ROCDL_CvtScaleF32PkBf8F16 :
+def ROCDL_CvtScaleF32PkBf8F16Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
     let summary = "Scale and convert f16 to packed bf8";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed bf8.
-    Store the result in low/high word based on $wordSel, preserving the other word.
+    Scale `src` by the exponent in `scale`, then convert to packed bf8.
+    Store the result in low/high word of `old` based on $wordSel, preserving the other word.
   }];
   let assemblyFormat = [{
     attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -722,26 +722,26 @@ def ROCDL_CvtScaleF32PkBf8F16 :
 }
 
 
-def ROCDL_CvtScaleF32PkBf8Bf16 :
+def ROCDL_CvtScaleF32PkBf8Bf16Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
     let summary = "Scale and convert bf16 to packed bf8";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed bf8.
-    Store the result in low/high word based on $wordSel, preserving the other word.
+    Scale `src` by the exponent in `scale`, then convert to packed bf8.
+    Store the result in low/high word of `old` based on $wordSel, preserving the other word.
   }];
   let assemblyFormat = [{
     attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
   }];
 }
 
-def ROCDL_CvtScaleF32SrFp8F16 :
+def ROCDL_CvtScaleF32SrFp8F16Op :
     ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
     Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
     let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
-    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+    using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
 
   }];
   let assemblyFormat = [{
@@ -749,13 +749,13 @@ def ROCDL_CvtScaleF32SrFp8F16 :
   }];
 }
 
-def ROCDL_CvtScaleF32SrBf8F16 :
+def ROCDL_CvtScaleF32SrBf8F16Op :
     ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
     Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
     let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
-    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
+    using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
 
   }];
   let assemblyFormat = [{
@@ -763,13 +763,13 @@ def ROCDL_CvtScaleF32SrBf8F16 :
   }];
 }
 
-def ROCDL_CvtScaleF32SrFp8Bf16 :
+def ROCDL_CvtScaleF32SrFp8Bf16Op :
     ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
     Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
     let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
-    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
+    using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
 
   }];
   let assemblyFormat = [{
@@ -777,13 +777,13 @@ def ROCDL_CvtScaleF32SrFp8Bf16 :
   }];
 }
 
-def ROCDL_CvtScaleF32SrBf8Bf16:
+def ROCDL_CvtScaleF32SrBf8Bf16Op :
     ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
     Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
     let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
     let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
-    using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+    using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
 
   }];
   let assemblyFormat = [{
@@ -791,48 +791,74 @@ def ROCDL_CvtScaleF32SrBf8Bf16:
   }];
 }
 
-def ROCDL_CvtScaleF32PkF16Fp8 :
+def ROCDL_CvtScaleF32PkF16Fp8Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
     Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
-    let summary = "Scale and convert fp8 to packed f16";
-    let description = [{ Scale `src` based on $wordSel by the exponent in `scale` 
-    then convert to packed f16.
+    let summary = "Convert fp8 to packed f16 and scale";
+    let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+    the packed values by the exponent in `scale`.
   }];
   let assemblyFormat = [{
     attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
   }];
 }
 
-def ROCDL_CvtScaleF32PkF16Bf8 :
+def ROCDL_CvtScaleF32PkF16Bf8Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
     Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
-    let summary = "Scale and convert bf8 to packed f16";
-    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
-    then convert to packed f16.
+    let summary = "convert bf8 to packed f16 and scale";
+    let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+    the packed values by exponent in `scale`.
   }];
   let assemblyFormat = [{
     attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
   }];
 }
 
-def ROCDL_CvtScaleF16Fp8 :
+def ROCDL_CvtScaleF32PkBf16Fp8Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+    let summary = "Convert fp8 to packed bf16 and scale";
+    let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Bf8Op :
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+    let summary = "Convert bf8 to packed bf16 and scale";
+    let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+    the packed values by the exponent in `scale`.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF16Fp8Op :
     ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
     let summary = "Scale and convert fp8 to f16";
-    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
-    then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+    let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+    by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+    preserving the others.
   }];
   let assemblyFormat = [{
     attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
   }];
 }
 
-def ROCDL_CvtScaleF16Bf8 :
+def ROCDL_CvtScaleF16Bf8Op :
     ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
     let summary = "Scale and convert fp8 to f16";
-    let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
-    then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+    let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+    by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+    preserving the others.
   }];
   let assemblyFormat = [{
     attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
@@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 :
 //===---------------------------------------------------------------------===//
 // 32-bit float intrinsics
 //===---------------------------------------------------------------------===//
-def ROCDL_CvtScale32PkF32Fp8 :
+def ROCDL_CvtScaleF32PkF32Fp8Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
     Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
   let summary = "Scale and convert packed fp8 to packed f32";
   let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed fp32.
-    Store the result in low/high word based on $wordSel, preserving the other word.
+    Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+    the exponent in `scale`. Store the result in a vector.
   }];
   let assemblyFormat = [{
     attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
   }];
 }
-def ROCDL_CvtScale32PkF32Bf8 :
+def ROCDL_CvtScaleF32PkF32Bf8Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
     Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
   let summary = "Scale and convert packed bf8 to packed f32";
   let description = [{
-    Scale `src` by the exponent in `scale` then convert to packed fp32.
-    Store the result in low/high word based on $wordSel, preserving the other word.
+    Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+    the exponent in `scale`. Store the result in a vector.
   }];
   let assemblyFormat = [{
     attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
@@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 :
 //===---------------------------------------------------------------------===//
 // 8-bit float scale intrinsics
 //===---------------------------------------------------------------------===//
-def ROCDL_CvtScaleF32PkFp8F32:
+def ROCDL_CvtScaleF32PkFp8F32Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
   let summary = "Scale and convert two f32's to packed fp8";
@@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32:
   }];
 }
 
-def ROCDL_CvtScaleF32PkBf8F32:
+def ROCDL_CvtScaleF32PkBf8F32Op :
     ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
     Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
   let summary = "Scale and convert two f32's to packed bf8";
@@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32:
   }];
 }
 
-def ROCDL_CvtScaleF32SrFp8F32:
+def ROCDL_CvtScaleF32SrFp8F32Op :
     ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
     Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
     let summary = "Scale and convert f32 to fp8 using stochastic rounding";
@@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32:
 }
 
 
-def ROCDL_CvtScaleF32SrBf8F32:
+def ROCDL_CvtScaleF32SrBf8F32Op :
     ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
     Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
     let summary = "Scale and convert f32 to bf8 using stochastic rounding";
@@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op :
   }];
 }
 
+def ROCDL_CvtPkF32Fp8Op :
+    ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, I1:$wordSel)> {
+  let summary = "Convert packed fp8 to packed f32";
+  let description = [{
+    Convert `src` based on $wordSel to packed fp32.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtPkF32Bf8Op :
+    ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, I1:$wordSel)> {
+  let summary = "Convert packed bf8 to packed f32";
+  let description = [{
+    Convert `src` based on $wordSel to packed fp32,
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `:` type($res)
+  }];
+}
 
 def ROCDL_CvtPkBf8F32Op :
     ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..768d21384412d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -977,13 +977,13 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
     source = longVec;
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
-  Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
+  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
   if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
-                                                    wordSel);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
+                                                      wordSel);
   } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
-                                                    wordSel);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
+                                                      wordSel);
   }
   return success();
 }
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..f9b685d1e90f6 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -83,14 +83,15 @@ static bool isSupportedF8(Type elementType, Chipset chipset) {
   return false;
 }
 
-static Value castF32To(Type elementType, Value f32, Location loc,
+static Value castF32To(Type desType, Value f32, Location loc,
                        PatternRewriter &rewriter) {
+  Type elementType = getElementTypeOrSelf(desType);
   if (elementType.isF32())
     return f32;
   if (elementType.getIntOrFloatBitWidth() < 32)
-    return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
+    return rewriter.create<arith::TruncFOp>(loc, desType, f32);
   if (elementType.getIntOrFloatBitWidth() > 32)
-    return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
+    return rewriter.create<arith::ExtFOp>(loc, desType, f32);
   llvm_unreachable("The only 32-bit float type is f32");
 }
 
@@ -110,10 +111,12 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  VectorType extResType = VectorType::get(2, rewriter.getF32Type());
   if (!inVecType) {
-    Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
-        loc, rewriter.getF32Type(), in, 0);
-    Value result = castF32To(outElemType, asFloat, loc, rewriter);
+    Value asFloats =
+        rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType, in, 0);
+    Value resFloat = rewriter.create<vector::ExtractOp>(loc, asFloats, 0);
+    Value result = castF32To(outElemType, resFloat, loc, rewriter);
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -150,11 +153,18 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
     Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
         loc, in, i, elemsThisOp, 1);
-    for (int64_t j = 0; j < elemsThisOp; ++j) {
-      Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
-          loc, rewriter.getF32Type(), inSlice, j);
-      Value asType = castF32To(outElemType, asFloat, loc, rewriter);
-      result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+    for (int64_t j = 0; j < elemsThisOp; j += 2) {
+      Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType,
+                                                               inSlice, j / 2);
+      Type desType = VectorType::get(2, outElemType);
+      Value asType = castF32To(desType, asFloats, loc, rewriter);
+      if (i + j + 1 < numElements)
+        result = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, asType, result, i + j, 1);
+      else {
+        asType = rewriter.create<vector::ExtractOp>(loc, asType, 0);
+        result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+      }
     }
   }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
index 70775a603e54d..0fb03ff13b558 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -7,12 +7,12 @@
 // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
 // CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
-  func.return %ret : f32
+func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @ext_short_vec
@@ -25,24 +25,24 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 {
 // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
 // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
 // CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
-  func.return %ret : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @ext_full_vec(
 // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
-// CHECK: return [[EXT]] : f32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
 
-func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
-  func.return %ret : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @packed_trunc
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
index a313aaffdf5cc..0a4a960d59ce8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -6,12 +6,12 @@
 // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
 // CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
-  func.return %ret : f32
+func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @ext_short_vec
@@ -24,24 +24,24 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
 // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
 // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
 // CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
-  func.return %ret : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @ext_full_vec(
 // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
-// CHECK: return [[EXT]] : f32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
 
-func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
-  func.return %ret : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @packed_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
index 0e7f58c9e6749..b75b69c1b5d27 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -1,10 +1,11 @@
 // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
 // RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s
-
+  
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
-// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to vector<2xf32>
+// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
+// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
 // CHECK: return [[W]]
 func.func @scalar_ext(%v: f8E5M2) -> f16 {
   %w = arith.extf %v : f8E5M2 to f16
@@ -17,14 +18,9 @@ func.func @scalar_ext(%v: f8E5M2) -> f16 {
 
 // CHECK-LABEL: func.func @vector_ext_short
 // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>)
-// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32
-// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
-// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32
-// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
-// CHECK: return [[W1]] : vector<2xf64>
+// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT0]] : vector<2xf32> to vector<2xf64>
+// CHECK: return [[EXT]] : vector<2xf64>
 
 func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
   %w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64>
@@ -35,30 +31,22 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
 
 // CHECK-LABEL: func.func @vector_ext_long
 // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>)
-// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: return [[W8]]
+// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
+// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to vector<2xf32>
+// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: return [[W5]]
 func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
   %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
   return %w : vector<9xf32>
@@ -144,31 +132,25 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
 
 // CHECK-LABEL: func.func @vector_ext_long_2d
 // CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
 // CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
 // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
 
 // CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
 
 // CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FN> to vector<2xf32>
+// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
 // CHECK: return [[CAST]]
 func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
   %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 6bb5b9771c015..2ed3f47e8ab73 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -2,8 +2,9 @@
 
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
-// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to vector<2xf32>
+// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
+// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
 // CHECK: return [[W]]
 func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
   %w = arith.extf %v : f8E5M2FNUZ to f16
@@ -16,8 +17,9 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
 // CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
 // CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
-// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
-// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
+// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to vector<2xf32>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CONVERT]][0] : f32 from vector<2xf32>
+// CHECK: %[[RESULT:.+]] = vector.insert %[[EXTRACT2]], %[[CONST]] [] : f32 into vector<f32>
 // CHECK: return %[[RESULT]] : vector<f32>
 func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
   %w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
@@ -28,15 +30,9 @@ func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
 
 // CHECK-LABEL: func.func @vector_ext_short
 // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
-// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
-// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
-// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
-// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
-// CHECK: return [[W1]] : vector<2xf64>
-
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to vector<2xf32>
+// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT]] : vector<2xf32> to vector<2xf64>
+// CHECK: return [[EXT]] : vector<2xf64>
 func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
   %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
   return %w : vector<2xf64>
@@ -46,30 +42,22 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
 
 // CHECK-LABEL: func.func @vector_ext_long
 // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
-// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: return [[W8]]
+// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
+// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: return [[W5]]
 func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
   %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
   return %w : vector<9xf32>
@@ -155,31 +143,25 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
 
 // CHECK-LABEL: func.func @vector_ext_long_2d
 // CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
 // CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
 // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
 
 // CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
 
 // CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
 // CHECK: return [[CAST]]
 func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
   %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 567e6498330a3..bf312ead32712 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -6,9 +6,9 @@
 
 // CHECK-LABEL: func @ext_packed_fp8
 // CHECK: amdgpu.ext_packed_fp8
-func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 {
-  %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32
-  func.return %ret : f32
+func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+  func.return %ret : vector<2xf32>
 }
 
 // CHECK-LABEL: func @packed_trunc_2xfp8
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index bc917041998d8..cce2c0aee62f3 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -767,10 +767,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
 // CHECK: rocdl.cvt.scalef32.f32.fp8
 // CHECK: rocdl.cvt.scalef32.pk.f16.bf8 
 // CHECK: rocdl.cvt.scalef32.pk.f16.fp8 
+// CHECK: rocdl.cvt.scalef32.pk.bf16.bf8
+// CHECK: rocdl.cvt.scalef32.pk.bf16.fp8
 // CHECK: rocdl.cvt.scalef32.f16.fp8
 // CHECK: rocdl.cvt.scalef32.f16.bf8
 // CHECK: rocdl.cvt.pk.bf8.f32
 // CHECK: rocdl.cvt.pk.fp8.f32
+// CHECK: rocdl.cvt.pk.f32.bf8
+// CHECK: rocdl.cvt.pk.f32.fp8
 // CHECK: rocdl.cvt.sr.bf8.f32
 // CHECK: rocdl.cvt.sr.fp8.f32
 // CHECK: rocdl.cvt.scalef32.sr.fp8.f32
@@ -793,10 +797,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
   %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
   %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
   %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
+  %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16>
+  %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16>
   %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
   %v6  = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
+  %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32>
+  %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32>
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
   %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
   %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 11f2faa2761ff..e70617bfff99e 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1042,6 +1042,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
 // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
 // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
 // CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
+// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
 // CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
@@ -1068,6 +1070,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
   %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
   %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
   %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
+  %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32
+  %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32

>From ae25d7357b40f38caa599af78f41db97516005b7 Mon Sep 17 00:00:00 2001
From: Yi Qian <yi.qian at amd.com>
Date: Wed, 19 Mar 2025 04:45:36 +0000
Subject: [PATCH 2/2] Allow amdgpu.ext_packed_fp8 to return a scalar or vector
 type

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 11 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 26 +++++++---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 23 +++++----
 .../AMDGPUToROCDL/8-bit-floats-ocp.mlir       | 48 ++++++++++++++----
 .../AMDGPUToROCDL/8-bit-floats.mlir           | 48 ++++++++++++++----
 .../ArithToAMDGPU/8-bit-floats-ocp.mlir       | 45 ++++++++---------
 .../ArithToAMDGPU/8-bit-floats.mlir           | 50 +++++++++----------
 7 files changed, 159 insertions(+), 92 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3ed6e84d19044..c0b3e5540b1df 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -85,12 +85,13 @@ def AMDGPU_ExtPackedFp8Op :
     AMDGPU_Op<"ext_packed_fp8", [Pure]>,
     Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
         VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
-      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex)>,
-    Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> {
-  let summary = "Extend a vector of packed fp8 values to two floats";
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+    Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+  let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
 
   let description = [{
-    Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them.
+    Extend one or two 8-bit floats in `source[index]` to a 32-bit float or
+    two floats and return them.
 
     This rather unusual signature arises from the fact that AMD GPUs cannot
     easily work with sub 32-bit quantities, so the compiler intrinsics for
@@ -102,7 +103,7 @@ def AMDGPU_ExtPackedFp8Op :
     undefined values as needed.
   }];
   let assemblyFormat = [{
-    attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res)
+    attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
   }];
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 768d21384412d..3acd470cff7f5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -959,6 +959,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
 
   Value source = adaptor.getSource();
   auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
   Type sourceElemType = getElementTypeOrSelf(op.getSource());
   // Extend to a v4i8
   if (!sourceVecType || sourceVecType.getNumElements() < 4) {
@@ -977,13 +978,24 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
     source = longVec;
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
-  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
-  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
-                                                      wordSel);
-  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
-                                                      wordSel);
+  if (resultVecType) {
+    Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
+    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
+                                                        wordSel);
+    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
+                                                        wordSel);
+    }
+  } else {
+    Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
+    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
+                                                      byteSel);
+    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
+                                                      byteSel);
+    }
   }
   return success();
 }
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index f9b685d1e90f6..3596b3235a631 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -113,10 +113,9 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
   VectorType extResType = VectorType::get(2, rewriter.getF32Type());
   if (!inVecType) {
-    Value asFloats =
-        rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType, in, 0);
-    Value resFloat = rewriter.create<vector::ExtractOp>(loc, asFloats, 0);
-    Value result = castF32To(outElemType, resFloat, loc, rewriter);
+    Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+        loc, rewriter.getF32Type(), in, 0);
+    Value result = castF32To(outElemType, asFloat, loc, rewriter);
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -154,15 +153,17 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
     Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
         loc, in, i, elemsThisOp, 1);
     for (int64_t j = 0; j < elemsThisOp; j += 2) {
-      Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType,
-                                                               inSlice, j / 2);
-      Type desType = VectorType::get(2, outElemType);
-      Value asType = castF32To(desType, asFloats, loc, rewriter);
-      if (i + j + 1 < numElements)
+      if (i + j + 1 < numElements) { // Convert two 8-bit elements
+        Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
+            loc, extResType, inSlice, j / 2);
+        Type desType = VectorType::get(2, outElemType);
+        Value asType = castF32To(desType, asFloats, loc, rewriter);
         result = rewriter.create<vector::InsertStridedSliceOp>(
             loc, asType, result, i + j, 1);
-      else {
-        asType = rewriter.create<vector::ExtractOp>(loc, asType, 0);
+      } else { // Convert a 8-bit element
+        Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+            loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+        Value asType = castF32To(outElemType, asFloat, loc, rewriter);
         result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
       }
     }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
index 0fb03ff13b558..eb483b0880294 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -7,12 +7,12 @@
 // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
-// CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> {
-  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to vector<2xf32>
-  func.return %ret : vector<2xf32>
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+  func.return %ret : f32
 }
 
 // CHECK-LABEL: func @ext_short_vec
@@ -25,22 +25,50 @@ func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> {
 // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
 // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_packed_2xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
 // CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
 // CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
+func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
   %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
   func.return %ret : vector<2xf32>
 }
 
-// CHECK-LABEL: func @ext_full_vec(
+// CHECK-LABEL: func @ext_packed_4xfp8
 // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
 // CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
 // CHECK: return [[EXT]] : vector<2xf32>
-
-func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
+func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
   %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
   func.return %ret : vector<2xf32>
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
index 0a4a960d59ce8..4029d14650d7f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -6,12 +6,12 @@
 // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
-// CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> {
-  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to vector<2xf32>
-  func.return %ret : vector<2xf32>
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
+  func.return %ret : f32
 }
 
 // CHECK-LABEL: func @ext_short_vec
@@ -24,22 +24,50 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> {
 // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
 // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FNUZ> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_packed_2xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
 // CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
 // CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
+func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
   %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
   func.return %ret : vector<2xf32>
 }
 
-// CHECK-LABEL: func @ext_full_vec(
+// CHECK-LABEL: func @ext_packed_4xfp8(
 // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
 // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
 // CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
 // CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
 // CHECK: return [[EXT]] : vector<2xf32>
-
-func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
   %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
   func.return %ret : vector<2xf32>
 }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
index b75b69c1b5d27..7fb5fbfe0c89e 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -3,9 +3,8 @@
   
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to vector<2xf32>
-// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
-// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
 // CHECK: return [[W]]
 func.func @scalar_ext(%v: f8E5M2) -> f16 {
   %w = arith.extf %v : f8E5M2 to f16
@@ -43,9 +42,8 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
 // CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32>
 // CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
 // CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to vector<2xf32>
-// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
-// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32>
 // CHECK: return [[W5]]
 func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
   %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
@@ -131,28 +129,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
 // -----
 
 // CHECK-LABEL: func.func @vector_ext_long_2d
-// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
-// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
+// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FN>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FN> to vector<11xf8E4M3FN>
 // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
 // CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32>
 // CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32>
 
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FN> to vector<4xf8E4M3FN>
 // CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32>
 // CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FN> to vector<2xf32>
-// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
-// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FN> to vector<3xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FN> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32>
 // CHECK: return [[CAST]]
-func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
-  %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
-  return %w : vector<1x9xf32>
+func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FN>) -> vector<1x11xf32> {
+  %w = arith.extf %v : vector<1x11xf8E4M3FN> to vector<1x11xf32>
+  return %w : vector<1x11xf32>
 }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 2ed3f47e8ab73..59ed6bd95ae8b 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -2,9 +2,8 @@
 
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to vector<2xf32>
-// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
-// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
 // CHECK: return [[W]]
 func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
   %w = arith.extf %v : f8E5M2FNUZ to f16
@@ -17,9 +16,8 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
 // CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
 // CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
-// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to vector<2xf32>
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CONVERT]][0] : f32 from vector<2xf32>
-// CHECK: %[[RESULT:.+]] = vector.insert %[[EXTRACT2]], %[[CONST]] [] : f32 into vector<f32>
+// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
+// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
 // CHECK: return %[[RESULT]] : vector<f32>
 func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
   %w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
@@ -54,9 +52,8 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
 // CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
 // CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
 // CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
-// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32>
 // CHECK: return [[W5]]
 func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
   %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
@@ -142,28 +139,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
 // -----
 
 // CHECK-LABEL: func.func @vector_ext_long_2d
-// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
-// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FNUZ>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FNUZ> to vector<11xf8E4M3FNUZ>
 // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
 // CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32>
 // CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32>
 
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
 // CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32>
 // CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
-// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<3xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FNUZ> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32>
 // CHECK: return [[CAST]]
-func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
-  %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
-  return %w : vector<1x9xf32>
+func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FNUZ>) -> vector<1x11xf32> {
+  %w = arith.extf %v : vector<1x11xf8E4M3FNUZ> to vector<1x11xf32>
+  return %w : vector<1x11xf32>
 }



More information about the Mlir-commits mailing list