[Mlir-commits] [mlir] [mlir][ROCDL] Add fp4 and fp6 conversion intrinsics, fix fp8 immargs (PR #140801)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue May 20 14:11:10 PDT 2025
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/140801
This PR adds support for the scaled conversion intrinsics for fp4 and fp6 types so that they can be targetted by a future amdgpu dialect op or used directly.
Additionally, this patch refactors the copy-paste-heavy fp8 versions of these scaled conversion intrinsics with tablegen `foreach` loops, and fixes the fact that certain immargs weren't being stored as attributes.
Note that some of the MLIR-level tests for those scaled fp8 intrinsics had incorrect return types, which have been fixed.
(Note that while the operations have a known return type, the IR format still prints that type for clarity).
>From f5cc3663fb1045cc51a54f1c6ae3ff4a16ea4807 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 20 May 2025 16:01:46 +0000
Subject: [PATCH] [mlir][ROCDL] Add fp4 and fp6 conversion intrinsics, fix fp8
immargs
This PR adds support for the scaled conversion intrinsics for fp4 and
fp6 types so that they can be targetted by a future amdgpu dialect op
or used directly.
Additionally, this patch refactors the copy-paste-heavy fp8 versions
of these scaled conversion intrinsics with tablegen `foreach` loops,
and fixes the fact that certain immargs weren't being stored as
attributes.
Note that some of the MLIR-level tests for those scaled fp8 intrinsics
had incorrect return types, which have been fixed.
(Note that while the operations have a known return type, the IR
format still prints that type for clarity).
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 686 +++++++++---------
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 24 +-
.../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 30 +-
.../AMDGPUToROCDL/8-bit-floats.mlir | 30 +-
mlir/test/Dialect/LLVMIR/rocdl.mlir | 146 +++-
mlir/test/Target/LLVMIR/rocdl.mlir | 140 +++-
6 files changed, 606 insertions(+), 450 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6fb9e3aba1f0a..1dadb7d9e8852 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -709,20 +709,23 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
}];
}
-def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
- BuildableType<"::mlir::VectorType::get("
- "{2},$_builder.getI16Type())">;
+class ROCDL_ConcreteVector<Type elem, int length> :
+ FixedVectorOfLengthAndType<[length], [elem]>,
+ BuildableType<
+ "::mlir::VectorType::get({" # length # "} ,"
+ # elem.builderCall # ")">;
+
+def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
+def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
+def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
+def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
+def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
+def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
+def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
-def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
- BuildableType<"::mlir::VectorType::get("
- "{2},$_builder.getF16Type())">;
-
-def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
- BuildableType<"::mlir::VectorType::get("
- "{2},$_builder.getBF16Type())">;
-
-// TODO: The word and byte selectors are immarg in LLVM
-// update to be attributes in MLIR
//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
@@ -738,279 +741,12 @@ def ROCDL_CvtPkRtz:
}];
}
-def ROCDL_CvtScaleF32PkFp8F16Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert f16 to packed fp8";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed fp8.
- Store the result in low/high word of `old` based on $wordSel, preserving the other word.
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32PkFp8Bf16Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert packed bf16 to packed fp8";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed fp8.
- Store the result in low/high word of `old` based on $wordSel, preserving the other word.
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
- }];
-}
-
-
-def ROCDL_CvtScaleF32PkBf8F16Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert f16 to packed bf8";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed bf8.
- Store the result in low/high word of `old` based on $wordSel, preserving the other word.
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
- }];
-}
-
-
-def ROCDL_CvtScaleF32PkBf8Bf16Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert bf16 to packed bf8";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed bf8.
- Store the result in low/high word of `old` based on $wordSel, preserving the other word.
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32SrFp8F16Op :
- ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
- Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
- using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
-
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32SrBf8F16Op :
- ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
- Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
- using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
-
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32SrFp8Bf16Op :
- ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
- Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
- using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
-
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32SrBf8Bf16Op :
- ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
- Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
- let description = [{
- Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
- using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
-
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32PkF16Fp8Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Convert fp8 to packed f16 and scale";
- let description = [{ Convert `src` based on $wordSel to packed f16, then scale
- the packed values by the exponent in `scale`.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32PkF16Bf8Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "convert bf8 to packed f16 and scale";
- let description = [{ Convert `src` based on $wordSel to packed f16, then scale
- the packed values by exponent in `scale`.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32PkBf16Fp8Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Convert fp8 to packed bf16 and scale";
- let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
- the packed values by the exponent in `scale`.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32PkBf16Bf8Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Convert bf8 to packed bf16 and scale";
- let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
- the packed values by the exponent in `scale`.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF16Fp8Op :
- ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
- let summary = "Scale and convert fp8 to f16";
- let description = [{ Convert `src` based on $wordSel to f16, then scale the value
- by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
- preserving the others.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF16Bf8Op :
- ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
- let summary = "Scale and convert fp8 to f16";
- let description = [{ Convert `src` based on $wordSel to f16, then scale the value
- by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
- preserving the others.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-//===---------------------------------------------------------------------===//
-// 32-bit float intrinsics
-//===---------------------------------------------------------------------===//
-def ROCDL_CvtScaleF32PkF32Fp8Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert packed fp8 to packed f32";
- let description = [{
- Convert `src` based on $wordSel to packed fp32, then scale the packed values by
- the exponent in `scale`. Store the result in a vector.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
- }];
-}
-def ROCDL_CvtScaleF32PkF32Bf8Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert packed bf8 to packed f32";
- let description = [{
- Convert `src` based on $wordSel to packed fp32, then scale the packed values by
- the exponent in `scale`. Store the result in a vector.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
- }];
-}
-//===---------------------------------------------------------------------===//
-// 8-bit float scale intrinsics
-//===---------------------------------------------------------------------===//
-def ROCDL_CvtScaleF32PkFp8F32Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
- let summary = "Scale and convert two f32's to packed fp8";
- let description = [{
- Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed fp8
- and store into the low/high word of `old`, preserving the other word.
- }];
- let assemblyFormat = [{
- attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32PkBf8F32Op :
- ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
- Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert two f32's to packed bf8";
- let description = [{
- Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed bf8
- and store into the low/high word of `old`, preserving the other word.
- }];
- let assemblyFormat = [{
- attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
- }];
-}
-
-def ROCDL_CvtScaleF32SrFp8F32Op :
- ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
- Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert f32 to fp8 using stochastic rounding";
- let description = [{
- Scale `src` by the exponent in `scale` then convert to fp8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
-
-def ROCDL_CvtScaleF32SrBf8F32Op :
- ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
- Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert f32 to bf8 using stochastic rounding";
- let description = [{
- Scale `src` by the exponent in `scale` then convert to bf8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
- }];
- let assemblyFormat = [{
- attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
- }];
-}
-
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtF32Bf8Op :
- ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
- Arguments<(ins I32:$srcA, I32:$byteSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.f32.bf8", [Pure], 1, [1], ["byteSel"]>,
+ Arguments<(ins I32:$srcA, I32Attr:$byteSel)> {
let summary = "Convert bf8 to f32";
let description = [{
Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
@@ -1020,23 +756,9 @@ def ROCDL_CvtF32Bf8Op :
}];
}
-def ROCDL_CvtScaleF32Bf8Op :
- ROCDL_IntrOp<"cvt.scalef32.f32.bf8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert bf8 to f32";
- let description = [{
- Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
- from the `byteSel`th bit of `src` to fp32.
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
- }];
-}
-
-
def ROCDL_CvtF32Fp8Op :
- ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
- Arguments<(ins I32:$srcA, I32:$byteSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.f32.fp8", [Pure], 1, [1], ["byteSel"]>,
+ Arguments<(ins I32:$srcA, I32Attr:$byteSel)> {
let summary = "Convert fp8 to f32";
let description = [{
Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
@@ -1046,24 +768,9 @@ def ROCDL_CvtF32Fp8Op :
}];
}
-
-def ROCDL_CvtScaleF32Fp8Op :
- ROCDL_IntrOp<"cvt.scalef32.f32.fp8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
- let summary = "Scale and convert fp8 to f32";
- let description = [{
- Scale `src` by the exponent in `scale` then convert 8-bit fp8 value
- from the `byteSel`th bit of `src` to fp32.
-
- }];
- let assemblyFormat = [{
- attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
- }];
-}
-
def ROCDL_CvtPkF32Fp8Op :
- ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, I1:$wordSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.fp8", [Pure], 1, [1], ["wordSel"]>,
+ Arguments<(ins I32:$src, I1Attr:$wordSel)> {
let summary = "Convert packed fp8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32.
@@ -1074,8 +781,8 @@ def ROCDL_CvtPkF32Fp8Op :
}
def ROCDL_CvtPkF32Bf8Op :
- ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
- Arguments<(ins I32:$src, I1:$wordSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.bf8", [Pure], 1, [1], ["wordSel"]>,
+ Arguments<(ins I32:$src, I1Attr:$wordSel)> {
let summary = "Convert packed bf8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32,
@@ -1086,8 +793,8 @@ def ROCDL_CvtPkF32Bf8Op :
}
def ROCDL_CvtPkBf8F32Op :
- ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
- Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.pk.bf8.f32", [Pure], 1, [3], ["wordSel"]>,
+ Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> {
let summary = "Convert two f32's to bf8";
let description = [{
Convert `srcA` and `srcB` to bf8 and store into the low/high word of
@@ -1099,8 +806,8 @@ def ROCDL_CvtPkBf8F32Op :
}
def ROCDL_CvtPkFp8F32Op :
- ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
- Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.pk.fp8.f32", [Pure], 1, [3], ["wordSel"]>,
+ Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> {
let summary = "Convert two f32's to fp8";
let description = [{
Convert `srcA` and `srcB` to fp8 and store into the low/high word of
@@ -1112,8 +819,8 @@ def ROCDL_CvtPkFp8F32Op :
}
def ROCDL_CvtSrBf8F32Op :
- ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
- Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.sr.bf8.f32", [Pure], 1, [3], ["byteSel"]>,
+ Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> {
let summary = "Convert f32 to bf8, stochiastic rounding";
let description = [{
Convert `srcA` to bf8, adding the rounding factor from `srcB`,
@@ -1125,8 +832,8 @@ def ROCDL_CvtSrBf8F32Op :
}
def ROCDL_CvtSrFp8F32Op :
- ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
- Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
+ ROCDL_ConcreteNonMemIntrOp<"cvt.sr.fp8.f32", [Pure], 1, [3], ["byteSel"]>,
+ Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> {
let summary = "Convert f32 to fp8, stochiastic rounding";
let description = [{
Convert `srcA` to fp8, adding the rounding factor from `srcB`,
@@ -1137,6 +844,335 @@ def ROCDL_CvtSrFp8F32Op :
}];
}
+//===---------------------------------------------------------------------===//
+// Scaled float conversion intrinsics
+//
+// These are using some tablegen trickery to avoid repetitive documentation
+//===---------------------------------------------------------------------===//
+
+// Pair used so we can iterate over types..
+class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
+ TypeConstraint type = argTyVal;
+ string name = !tolower(typeName);
+ string nameForOp = typeName;
+}
+
+//===---------------------------------------------------------------------===//
+// Scaled 32x6-bit float float conversion intrinsics
+//===---------------------------------------------------------------------===//
+foreach smallT = [
+ // MLIR f6E2M3FN
+ ScaleArgInfo<ROCDL_V6I32Type, "Fp6">,
+ // MLIR f8E3M2FN
+ ScaleArgInfo<ROCDL_V6I32Type, "Bf6">
+] in {
+ foreach largeT = [
+ ScaleArgInfo<ROCDL_V32F16Type, "F16">,
+ ScaleArgInfo<ROCDL_V32BF16Type, "Bf16">,
+ ScaleArgInfo<ROCDL_V32F32Type, "F32">,
+ ] in {
+ // Note: rouding down f32 values has a special case where
+ // we have to use 2 16xf32 arguments.
+ if !ne(largeT.name, "f32") then {
+ def ROCDL_CvtScaleF32Pk32 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name;
+ let description = [{
+ Convert 32 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, dividing by the exponent part of `scale`
+ before doing so.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `:` type($res)
+ }];
+ }
+ } // if
+
+ def ROCDL_CvtScaleF32SrPk32 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk32." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name
+ # " with stochiastic rounding";
+ let description = [{
+ Convert 32 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, dividing by the exponent part of `scale`
+ before doing so and applying random rounding derived from
+ `seed`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `:` type($res)
+ }];
+ }
+
+ def ROCDL_CvtScaleF32Pk32 # largeT.nameForOp # smallT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # largeT.name # "." # smallT.name,
+ [Pure], 1>,
+ Arguments<(ins smallT.type:$src, F32:$scale)> {
+ let results = (outs largeT.type:$res);
+ let summary = "Scale and convert packed "
+ # smallT.name # " to packed " # largeT.name;
+ let description = [{
+ Convert 32 packed }] # smallT.name # [{ values to packed }]
+ # largeT.name # [{, multiplying by the exponent part of `scale`
+ before doing so.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `:` type($res)
+ }];
+ }
+ } // foreach largeT
+
+ def ROCDL_CvtScaleF322xPk16 # smallT.nameForOp # F32Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.2xpk16." # smallT.name # ".f32",
+ [Pure], 1>,
+ Arguments<(ins ROCDL_V16F32Type:$src0, ROCDL_V16F32Type:$src1, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert two vector<16xf32> to 32 packed " # smallT.name;
+ let description = [{
+ Convert 32 single-precision float values, packed into two length-16
+ vectors that will be logically concanenated, to packed }]
+ # smallT.name # [{, dividing by the exponent part of `scale`
+ before doing so.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $src1 `,` $scale `:` type($res)
+ }];
+ }
+} // forach smallT
+
+//===---------------------------------------------------------------------===//
+// Scaled conversions to/from fp8/bf8 (f8E4M3FN / f8E5M2)
+//===---------------------------------------------------------------------===//
+foreach smallTOp = ["Fp8", "Bf8"] in {
+ defvar smallT = !tolower(smallTOp);
+
+ def ROCDL_CvtScaleF32F16 # smallTOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f16." # smallT,
+ [Pure], 1, [3, 4], ["srcSelIndex", "dstLoHiSel"]>,
+ Arguments<(ins ROCDL_V2F16Type:$oldVdst, I32:$src, F32:$scale, I32Attr:$srcSelIndex, I1Attr:$dstLoHiSel)> {
+ let results = (outs ROCDL_V2F16Type:$res);
+ let summary = "Scaled convert " # smallT # " from packed vector to f16, updating tied result";
+ let description = [{
+ Convert a }] # smallT # [{ byte from `src`, selected by
+ `srcSelIndex`, to f16 while multiplying it by the expontent of `scale`,
+ and place it into the `dstLoHiSel`th bit
+ of `oldVdst` preserving the other element of that vector in
+ the return value.
+
+ The bytes are stored as an `i32` and not a `<4 x i8>`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $srcSelIndex `]` `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
+ }];
+ }
+
+ def ROCDL_CvtScaleF32F32 # smallTOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f32." # smallT,
+ [Pure], 1, [2], ["srcSelIndex"]>,
+ Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> {
+ let results = (outs F32:$res);
+ let summary = "Scaled convert " # smallT # " from packed vector to f32";
+ let description = [{
+ Convert a }] # smallT # [{ byte from `src`, selected by
+ `srcSelIndex`, to f32, multiplying it by the exponent of `scale`.
+
+ The bytes are stored in an `i32`, not a `<4 x i8>`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res)
+ }];
+ }
+
+ def ROCDL_CvtScaleF32Pk # smallTOp # F32Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # ".f32",
+ [Pure], 1, [4], ["dstLoHiSel"]>,
+ Arguments<(ins ROCDL_V2I16Type:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I1Attr:$dstLoHiSel)> {
+ let results = (outs ROCDL_V2I16Type:$res);
+ let summary = "Scaled convert two f32 to two " # smallT # ", updating packed vector";
+ let description = [{
+ Convert two f32 values in `src0` and `src1` to two }] # smallT # [{ bytes,
+ dividing by the exponent in `scale`. The bytes are packed into
+ a 16-bit value which is inserted into `oldVdst` at the `dstLoHiSel`
+ position, with the entire updated vector being returned.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
+ }];
+ }
+
+ foreach largeT = [
+ ScaleArgInfo<ROCDL_V2F16Type, "F16">,
+ ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
+ ] in {
+ def ROCDL_CvtScaleF32Pk # smallTOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # "." # largeT.name,
+ [Pure], 1, [3], ["dstLoHiSel"]>,
+ Arguments<(ins ROCDL_V2I16Type:$oldVdst, largeT.type:$src0, F32:$scale, I1Attr:$dstLoHiSel)> {
+ let results = (outs ROCDL_V2I16Type:$res);
+ let summary = "Scaled convert two " # largeT.name # "to two " # smallT # ", updating packed vector";
+ let description = [{
+ Convert two }] # largeT.name # [{ values in `src0` to two }]
+ # smallT # [{ bytes, dividing by the exponent in `scale`. The bytes are
+ packed into a 16-bit value which is inserted into `oldVdst` at the
+ `dstLoHiSel` position, with the entire updated vector being returned.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
+ }];
+ }
+ } // foreach largeT
+
+ foreach largeT = [
+ ScaleArgInfo<ROCDL_V2F16Type, "F16">,
+ ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
+ ScaleArgInfo<ROCDL_V2F32Type, "F32">
+ ] in {
+ def ROCDL_CvtScaleF32Pk # largeT.nameForOp # smallTOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # "." # smallT,
+ [Pure], 1, [2], ["srcLoHiSel"]>,
+ Arguments<(ins I32:$src, F32:$scale, I1Attr:$srcLoHiSel)> {
+ let results = (outs largeT.type:$res);
+ let summary = "Scaled convert two " # smallT # "to two " # largeT.name #;
+ let description = [{
+ Convert two packed }] # smallT # [{ values in `src0` to two }]
+ # largeT.name # [{ values, multiplying by the exponent in `scale`.
+ The two values to be converted are selected from the low or high half
+ of `src` (a packed vector represented as an `i32`)
+ on the basis of `srcLoHiSel`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $srcLoHiSel `]` `,` $scale `:` type($res)
+ }];
+ }
+ } // foreach largeT
+
+ foreach largeT = [
+ ScaleArgInfo<F32, "F32">,
+ ScaleArgInfo<F16, "F16">,
+ ScaleArgInfo<BF16, "BF16">
+ ] in {
+ def ROCDL_CvtScaleF32Sr # smallTOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr." # smallT # "." # largeT.name,
+ [Pure], 1, [4], ["dstSelIndex"]>,
+ Arguments<(ins I32:$oldVdst, largeT.type:$src0, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> {
+ let results = (outs I32:$res);
+ let summary = "Scaled convert " # largeT.name # "to " # smallT # " with stochiastic rounding, updating packed vector";
+ let description = [{
+ Convert a }] # largeT.name # [{ value in `src0` to a }]
+ # smallT # [{ bytes, dividing by the exponent in `scale` and using `seed`
+ for stochiastic rounding. Place the resulting byte in the
+ `dstSelIndex`th bit of `oldVdst` and return the entire packed vector,
+ which is stored as an `i32`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
+ }];
+ }
+ } // foreach largeT
+} // foreach smallTOp
+
+//===---------------------------------------------------------------------===//
+// Scaled conversions to/from fp4 (f4E2M1FN)
+//===---------------------------------------------------------------------===//
+
+foreach largeT = [
+ ScaleArgInfo<ROCDL_V2F16Type, "F16">,
+ ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
+ ScaleArgInfo<ROCDL_V2F32Type, "F32">,
+] in {
+ // Note: rouding down f32 values has a special case where
+ // we have to use 2 float arguments.
+ if !ne(largeT.name, "f32") then {
+ def ROCDL_CvtScaleF32PkFp4 # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4." # largeT.name,
+ [Pure], 1, [3], ["dstSelIndex"]>,
+ Arguments<(ins I32:$oldVdst, largeT.type:$src, F32:$scale, I32Attr:$dstSelIndex)> {
+ let results = (outs I32:$res);
+ let summary = "Scale and convert two "
+ # largeT.name # " to packed fp4, updating tied vector";
+ let description = [{
+ Convert two packed }] # largeT.name # [{ values to packed
+ fp4, dividing by the exponent part of `scale`
+ before doing so.
+
+ The two scaled values are packed into a byte.
+ That byte is used to update the `dstSelIndex`th
+ byte of `oldVdst`, which is returned in its entirity.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
+ }];
+ }
+ } // if
+
+ def ROCDL_CvtScaleF32SrPkFp4 # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk.fp4." # largeT.name,
+ [Pure], 1, [4], ["dstSelIndex"]>,
+ Arguments<(ins I32:$oldVdst, largeT.type:$src, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> {
+ let results = (outs I32:$res);
+ let summary = "Scale and convert two "
+ # largeT.name # " to packed fp4 with stochiastic rounding, updating tied vector";
+ let description = [{
+ Convert two packed }] # largeT.name # [{ values to packed
+ fp4, dividing by the exponent part of `scale`
+ before doing so and using `seed` as the random seed for
+ stochiastic rounding.
+
+ The two scaled values are packed (little-endian)
+ into a byte. That byte is used to update the `dstSelIndex`th
+ byte of `oldVdst`, which is returned in its entirity.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
+ }];
+ }
+
+ def ROCDL_CvtScaleF32Pk # largeT.nameForOp # Fp4Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # ".fp4",
+ [Pure], 1, [2], ["srcSelIndex"]>,
+ Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> {
+ let results = (outs largeT.type:$res);
+ let summary = "Scale and convert two packed fp4 to packed " # largeT.name;
+ let description = [{
+ Convert two packed fp4 (f4E2M1) values stored as one byte of a 32-bit integer
+ to packed }] # largeT.name # [{, multiplying by the exponent part of `scale`
+ before doing so.
+
+ The byte to convert is chosen by `srcSelIndex`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res)
+ }];
+ }
+} // foreach largeT
+
+def ROCDL_CvtScaleF32PkFp4F32Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4.f32",
+ [Pure], 1, [4], ["dstSelIndex"]>,
+ Arguments<(ins I32:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I32Attr:$dstSelIndex)> {
+ let results = (outs I32:$res);
+ let summary = "Scale and convert two f32 values to two packed fp4, updating tied vector";
+ let description = [{
+ Convert two single-precision float values, passed in `src0` and `src1`
+ into two fp4 values, dividing them by the expontent part of `scale`
+ before doing so.
+
+ The two scaled values are packed into a byte.
+ That byte is used to update the `dstSelIndex`th
+ byte of `oldVdst`, which is returned in its entirity.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ROCDL target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d0093b8dc8c2a..b62c8e7d7eec5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1208,22 +1208,20 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
if (resultVecType) {
- Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
- wordSel);
+ op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
- wordSel);
+ op.getIndex());
}
} else {
- Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
- byteSel);
+ op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
- byteSel);
+ op.getIndex());
}
}
return success();
@@ -1251,15 +1249,14 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
- Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
- existing, wordSel);
+ existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
- existing, wordSel);
+ existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1286,15 +1283,14 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
- Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
- existing, byteSel);
+ result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
+ loc, i32, source, stoch, existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
- existing, byteSel);
+ result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
+ loc, i32, source, stoch, existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
index ea0c3afbd9021..464d47216c81b 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -7,8 +7,7 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_scalar(%v: f8E5M2) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
@@ -25,8 +24,7 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
@@ -36,8 +34,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
// CHECK-LABEL: func @ext_full_vec(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
@@ -54,8 +51,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32>
// CHECK: return [[EXT]]
func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
@@ -65,8 +61,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
// CHECK-LABEL: func @ext_packed_4xfp8
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32>
// CHECK: return [[EXT]] : vector<2xf32>
func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
@@ -77,8 +72,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
-// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
@@ -89,8 +83,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
// CHECK-LABEL: func @packed_truncx2
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
-// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
@@ -102,8 +95,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
-// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
@@ -114,8 +106,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) ->
// CHECK-LABEL: func @packed_stoch_round
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
-// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
@@ -127,8 +118,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
-// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
index 219f822ca9a1c..03fcb266a2e87 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -6,8 +6,7 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
-// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
@@ -24,8 +23,7 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
@@ -35,8 +33,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
// CHECK-LABEL: func @ext_full_vec
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
@@ -53,8 +50,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
-// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32>
// CHECK: return [[EXT]]
func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
@@ -64,8 +60,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
// CHECK-LABEL: func @ext_packed_4xfp8(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
-// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
-// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
+// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32>
// CHECK: return [[EXT]] : vector<2xf32>
func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
@@ -76,8 +71,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
-// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
@@ -88,8 +82,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
// CHECK-LABEL: func @packed_truncx2
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
-// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
@@ -101,8 +94,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
-// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
@@ -113,8 +105,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>
// CHECK-LABEL: func @packed_stoch_round
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
-// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
@@ -126,8 +117,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
-// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index fbde993891342..0503c2a15860b 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -789,36 +789,32 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
- %c0 = llvm.mlir.constant(0 : i32) : i32
- %c2 = llvm.mlir.constant(2 : i32) : i32
- %c3 = llvm.mlir.constant(3 : i32) : i32
%c4 = llvm.mlir.constant(1.0 : f32) : f32
- %false = llvm.mlir.constant(false) : i1
- %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
- %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
- %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
- %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
- %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
- %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
- %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16>
- %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16>
- %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
- %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
- %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
- %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
- %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32>
- %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32>
- %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
- %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
- %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
- %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
- %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
- %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
- %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
- %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
- %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
- %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
- %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
+ %v1 = rocdl.cvt.f32.bf8 %source[0] : f32
+ %v2 = rocdl.cvt.f32.fp8 %source[0] : f32
+ %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32
+ %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+ %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16>
+ %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16>
+ %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16>
+ %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16>
+ %v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16>
+ %v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16>
+ %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32
+ %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32
+ %source2_ext = rocdl.cvt.pk.f32.bf8 %source[false] : vector<2xf32>
+ %source3_ext = rocdl.cvt.pk.f32.fp8 %source[false] : vector<2xf32>
+ %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32
+ %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32
+ %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32
+ %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32
+ %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32
+ %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32
+ %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32
+ %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32
+ %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32
+ %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32>
+ %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32>
llvm.return %source5 : i32
}
@@ -826,9 +822,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
// CHECK-LABEL: @rocdl_8bit_packed_v2i16
// CHECK: rocdl.cvt.scalef32.pk.fp8.f32
%c0 = llvm.mlir.constant(1.0 : f32) : f32
- %false = llvm.mlir.constant(false) : i1
- %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
- %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
+ %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
+ %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
@@ -836,14 +831,91 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
// CHECK-LABEL: @rocdl_v2f16_v2i16
// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
%c0 = llvm.mlir.constant(1.0 : f32) : f32
- %false = llvm.mlir.constant(false) : i1
- %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
- %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
- %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
- %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+ %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16>
+ %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
+ %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16>
+ %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
+// CHECK-LABEL: @rocdl_6_bit_floats
+// CHECK-SAME: (%[[V32F6:.+]]: vector<6xi32>, %[[V16F32:.+]]: vector<16xf32>, %[[V32F32:.+]]: vector<32xf32>, %[[V32F16:.+]]: vector<32xf16>, %[[V32BF16:.+]]: vector<32xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32)
+llvm.func @rocdl_6_bit_floats(
+ %v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>,
+ %v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32,
+ %scale: f32) {
+ // CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.bf6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]]
+ %f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.fp6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]]
+ %f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.f16 %[[V32F16]], %[[SCALE]]
+ %f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.f16 %[[V32F16]], %[[SCALE]]
+ %f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.bf16 %[[V32BF16]], %[[SCALE]]
+ %bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.bf16 %[[V32BF16]], %[[SCALE]]
+ %bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32>
+
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.bf6 %[[V32F6]], %[[SCALE]]
+ %bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.fp6 %[[V32F6]], %[[SCALE]]
+ %fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.bf6 %[[V32F6]], %[[SCALE]]
+ %bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.fp6 %[[V32F6]], %[[SCALE]]
+ %fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.bf6 %[[V32F6]], %[[SCALE]]
+ %bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.fp6 %[[V32F6]], %[[SCALE]]
+ %fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16>
+
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]]
+ %f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]]
+ %f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]]
+ %f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]]
+ %f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]]
+ %bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]]
+ %bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @rocdl_4_bit_floats
+// CHECK-SAME: (%[[V8F4:.+]]: i32, %[[F32:.+]]: f32, %[[V2F32:.+]]: vector<2xf32>, %[[V2F16:.+]]: vector<2xf16>, %[[V2BF16:.+]]: vector<2xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32)
+llvm.func @rocdl_4_bit_floats(
+ %v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>,
+ %v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) {
+
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f32 %[[F32]], %[[F32]], %[[SCALE]] -> %[[V8F4]][0]
+ %f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f16 %[[V2F16]], %[[SCALE]] -> %[[V8F4]][1]
+ %f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.bf16 %[[V2BF16]], %[[SCALE]] -> %[[V8F4]][0]
+ %bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32
+
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk.f32.fp4 %[[V8F4]][0], %[[SCALE]]
+ %fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk.f16.fp4 %[[V8F4]][1], %[[SCALE]]
+ %fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16>
+ // CHECK-NEXT: rocdl.cvt.scalef32.pk.bf16.fp4 %[[V8F4]][0], %[[SCALE]]
+ %fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16>
+
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f32 %[[V2F32]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0]
+ %f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f16 %[[V2F16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][1]
+ %f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32
+ // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.bf16 %[[V2BF16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0]
+ %bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32
+
+ llvm.return
+}
+
llvm.func @rocdl.s.waitcnt() {
// CHECK-LABEL: rocdl.s.waitcnt
// CHECK: rocdl.s.waitcnt 0
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index b37f0da361950..a6a03c586dd25 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1081,34 +1081,31 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
- %c0 = llvm.mlir.constant(0 : i32) : i32
- %c2 = llvm.mlir.constant(2 : i32) : i32
- %c3 = llvm.mlir.constant(3 : i32) : i32
%c4 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
- %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
- %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
- %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
- %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
- %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
- %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
- %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
- %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
- %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32
- %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32
- %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
- %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
- %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
- %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
- %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
- %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
- %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
- %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
- %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
- %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
- %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
- %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
- %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
+ %v1 = rocdl.cvt.f32.bf8 %source[0] : f32
+ %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32
+ %v2 = rocdl.cvt.f32.fp8 %source[0] : f32
+ %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+ %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16>
+ %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16>
+ %v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %source_packed[false] : vector<2xf16>
+ %v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %source_packed[false] : vector<2xf16>
+ %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16>
+ %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16>
+ %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32
+ %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32
+ %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32
+ %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32
+ %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32
+ %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32
+ %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32
+ %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32
+ %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32
+ %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32
+ %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32
+ %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32>
+ %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32>
llvm.return %source5 : i32
}
@@ -1117,9 +1114,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
%c0 = llvm.mlir.constant(1.0 : f32) : f32
- %false = llvm.mlir.constant(false) : i1
- %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
- %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
+ %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
+ %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
@@ -1130,11 +1126,10 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
%c0 = llvm.mlir.constant(1.0 : f32) : f32
- %false = llvm.mlir.constant(false) : i1
- %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
- %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
- %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
- %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
+ %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16>
+ %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
+ %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16>
+ %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
@@ -1145,6 +1140,83 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
llvm.return %source : vector<2xf16>
}
+// CHECK-LABEL: @rocdl_6_bit_floats
+// CHECK-SAME: (<6 x i32> %[[V32F6:.+]], <16 x float> %[[V16F32:.+]], <32 x float> %[[V32F32:.+]], <32 x half> %[[V32F16:.+]], <32 x bfloat> %[[V32BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl_6_bit_floats(
+ %v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>,
+ %v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32,
+ %scale: f32) {
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.bf6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]])
+ %f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.fp6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]])
+ %f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.f16(<32 x half> %[[V32F16]], float %[[SCALE]])
+ %f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.f16(<32 x half> %[[V32F16]], float %[[SCALE]])
+ %f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]])
+ %bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]])
+ %bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32>
+
+ // CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
+ %bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32>
+ // CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
+ %fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32>
+ // CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
+ %bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16>
+ // CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
+ %fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16>
+ // CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
+ %bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16>
+ // CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
+ %fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16>
+
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]])
+ %f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]])
+ %f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]])
+ %f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]])
+ %f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
+ // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @rocdl_4_bit_floats
+// CHECK-SAME: (i32 %[[V8F4:.+]], float %[[F32:.+]], <2 x float> %[[V2F32:.+]], <2 x half> %[[V2F16:.+]], <2 x bfloat> %[[V2BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl_4_bit_floats(
+ %v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>,
+ %v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) {
+
+ // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %[[V8F4]], float %[[F32]], float %[[F32]], float %[[SCALE]], i32 0)
+ %f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32
+ // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], float %[[SCALE]], i32 1)
+ %f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32
+ // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], float %[[SCALE]], i32 0)
+ %bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32
+
+ // CHECK-NEXT: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0)
+ %fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32>
+ // CHECK-NEXT: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 1)
+ %fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16>
+ // CHECK-NEXT: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0)
+ %fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16>
+
+ // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %[[V8F4]], <2 x float> %[[V2F32]], i32 %[[SEED]], float %[[SCALE]], i32 0)
+ %f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32
+ // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], i32 %[[SEED]], float %[[SCALE]], i32 1)
+ %f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32
+ // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], i32 %[[SEED]], float %[[SCALE]], i32 0)
+ %bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32
+
+ llvm.return
+}
llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
// CHECK-LABEL: @rocdl_atomic_attrs
// CHECK: atomicrmw
More information about the Mlir-commits
mailing list