[Mlir-commits] [mlir] [AMD][ROCDL][AMDGPU] Support packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 (PR #131850)
Yi Qian
llvmlistbot at llvm.org
Tue Mar 18 21:46:15 PDT 2025
https://github.com/yiqian1 updated https://github.com/llvm/llvm-project/pull/131850
>From e7e14efa6df888889b576027305d3a083ca5fd5b Mon Sep 17 00:00:00 2001
From: Yi Qian <yi.qian at amd.com>
Date: Sat, 15 Mar 2025 05:07:21 +0000
Subject: [PATCH 1/2] [AMD][ROCDL][AMDGPU] Support packed conversions
fp8/bf8->bf16 and fp8/bf8->fp32
Add packed conversions fp8/bf8->bf16 in gfx950 and fp8/bf8->fp32 in gfx942
Update amdgpu.ext_packed_fp8 lowering to use ROCDL CvtPkF32Fp8Op
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 +-
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 145 ++++++++++++------
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 +-
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 32 ++--
.../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 32 ++--
.../AMDGPUToROCDL/8-bit-floats.mlir | 32 ++--
.../ArithToAMDGPU/8-bit-floats-ocp.mlir | 90 +++++------
.../ArithToAMDGPU/8-bit-floats.mlir | 94 +++++-------
mlir/test/Dialect/AMDGPU/ops.mlir | 6 +-
mlir/test/Dialect/LLVMIR/rocdl.mlir | 8 +
mlir/test/Target/LLVMIR/rocdl.mlir | 4 +
11 files changed, 251 insertions(+), 215 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3acc383923ca8..3ed6e84d19044 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -85,11 +85,12 @@ def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
- ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
- Results<(outs F32:$res)> {
- let summary = "Extend one of a vector of packed fp8 values to a float";
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex)>,
+ Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> {
+ let summary = "Extend a vector of packed fp8 values to two floats";
+
let description = [{
- Extend the value `source[index]` to a 32-bit float and return it.
+ Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them.
This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
@@ -97,11 +98,11 @@ def AMDGPU_ExtPackedFp8Op :
this operation) take packed vectors of 4 such floats.
If the passed-in vector has fewer than four elements, or the input is scalar,
- the remaining values in the <4 x i8> will be filled with with
+ the remaining values in the <4 x i8> will be filled with
undefined values as needed.
}];
let assemblyFormat = [{
- attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
+ attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res)
}];
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index f194e70ee275b..9a433202e3149 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -681,26 +681,26 @@ def ROCDL_CvtPkRtz:
}];
}
-def ROCDL_CvtScaleF32PkFp8F16 :
+def ROCDL_CvtScaleF32PkFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed fp8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF32PkFp8Bf16 :
+def ROCDL_CvtScaleF32PkFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf16 to packed fp8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -708,13 +708,13 @@ def ROCDL_CvtScaleF32PkFp8Bf16 :
}
-def ROCDL_CvtScaleF32PkBf8F16 :
+def ROCDL_CvtScaleF32PkBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed bf8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -722,26 +722,26 @@ def ROCDL_CvtScaleF32PkBf8F16 :
}
-def ROCDL_CvtScaleF32PkBf8Bf16 :
+def ROCDL_CvtScaleF32PkBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert bf16 to packed bf8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF32SrFp8F16 :
+def ROCDL_CvtScaleF32SrFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -749,13 +749,13 @@ def ROCDL_CvtScaleF32SrFp8F16 :
}];
}
-def ROCDL_CvtScaleF32SrBf8F16 :
+def ROCDL_CvtScaleF32SrBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -763,13 +763,13 @@ def ROCDL_CvtScaleF32SrBf8F16 :
}];
}
-def ROCDL_CvtScaleF32SrFp8Bf16 :
+def ROCDL_CvtScaleF32SrFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -777,13 +777,13 @@ def ROCDL_CvtScaleF32SrFp8Bf16 :
}];
}
-def ROCDL_CvtScaleF32SrBf8Bf16:
+def ROCDL_CvtScaleF32SrBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -791,48 +791,74 @@ def ROCDL_CvtScaleF32SrBf8Bf16:
}];
}
-def ROCDL_CvtScaleF32PkF16Fp8 :
+def ROCDL_CvtScaleF32PkF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert fp8 to packed f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to packed f16.
+ let summary = "Convert fp8 to packed f16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+ the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScaleF32PkF16Bf8 :
+def ROCDL_CvtScaleF32PkF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert bf8 to packed f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to packed f16.
+ let summary = "convert bf8 to packed f16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+ the packed values by exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScaleF16Fp8 :
+def ROCDL_CvtScaleF32PkBf16Fp8Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Convert fp8 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Bf8Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Convert bf8 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+ by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+ preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF16Bf8 :
+def ROCDL_CvtScaleF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+ by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+ preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
@@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 :
//===---------------------------------------------------------------------===//
// 32-bit float intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_CvtScale32PkF32Fp8 :
+def ROCDL_CvtScaleF32PkF32Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed fp8 to packed f32";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp32.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+ the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScale32PkF32Bf8 :
+def ROCDL_CvtScaleF32PkF32Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf8 to packed f32";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp32.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+ the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
@@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 :
//===---------------------------------------------------------------------===//
// 8-bit float scale intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_CvtScaleF32PkFp8F32:
+def ROCDL_CvtScaleF32PkFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed fp8";
@@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32:
}];
}
-def ROCDL_CvtScaleF32PkBf8F32:
+def ROCDL_CvtScaleF32PkBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed bf8";
@@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32:
}];
}
-def ROCDL_CvtScaleF32SrFp8F32:
+def ROCDL_CvtScaleF32SrFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
@@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32:
}
-def ROCDL_CvtScaleF32SrBf8F32:
+def ROCDL_CvtScaleF32SrBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
@@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op :
}];
}
+def ROCDL_CvtPkF32Fp8Op :
+ ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, I1:$wordSel)> {
+ let summary = "Convert packed fp8 to packed f32";
+ let description = [{
+ Convert `src` based on $wordSel to packed fp32.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtPkF32Bf8Op :
+ ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, I1:$wordSel)> {
+ let summary = "Convert packed bf8 to packed f32";
+ let description = [{
+ Convert `src` based on $wordSel to packed fp32,
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `:` type($res)
+ }];
+}
def ROCDL_CvtPkBf8F32Op :
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..768d21384412d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -977,13 +977,13 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
- Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
+ Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
- wordSel);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
+ wordSel);
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
- wordSel);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
+ wordSel);
}
return success();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..f9b685d1e90f6 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -83,14 +83,15 @@ static bool isSupportedF8(Type elementType, Chipset chipset) {
return false;
}
-static Value castF32To(Type elementType, Value f32, Location loc,
+static Value castF32To(Type desType, Value f32, Location loc,
PatternRewriter &rewriter) {
+ Type elementType = getElementTypeOrSelf(desType);
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
+ return rewriter.create<arith::TruncFOp>(loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
+ return rewriter.create<arith::ExtFOp>(loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -110,10 +111,12 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), in, 0);
- Value result = castF32To(outElemType, asFloat, loc, rewriter);
+ Value asFloats =
+ rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType, in, 0);
+ Value resFloat = rewriter.create<vector::ExtractOp>(loc, asFloats, 0);
+ Value result = castF32To(outElemType, resFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
@@ -150,11 +153,18 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, i, elemsThisOp, 1);
- for (int64_t j = 0; j < elemsThisOp; ++j) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j);
- Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ for (int64_t j = 0; j < elemsThisOp; j += 2) {
+ Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType,
+ inSlice, j / 2);
+ Type desType = VectorType::get(2, outElemType);
+ Value asType = castF32To(desType, asFloats, loc, rewriter);
+ if (i + j + 1 < numElements)
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, asType, result, i + j, 1);
+ else {
+ asType = rewriter.create<vector::ExtractOp>(loc, asType, 0);
+ result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ }
}
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
index 70775a603e54d..0fb03ff13b558 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -7,12 +7,12 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
// CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
- func.return %ret : f32
+func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @ext_short_vec
@@ -25,24 +25,24 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
// CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
- func.return %ret : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @ext_full_vec(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
-// CHECK: return [[EXT]] : f32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
-func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
- func.return %ret : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @packed_trunc
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
index a313aaffdf5cc..0a4a960d59ce8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -6,12 +6,12 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
// CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
- func.return %ret : f32
+func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @ext_short_vec
@@ -24,24 +24,24 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
// CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
- func.return %ret : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @ext_full_vec(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
-// CHECK: return [[EXT]] : f32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
-func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
- func.return %ret : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @packed_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
index 0e7f58c9e6749..b75b69c1b5d27 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -1,10 +1,11 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s
-
+
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
-// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to vector<2xf32>
+// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
+// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
// CHECK: return [[W]]
func.func @scalar_ext(%v: f8E5M2) -> f16 {
%w = arith.extf %v : f8E5M2 to f16
@@ -17,14 +18,9 @@ func.func @scalar_ext(%v: f8E5M2) -> f16 {
// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>)
-// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32
-// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
-// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32
-// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
-// CHECK: return [[W1]] : vector<2xf64>
+// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT0]] : vector<2xf32> to vector<2xf64>
+// CHECK: return [[EXT]] : vector<2xf64>
func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
%w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64>
@@ -35,30 +31,22 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
// CHECK-LABEL: func.func @vector_ext_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>)
-// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: return [[W8]]
+// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
+// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to vector<2xf32>
+// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: return [[W5]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
return %w : vector<9xf32>
@@ -144,31 +132,25 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
// CHECK-LABEL: func.func @vector_ext_long_2d
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FN> to vector<2xf32>
+// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
// CHECK: return [[CAST]]
func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
%w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 6bb5b9771c015..2ed3f47e8ab73 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -2,8 +2,9 @@
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
-// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to vector<2xf32>
+// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
+// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
// CHECK: return [[W]]
func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
%w = arith.extf %v : f8E5M2FNUZ to f16
@@ -16,8 +17,9 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
// CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
-// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
-// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
+// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to vector<2xf32>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CONVERT]][0] : f32 from vector<2xf32>
+// CHECK: %[[RESULT:.+]] = vector.insert %[[EXTRACT2]], %[[CONST]] [] : f32 into vector<f32>
// CHECK: return %[[RESULT]] : vector<f32>
func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
%w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
@@ -28,15 +30,9 @@ func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
-// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
-// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
-// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
-// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
-// CHECK: return [[W1]] : vector<2xf64>
-
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to vector<2xf32>
+// CHECK: [[EXT:%.+]] = arith.extf [[FLOAT]] : vector<2xf32> to vector<2xf64>
+// CHECK: return [[EXT]] : vector<2xf64>
func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
%w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
return %w : vector<2xf64>
@@ -46,30 +42,22 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
// CHECK-LABEL: func.func @vector_ext_long
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
-// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
-
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: return [[W8]]
+// CHECK: [[W0:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
+// CHECK: [[IN1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[FLOAT1]], [[W0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT2:%.+]] = amdgpu.ext_packed_fp8 [[IN1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[FLOAT2]], [[W1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[FLOAT3:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[FLOAT3]], [[W2]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: return [[W5]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
return %w : vector<9xf32>
@@ -155,31 +143,25 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
// CHECK-LABEL: func.func @vector_ext_long_2d
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
-// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insert [[F0]]
-// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
-// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
-// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
-// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
-// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
-// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
// CHECK: return [[CAST]]
func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
%w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 567e6498330a3..bf312ead32712 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -6,9 +6,9 @@
// CHECK-LABEL: func @ext_packed_fp8
// CHECK: amdgpu.ext_packed_fp8
-func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 {
- %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32
- func.return %ret : f32
+func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
+ func.return %ret : vector<2xf32>
}
// CHECK-LABEL: func @packed_trunc_2xfp8
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index bc917041998d8..cce2c0aee62f3 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -767,10 +767,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: rocdl.cvt.scalef32.f32.fp8
// CHECK: rocdl.cvt.scalef32.pk.f16.bf8
// CHECK: rocdl.cvt.scalef32.pk.f16.fp8
+// CHECK: rocdl.cvt.scalef32.pk.bf16.bf8
+// CHECK: rocdl.cvt.scalef32.pk.bf16.fp8
// CHECK: rocdl.cvt.scalef32.f16.fp8
// CHECK: rocdl.cvt.scalef32.f16.bf8
// CHECK: rocdl.cvt.pk.bf8.f32
// CHECK: rocdl.cvt.pk.fp8.f32
+// CHECK: rocdl.cvt.pk.f32.bf8
+// CHECK: rocdl.cvt.pk.f32.fp8
// CHECK: rocdl.cvt.sr.bf8.f32
// CHECK: rocdl.cvt.sr.fp8.f32
// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
@@ -793,10 +797,14 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
+ %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16>
+ %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16>
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
+ %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32>
+ %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32>
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 11f2faa2761ff..e70617bfff99e 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1042,6 +1042,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.fp8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.f16.bf8(<2 x half> %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 0, i1 false)
+// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
+// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
@@ -1068,6 +1070,8 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
+ %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32
+ %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
>From ae25d7357b40f38caa599af78f41db97516005b7 Mon Sep 17 00:00:00 2001
From: Yi Qian <yi.qian at amd.com>
Date: Wed, 19 Mar 2025 04:45:36 +0000
Subject: [PATCH 2/2] Allow amdgpu.ext_packed_fp8 to return a scalar or vector
type
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 11 ++--
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 26 +++++++---
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 23 +++++----
.../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 48 ++++++++++++++----
.../AMDGPUToROCDL/8-bit-floats.mlir | 48 ++++++++++++++----
.../ArithToAMDGPU/8-bit-floats-ocp.mlir | 45 ++++++++---------
.../ArithToAMDGPU/8-bit-floats.mlir | 50 +++++++++----------
7 files changed, 159 insertions(+), 92 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3ed6e84d19044..c0b3e5540b1df 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -85,12 +85,13 @@ def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
- ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex)>,
- Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> {
- let summary = "Extend a vector of packed fp8 values to two floats";
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+ Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+ let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
let description = [{
- Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them.
+ Extend one or two 8-bit floats in `source[index]` to a 32-bit float or
+ two floats and return them.
This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
@@ -102,7 +103,7 @@ def AMDGPU_ExtPackedFp8Op :
undefined values as needed.
}];
let assemblyFormat = [{
- attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res)
+ attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
}];
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 768d21384412d..3acd470cff7f5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -959,6 +959,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Value source = adaptor.getSource();
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+ auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
@@ -977,13 +978,24 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
- Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
- if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
- wordSel);
- } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
- wordSel);
+ if (resultVecType) {
+ Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
+ wordSel);
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
+ wordSel);
+ }
+ } else {
+ Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
+ byteSel);
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
+ byteSel);
+ }
}
return success();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index f9b685d1e90f6..3596b3235a631 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -113,10 +113,9 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloats =
- rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType, in, 0);
- Value resFloat = rewriter.create<vector::ExtractOp>(loc, asFloats, 0);
- Value result = castF32To(outElemType, resFloat, loc, rewriter);
+ Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+ loc, rewriter.getF32Type(), in, 0);
+ Value result = castF32To(outElemType, asFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
@@ -154,15 +153,17 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, i, elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; j += 2) {
- Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType,
- inSlice, j / 2);
- Type desType = VectorType::get(2, outElemType);
- Value asType = castF32To(desType, asFloats, loc, rewriter);
- if (i + j + 1 < numElements)
+ if (i + j + 1 < numElements) { // Convert two 8-bit elements
+ Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
+ loc, extResType, inSlice, j / 2);
+ Type desType = VectorType::get(2, outElemType);
+ Value asType = castF32To(desType, asFloats, loc, rewriter);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, asType, result, i + j, 1);
- else {
- asType = rewriter.create<vector::ExtractOp>(loc, asType, 0);
+ } else { // Convert a 8-bit element
+ Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+ loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+ Value asType = castF32To(outElemType, asFloat, loc, rewriter);
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
}
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
index 0fb03ff13b558..eb483b0880294 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -7,12 +7,12 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
-// CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> {
- %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to vector<2xf32>
- func.return %ret : vector<2xf32>
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+ func.return %ret : f32
}
// CHECK-LABEL: func @ext_short_vec
@@ -25,22 +25,50 @@ func.func @ext_scalar(%v: f8E5M2) -> vector<2xf32> {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_packed_2xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
// CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
+func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
func.return %ret : vector<2xf32>
}
-// CHECK-LABEL: func @ext_full_vec(
+// CHECK-LABEL: func @ext_packed_4xfp8
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
// CHECK: return [[EXT]] : vector<2xf32>
-
-func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
+func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
func.return %ret : vector<2xf32>
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
index 0a4a960d59ce8..4029d14650d7f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -6,12 +6,12 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : vector<2xf32>
-// CHECK: return [[EXT]]
-func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> {
- %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to vector<2xf32>
- func.return %ret : vector<2xf32>
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
+ func.return %ret : f32
}
// CHECK-LABEL: func @ext_short_vec
@@ -24,22 +24,50 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> vector<2xf32> {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[2] : vector<4xf8E4M3FNUZ> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_packed_2xfp8
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
// CHECK: return [[EXT]]
-func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
+func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
func.return %ret : vector<2xf32>
}
-// CHECK-LABEL: func @ext_full_vec(
+// CHECK-LABEL: func @ext_packed_4xfp8(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
// CHECK: return [[EXT]] : vector<2xf32>
-
-func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
+func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
func.return %ret : vector<2xf32>
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
index b75b69c1b5d27..7fb5fbfe0c89e 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -3,9 +3,8 @@
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to vector<2xf32>
-// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
-// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
// CHECK: return [[W]]
func.func @scalar_ext(%v: f8E5M2) -> f16 {
%w = arith.extf %v : f8E5M2 to f16
@@ -43,9 +42,8 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FN> to vector<2xf32>
// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to vector<2xf32>
-// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
-// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FN> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32>
// CHECK: return [[W5]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
@@ -131,28 +129,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
// -----
// CHECK-LABEL: func.func @vector_ext_long_2d
-// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
-// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
+// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FN>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FN> to vector<11xf8E4M3FN>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32>
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32>
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FN> to vector<4xf8E4M3FN>
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32>
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FN> to vector<2xf32>
-// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FN> to vector<2xf32>
-// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
-// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FN> to vector<3xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FN> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FN> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32>
// CHECK: return [[CAST]]
-func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
- %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
- return %w : vector<1x9xf32>
+func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FN>) -> vector<1x11xf32> {
+ %w = arith.extf %v : vector<1x11xf8E4M3FN> to vector<1x11xf32>
+ return %w : vector<1x11xf32>
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 2ed3f47e8ab73..59ed6bd95ae8b 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -2,9 +2,8 @@
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
-// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to vector<2xf32>
-// CHECK: [[EXT:%.+]] = vector.extract [[FLOAT]][0] : f32 from vector<2xf32>
-// CHECK: [[W:%.+]] = arith.truncf [[EXT]] : f32 to f16
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
// CHECK: return [[W]]
func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
%w = arith.extf %v : f8E5M2FNUZ to f16
@@ -17,9 +16,8 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
// CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
-// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to vector<2xf32>
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CONVERT]][0] : f32 from vector<2xf32>
-// CHECK: %[[RESULT:.+]] = vector.insert %[[EXTRACT2]], %[[CONST]] [] : f32 into vector<f32>
+// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
+// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
// CHECK: return %[[RESULT]] : vector<f32>
func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
%w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
@@ -54,9 +52,8 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
// CHECK: [[FLOAT4:%.+]] = amdgpu.ext_packed_fp8 [[IN2]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[FLOAT4]], [[W3]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
// CHECK: [[IN3:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[FLOAT6:%.+]] = vector.extract [[FLOAT5]][0] : f32 from vector<2xf32>
-// CHECK: [[W5:%.+]] = vector.insert [[FLOAT6]], [[W4]] [8] : f32 into vector<9xf32>
+// CHECK: [[FLOAT5:%.+]] = amdgpu.ext_packed_fp8 [[IN3]][0] : vector<1xf8E4M3FNUZ> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[FLOAT5]], [[W4]] [8] : f32 into vector<9xf32>
// CHECK: return [[W5]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
@@ -142,28 +139,29 @@ func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
// -----
// CHECK-LABEL: func.func @vector_ext_long_2d
-// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
-// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK-SAME: ([[V:%.+]]: vector<1x11xf8E4M3FNUZ>)
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x11xf8E4M3FNUZ> to vector<11xf8E4M3FNUZ>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[F0]], [[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<11xf32>
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[F1]], [[W0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<11xf32>
-// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<9xf32>
+// CHECK: [[W2:%.+]] = vector.insert_strided_slice [[F2]], [[W1]] {offsets = [4], strides = [1]} : vector<2xf32> into vector<11xf32>
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<9xf32>
-
-// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
-// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<1xf8E4M3FNUZ> to vector<2xf32>
-// CHECK: [[E0:%.+]] = vector.extract [[F4]][0] : f32 from vector<2xf32>
-// CHECK: [[W4:%.+]] = vector.insert [[E0]], [[W3]] [8] : f32 into vector<9xf32>
-// CHECK: [[CAST:%.+]] = vector.shape_cast [[W4]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: [[W3:%.+]] = vector.insert_strided_slice [[F3]], [[W2]] {offsets = [6], strides = [1]} : vector<2xf32> into vector<11xf32>
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [3], strides = [1]} : vector<11xf8E4M3FNUZ> to vector<3xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] : vector<3xf8E4M3FNUZ> to vector<2xf32>
+// CHECK: [[W4:%.+]] = vector.insert_strided_slice [[F4]], [[W3]] {offsets = [8], strides = [1]} : vector<2xf32> into vector<11xf32>
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V2]][2] : vector<3xf8E4M3FNUZ> to f32
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] [10] : f32 into vector<11xf32>
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W5]] : vector<11xf32> to vector<1x11xf32>
// CHECK: return [[CAST]]
-func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
- %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
- return %w : vector<1x9xf32>
+func.func @vector_ext_long_2d(%v: vector<1x11xf8E4M3FNUZ>) -> vector<1x11xf32> {
+ %w = arith.extf %v : vector<1x11xf8E4M3FNUZ> to vector<1x11xf32>
+ return %w : vector<1x11xf32>
}
More information about the Mlir-commits
mailing list