[Mlir-commits] [mlir] 0ea4fb9 - [AMD][ROCDL] Add packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 in ROCDL dialect (#131850)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 21 07:49:53 PDT 2025
Author: Yi Qian
Date: 2025-03-21T14:49:50Z
New Revision: 0ea4fb92648b2aa7cbab486bb493e122b4dcc062
URL: https://github.com/llvm/llvm-project/commit/0ea4fb92648b2aa7cbab486bb493e122b4dcc062
DIFF: https://github.com/llvm/llvm-project/commit/0ea4fb92648b2aa7cbab486bb493e122b4dcc062.diff
LOG: [AMD][ROCDL] Add packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 in ROCDL dialect (#131850)
- Add packed conversions fp8/bf8->bf16 for gfx950 and fp8/bf8->fp32 for
gfx942 in ROCDL dialect
- Update amdgpu.ext_packed_fp8 lowering to use ROCDL packed fp8/bf8->f32
conversions for vector target types and ROCDL scalar fp8/bf8->fp32 for
scalar target type.
---------
Co-authored-by: Jungwook Park <jungwook.park at amd.com>
Added:
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
mlir/test/Dialect/LLVMIR/rocdl.mlir
mlir/test/Target/LLVMIR/rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3acc383923ca8..c0b3e5540b1df 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -86,10 +86,12 @@ def AMDGPU_ExtPackedFp8Op :
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
- Results<(outs F32:$res)> {
- let summary = "Extend one of a vector of packed fp8 values to a float";
+ Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+ let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
+
let description = [{
- Extend the value `source[index]` to a 32-bit float and return it.
+ Extend one or two 8-bit floats in `source[index]` to a 32-bit float or
+ two floats and return them.
This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
@@ -97,7 +99,7 @@ def AMDGPU_ExtPackedFp8Op :
this operation) take packed vectors of 4 such floats.
If the passed-in vector has fewer than four elements, or the input is scalar,
- the remaining values in the <4 x i8> will be filled with with
+ the remaining values in the <4 x i8> will be filled with
undefined values as needed.
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index f194e70ee275b..9a433202e3149 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -681,26 +681,26 @@ def ROCDL_CvtPkRtz:
}];
}
-def ROCDL_CvtScaleF32PkFp8F16 :
+def ROCDL_CvtScaleF32PkFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed fp8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF32PkFp8Bf16 :
+def ROCDL_CvtScaleF32PkFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf16 to packed fp8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -708,13 +708,13 @@ def ROCDL_CvtScaleF32PkFp8Bf16 :
}
-def ROCDL_CvtScaleF32PkBf8F16 :
+def ROCDL_CvtScaleF32PkBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed bf8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -722,26 +722,26 @@ def ROCDL_CvtScaleF32PkBf8F16 :
}
-def ROCDL_CvtScaleF32PkBf8Bf16 :
+def ROCDL_CvtScaleF32PkBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert bf16 to packed bf8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF32SrFp8F16 :
+def ROCDL_CvtScaleF32SrFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -749,13 +749,13 @@ def ROCDL_CvtScaleF32SrFp8F16 :
}];
}
-def ROCDL_CvtScaleF32SrBf8F16 :
+def ROCDL_CvtScaleF32SrBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -763,13 +763,13 @@ def ROCDL_CvtScaleF32SrBf8F16 :
}];
}
-def ROCDL_CvtScaleF32SrFp8Bf16 :
+def ROCDL_CvtScaleF32SrFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -777,13 +777,13 @@ def ROCDL_CvtScaleF32SrFp8Bf16 :
}];
}
-def ROCDL_CvtScaleF32SrBf8Bf16:
+def ROCDL_CvtScaleF32SrBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -791,48 +791,74 @@ def ROCDL_CvtScaleF32SrBf8Bf16:
}];
}
-def ROCDL_CvtScaleF32PkF16Fp8 :
+def ROCDL_CvtScaleF32PkF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert fp8 to packed f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to packed f16.
+ let summary = "Convert fp8 to packed f16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+ the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScaleF32PkF16Bf8 :
+def ROCDL_CvtScaleF32PkF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert bf8 to packed f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to packed f16.
+ let summary = "convert bf8 to packed f16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+ the packed values by exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScaleF16Fp8 :
+def ROCDL_CvtScaleF32PkBf16Fp8Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Convert fp8 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Bf8Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Convert bf8 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+ by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+ preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF16Bf8 :
+def ROCDL_CvtScaleF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+ by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+ preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
@@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 :
//===---------------------------------------------------------------------===//
// 32-bit float intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_CvtScale32PkF32Fp8 :
+def ROCDL_CvtScaleF32PkF32Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed fp8 to packed f32";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp32.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+ the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScale32PkF32Bf8 :
+def ROCDL_CvtScaleF32PkF32Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf8 to packed f32";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp32.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+ the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
@@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 :
//===---------------------------------------------------------------------===//
// 8-bit float scale intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_CvtScaleF32PkFp8F32:
+def ROCDL_CvtScaleF32PkFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed fp8";
@@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32:
}];
}
-def ROCDL_CvtScaleF32PkBf8F32:
+def ROCDL_CvtScaleF32PkBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed bf8";
@@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32:
}];
}
-def ROCDL_CvtScaleF32SrFp8F32:
+def ROCDL_CvtScaleF32SrFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
@@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32:
}
-def ROCDL_CvtScaleF32SrBf8F32:
+def ROCDL_CvtScaleF32SrBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
@@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op :
}];
}
+def ROCDL_CvtPkF32Fp8Op :
+ ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, I1:$wordSel)> {
+ let summary = "Convert packed fp8 to packed f32";
+ let description = [{
+ Convert `src` based on $wordSel to packed fp32.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtPkF32Bf8Op :
+ ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, I1:$wordSel)> {
+ let summary = "Convert packed bf8 to packed f32";
+ let description = [{
+ Convert `src` based on $wordSel to packed fp32,
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `:` type($res)
+ }];
+}
def ROCDL_CvtPkBf8F32Op :
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..3acd470cff7f5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -959,6 +959,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Value source = adaptor.getSource();
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+ auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
@@ -977,13 +978,24 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
- Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
- wordSel);
- } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
- wordSel);
+ if (resultVecType) {
+ Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
+ wordSel);
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
+ wordSel);
+ }
+ } else {
+ Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
+ byteSel);
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
+ byteSel);
+ }
}
return success();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..3596b3235a631 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -83,14 +83,15 @@ static bool isSupportedF8(Type elementType, Chipset chipset) {
return false;
}
-static Value castF32To(Type elementType, Value f32, Location loc,
+static Value castF32To(Type desType, Value f32, Location loc,
PatternRewriter &rewriter) {
+ Type elementType = getElementTypeOrSelf(desType);
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
+ return rewriter.create<arith::TruncFOp>(loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
+ return rewriter.create<arith::ExtFOp>(loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -110,6 +111,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
@@ -150,11 +152,20 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, i, elemsThisOp, 1);
- for (int64_t j = 0; j < elemsThisOp; ++j) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j);
- Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ for (int64_t j = 0; j < elemsThisOp; j += 2) {
+ if (i + j + 1 < numElements) { // Convert two 8-bit elements
+ Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
+ loc, extResType, inSlice, j / 2);
+ Type desType = VectorType::get(2, outElemType);
+ Value asType = castF32To(desType, asFloats, loc, rewriter);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, asType, result, i + j, 1);
+ } else { // Convert a 8-bit element
+ Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+ loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+ Value asType = castF32To(outElemType, asFloat, loc, rewriter);
+ result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ }
}
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
index 70775a603e54d..ea0c3afbd9021 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -9,7 +9,7 @@
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
-// CHECK: return [[EXT]]
+// CHECK: return [[EXT]] : f32
func.func @ext_scalar(%v: f8E5M2) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
func.return %ret : f32
@@ -27,7 +27,7 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 {
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
-// CHECK: return [[EXT]]
+// CHECK: return [[EXT]] : f32
func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
func.return %ret : f32
@@ -39,12 +39,40 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
// CHECK: return [[EXT]] : f32
-
func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
func.return %ret : f32
}
+// CHECK-LABEL: func @ext_packed_2xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @ext_packed_4xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
// CHECK-LABEL: func @packed_trunc
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
index a313aaffdf5cc..219f822ca9a1c 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -8,7 +8,7 @@
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
-// CHECK: return [[EXT]]
+// CHECK: return [[EXT]] : f32
func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
func.return %ret : f32
@@ -26,24 +26,52 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
-// CHECK: return [[EXT]]
+// CHECK: return [[EXT]] : f32
func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
func.return %ret : f32
}
-// CHECK-LABEL: func @ext_full_vec(
+// CHECK-LABEL: func @ext_full_vec
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
// CHECK: return [[EXT]] : f32
-
func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
func.return %ret : f32
}
+// CHECK-LABEL: func @ext_packed_2xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @ext_packed_4xfp8(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
// CHECK-LABEL: func @packed_trunc
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
index 0e7f58c9e6749..7fb5fbfe0c89e 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s
-
+
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2)
// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
@@ -17,14 +17,9 @@ func.func @scalar_ext(%v: f8E5M2) -> f16 {
// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>)
-// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32
-// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
-// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32
-// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
-// CHECK: return [[W1]] : vector<2xf64>
+// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT0]] : vector<2xf32> to vector<2xf64>
+// CHECK: return [[EXT]] : vector<2xf64>
func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
%w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64>
@@ -35,30 +30,21 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
// CHECK-LABEL: func.func @vector_ext_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>)
-// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: return [[W8]]
+// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
+// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: return [[W5]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
return %w : vector<9xf32>
@@ -143,34 +129,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
// -----
// CHECK-LABEL: func.func @vector_ext_long_2d
-// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
+// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FN>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FN> to vector<11xf8E4M3FN>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FN> to vector<3xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FN> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32>
// CHECK: return [[CAST]]
-func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
- %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
- return %w : vector<1x9xf32>
+func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FN>) -> vector<1x11xf32> {
+ %w = arith.extf %v : vector<1x11xf8E4M3FN> to vector<1x11xf32>
+ return %w : vector<1x11xf32>
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 6bb5b9771c015..59ed6bd95ae8b 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -28,15 +28,9 @@ func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
-// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
-// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
-// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
-// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
-// CHECK: return [[W1]] : vector<2xf64>
-
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to vector<2xf32>
+// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT]] : vector<2xf32> to vector<2xf64>
+// CHECK: return [[EXT]] : vector<2xf64>
func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
%w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
return %w : vector<2xf64>
@@ -46,30 +40,21 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
// CHECK-LABEL: func.func @vector_ext_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
-// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: return [[W8]]
+// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
+// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: return [[W5]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
return %w : vector<9xf32>
@@ -154,34 +139,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
// -----
// CHECK-LABEL: func.func @vector_ext_long_2d
-// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FNUZ>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FNUZ> to vector<11xf8E4M3FNUZ>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<3xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FNUZ> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32>
// CHECK: return [[CAST]]
-func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
- %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
- return %w : vector<1x9xf32>
+func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FNUZ>) -> vector<1x11xf32> {
+ %w = arith.extf %v : vector<1x11xf8E4M3FNUZ> to vector<1x11xf32>
+ return %w : vector<1x11xf32>
}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 567e6498330a3..665674f2a7873 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -4,13 +4,20 @@
// Verify the generic form can be parsed.
// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
-// CHECK-LABEL: func @ext_packed_fp8
-// CHECK: amdgpu.ext_packed_fp8
-func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 {
+// CHECK-LABEL: func @ext_packed_fp8_s
+// CHECK: amdgpu.ext_packed_fp8 {{.*}} vector<4xf8E4M3FNUZ> to f32
+func.func @ext_packed_fp8_s(%v: vector<4xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32
func.return %ret : f32
}
+// CHECK-LABEL: func @ext_packed_fp8_v
+// CHECK: amdgpu.ext_packed_fp8 {{.*}} vector<4xf8E4M3FNUZ> to vector<2xf32
+func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
// CHECK-LABEL: func @packed_trunc_2xfp8
// CHECK: amdgpu.packed_trunc_2xfp8
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index bc917041998d8..cce2c0aee62f3 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -767,10 +767,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: rocdl.cvt.scalef32.f32.fp8
// CHECK: rocdl.cvt.scalef32.pk.f16.bf8
// CHECK: rocdl.cvt.scalef32.pk.f16.fp8
+// CHECK: rocdl.cvt.scalef32.pk.bf16.bf8
+// CHECK: rocdl.cvt.scalef32.pk.bf16.fp8
// CHECK: rocdl.cvt.scalef32.f16.fp8
// CHECK: rocdl.cvt.scalef32.f16.bf8
// CHECK: rocdl.cvt.pk.bf8.f32
// CHECK: rocdl.cvt.pk.fp8.f32
+// CHECK: rocdl.cvt.pk.f32.bf8
+// CHECK: rocdl.cvt.pk.f32.fp8
// CHECK: rocdl.cvt.sr.bf8.f32
// CHECK: rocdl.cvt.sr.fp8.f32
// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
@@ -793,10 +797,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
+ %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16>
+ %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16>
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
+ %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32>
+ %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32>
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 11f2faa2761ff..e70617bfff99e 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1042,6 +1042,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
+// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
@@ -1068,6 +1070,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
+ %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32
+ %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
More information about the Mlir-commits
mailing list