[Mlir-commits] [mlir] 2ebd633 - [mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Sep 28 07:44:22 PDT 2023


Author: Krzysztof Drewniak
Date: 2023-09-28T14:44:16Z
New Revision: 2ebd633f145615a42d7e8b1d07cbdad294c244aa

URL: https://github.com/llvm/llvm-project/commit/2ebd633f145615a42d7e8b1d07cbdad294c244aa
DIFF: https://github.com/llvm/llvm-project/commit/2ebd633f145615a42d7e8b1d07cbdad294c244aa.diff

LOG: [mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering

Define operations that wrap the gfx940's new operations for converting
between f32 and registers containing packed sets of four 8-bit floats.

Define rocdl operations for the intrinsics and an AMDGPU dialect
wrapper around them (to account for the fact that MLIR distinguishes
the two float formats at the type level but that the LLVM IR does
not).

Define an ArithToAMDGPU pass, meant to run before conversion to LLVM,
that replaces relevant calls to arith.extf and arith.truncf with the
packed operations in the AMDGPU dialect. Note that the conversion
currently only handles scalars and vectors of rank <= 1, as we do not
have a usecase for multi-dimensional vector support right now.

Reviewed By: jsjodin

Differential Revision: https://reviews.llvm.org/D152457

Added: 
    mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
    mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
    mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
    mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
    mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Dialect/AMDGPU/invalid.mlir
    mlir/test/Dialect/AMDGPU/ops.mlir
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
new file mode 100644
index 000000000000000..7f445fee5ba6b82
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -0,0 +1,27 @@
+//===- ArithToAMDGPU.h - Arith to AMDGPU dialect conversion ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
+#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
+
+#include <memory>
+
+namespace mlir {
+
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOAMDGPUCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 41806004fc1dca8..e714f5070f23db8 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 9b7848d9288be43..38b05c792d405ad 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -112,6 +112,21 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> {
                         "Chipset that these operations will run on">];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToAMDGPU
+//===----------------------------------------------------------------------===//
+def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
+  let summary = "Convert Arith operations to AMDGPU-specific implementations";
+  let description = [{
+    Convert `arith` operations (currently extf and truncf on 8-bit floats)
+    to operations in the `amdgpu` dialect. This pass is done in two steps
+    in order to avoid running a notional arith-to-rocdl and arith-to-llvm
+    simultaniously.
+  }];
+
+  let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArithToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 6d788e3a970108e..ffb302fcedd732c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -38,6 +38,85 @@ def AMDGPU_Dialect : Dialect {
 class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
   Op<AMDGPU_Dialect, mnemonic, traits> {}
 
+def AMDGPU_ExtPackedFp8Op :
+    AMDGPU_Op<"ext_packed_fp8", [Pure]>,
+    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
+        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+    Results<(outs F32:$res)> {
+  let summary = "Extend one of a vector of packed fp8 values to a float";
+  let description = [{
+    Extend the value `source[index]` to a 32-bit float and return it.
+
+    This rather unusual signature arises from the fact that AMD GPUs cannot
+    easily work with sub 32-bit quantities, so the compiler intrinsics for
+    extending 8-bit floats (which are, currently, the only way to work with
+    this operation) take packed vectors of 4 such floats.
+
+    If the passed-in vector has fewer than four elements, or the input is scalar,
+    the remaining values in the <4 x i8> will be filled with with
+    undefined values as needed.
+  }];
+  let assemblyFormat = [{
+    attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
+  }];
+}
+
+def AMDGPU_PackedTrunc2xFp8Op :
+    AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
+    Arguments<(ins F32:$sourceA,
+      Optional<F32>:$sourceB,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+  let summary = "Round two floats into a packed vector of 8-bit floats";
+  let description = [{
+    Round the inputs `sourceA` and `sourceB` (which is undefined if not
+    specified) into the low or high word (bottom two or top two) elements
+    of the returned vector, keeping the other two elements of `existing`
+    unchanged if present (or undefined if it was not passed in).
+
+    The reason for this odd signature is that AMD GPUs cannot easily work with
+    sub-registers, and so the conversion intrinsics (which are currently the
+    only way to work with 8-bit float types) take packed vectors of 4 8-bit
+    values.
+  }];
+  let assemblyFormat = [{
+    attr-dict $sourceA `,` ($sourceB^):(`undef`)?
+    `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
+    `:` type($sourceA) `to` type($res) (`into` type($existing)^)?
+  }];
+  let hasVerifier = 1;
+}
+
+def AMDGPU_PackedStochRoundFp8Op :
+    AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
+    Arguments<(ins F32:$source,
+      I32:$stochiasticParam,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+  let summary = "Round float stochiastically into a packed vector of 8-bit floats";
+  let description = [{
+    Round the input `source`, adding in `stochiasticParam`, and place it into
+    the `storeIndex`th element of `res`.
+
+    If `existing` is passed in, elements of `res` other than the one at `storeIndex`
+    are copied from `existing`.
+
+    The reason for this odd signature is that AMD GPUs cannot easily work with
+    sub-registers, and so the conversion intrinsics (which are currently the
+    only way to work with 8-bit float types) take packed vectors of 4 8-bit
+    values.
+  }];
+  let assemblyFormat = [{
+    attr-dict $source `+` $stochiasticParam
+    `into` ($existing^):(`undef`)? `[` $storeIndex `]`
+    `:` type($source) `to` type($res) (`into` type($existing)^)?
+  }];
+  let hasVerifier = 1;
+}
+
 /// Raw buffer load
 def AMDGPU_RawBufferLoadOp :
     AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 08d36397dc31355..6c6419bf238b457 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -116,7 +116,7 @@ class ROCDL_MbcntOp<string mnemonic> :
 def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
 def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;
 
-def ROCDL_DsSwizzleOp : 
+def ROCDL_DsSwizzleOp :
 ROCDL_Op<"ds_swizzle">,
 Results<(outs I32:$res)>,
 Arguments<(ins I32:$src,
@@ -130,7 +130,7 @@ Arguments<(ins I32:$src,
    }];
 }
 
-def ROCDL_DsBpermuteOp : 
+def ROCDL_DsBpermuteOp :
 ROCDL_Op<"ds_bpermute">,
 Results<(outs I32:$res)>,
 Arguments<(ins I32:$index,
@@ -525,6 +525,85 @@ def ROCDL_RawBufferAtomicUMinOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 8-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtF32Bf8Op :
+    ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$srcA, I32:$byteSel)> {
+  let summary = "Convert bf8 to f32";
+  let description = [{
+    Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtF32Fp8Op :
+    ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$srcA, I32:$byteSel)> {
+  let summary = "Convert fp8 to f32";
+  let description = [{
+    Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtPkBf8F32Op :
+    ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
+  let summary = "Convert two f32's to bf8";
+  let description = [{
+    Convert `srcA` and `srcB` to bf8 and store into the low/high word of
+    `old`, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtPkFp8F32Op :
+    ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
+  let summary = "Convert two f32's to fp8";
+  let description = [{
+    Convert `srcA` and `srcB` to fp8 and store into the low/high word of
+    `old`, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtSrBf8F32Op :
+    ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
+  let summary = "Convert f32 to bf8, stochiastic rounding";
+  let description = [{
+    Convert `srcA` to bf8, adding the rounding factor from `srcB`,
+    and store into the `byteSel`th byte of `old`, preserving the others.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtSrFp8F32Op :
+    ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
+  let summary = "Convert f32 to fp8, stochiastic rounding";
+  let description = [{
+    Convert `srcA` to fp8, adding the rounding factor from `srcB`,
+    and store into the `byteSel`th byte of `old`, preserving the others.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ROCDL target attribute.
 //===----------------------------------------------------------------------===//
@@ -612,5 +691,4 @@ def ROCDL_TargettAttr :
     }
   }];
 }
-
 #endif // ROCDLIR_OPS

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index eeed04049668420..9ed312cef744dca 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -639,6 +640,161 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+namespace {
+struct ExtPackedFp8OpLowering final
+    : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
+  ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct PackedTrunc2xFp8OpLowering final
+    : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
+  PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct PackedStochRoundFp8OpLowering final
+    : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
+  PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(PackedStochRoundFp8Op op,
+                  PackedStochRoundFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
+    ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+    return rewriter.notifyMatchFailure(
+        loc, "Fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type v4i8 =
+      getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
+
+  Value source = adaptor.getSource();
+  auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
+  Type sourceElemType = getElementTypeOrSelf(op.getSource());
+  // Extend to a v4i8
+  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
+    Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+    if (!sourceVecType) {
+      longVec = rewriter.create<LLVM::InsertElementOp>(
+          loc, longVec, source, createI32Constant(rewriter, loc, 0));
+    } else {
+      for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+        Value idx = createI32Constant(rewriter, loc, i);
+        Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+        longVec =
+            rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+      }
+    }
+    source = longVec;
+  }
+  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+  Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
+  if (sourceElemType.isFloat8E5M2FNUZ()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
+                                                    wordSel);
+  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
+                                                    wordSel);
+  }
+  return success();
+}
+
+LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
+    PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+    return rewriter.notifyMatchFailure(
+        loc, "Fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+
+  Type resultType = op.getResult().getType();
+  Type resultElemType = getElementTypeOrSelf(resultType);
+
+  Value sourceA = adaptor.getSourceA();
+  Value sourceB = adaptor.getSourceB();
+  if (!sourceB)
+    sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+  Value existing = adaptor.getExisting();
+  if (existing)
+    existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+  else
+    existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
+
+  Value result;
+  if (resultElemType.isFloat8E5M2FNUZ())
+    result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
+                                                   existing, wordSel);
+  else if (resultElemType.isFloat8E4M3FNUZ())
+    result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
+                                                   existing, wordSel);
+
+  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+      op, getTypeConverter()->convertType(resultType), result);
+  return success();
+}
+
+LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
+    PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
+    return rewriter.notifyMatchFailure(
+        loc, "Fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+
+  Type resultType = op.getResult().getType();
+  Type resultElemType = getElementTypeOrSelf(resultType);
+
+  Value source = adaptor.getSource();
+  Value stoch = adaptor.getStochiasticParam();
+  Value existing = adaptor.getExisting();
+  if (existing)
+    existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+  else
+    existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
+
+  Value result;
+  if (resultElemType.isFloat8E5M2FNUZ())
+    result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
+                                                   existing, byteSel);
+  else if (resultElemType.isFloat8E4M3FNUZ())
+    result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
+                                                   existing, byteSel);
+
+  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+      op, getTypeConverter()->convertType(resultType), result);
+  return success();
+}
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
   ConvertAMDGPUToROCDLPass() = default;
@@ -691,7 +847,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicUminOp>,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
-           MFMAOpLowering, WMMAOpLowering>(converter, chipset);
+           MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
+                                                                      chipset);
 }
 
 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {

diff  --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
new file mode 100644
index 000000000000000..7785405eae67be3
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -0,0 +1,210 @@
+//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ArithToAMDGPUConversionPass final
+    : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
+  using impl::ArithToAMDGPUConversionPassBase<
+      ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
+
+  void runOnOperation() override;
+};
+
+struct ExtfOnFloat8RewritePattern final
+    : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
+
+  LogicalResult match(arith::ExtFOp op) const override;
+  void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
+};
+
+struct TruncfToFloat8RewritePattern final
+    : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+  LogicalResult match(arith::TruncFOp op) const override;
+  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+static Value castF32To(Type elementType, Value f32, Location loc,
+                       PatternRewriter &rewriter) {
+  if (elementType.isF32())
+    return f32;
+  if (elementType.getIntOrFloatBitWidth() < 32)
+    return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
+  if (elementType.getIntOrFloatBitWidth() > 32)
+    return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
+  llvm_unreachable("The only 32-bit float type is f32");
+}
+
+LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
+  Type inType = op.getIn().getType();
+  if (auto inVecType = inType.dyn_cast<VectorType>()) {
+    if (inVecType.isScalable())
+      return failure();
+    if (inVecType.getShape().size() > 1)
+      // Multi-dimensional vectors are currently unsupported.
+      return failure();
+    inType = inVecType.getElementType();
+  }
+  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+}
+
+void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
+                                         PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value in = op.getIn();
+  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  if (!in.getType().isa<VectorType>()) {
+    Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+        loc, rewriter.getF32Type(), in, 0);
+    Value result = castF32To(outElemType, asFloat, loc, rewriter);
+    return rewriter.replaceOp(op, result);
+  }
+  VectorType inType = in.getType().cast<VectorType>();
+  int64_t numElements = inType.getNumElements();
+  Value zero = rewriter.createOrFold<arith::ConstantOp>(
+      loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+  Value result =
+      rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
+  if (inType.getShape().empty()) {
+    Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
+    // Recurse to send the 0-D vector case to the 1-D vector case
+    Value scalarExt =
+        rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
+    result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
+    return rewriter.replaceOp(op, result);
+  }
+  for (int64_t i = 0; i < numElements; i += 4) {
+    int64_t elemsThisOp = std::min(numElements, i + 4) - i;
+    Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, in, i, elemsThisOp, 1);
+    for (int64_t j = 0; j < elemsThisOp; ++j) {
+      Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
+          loc, rewriter.getF32Type(), inSlice, j);
+      Value asType = castF32To(outElemType, asFloat, loc, rewriter);
+      result = rewriter.create<vector::InsertElementOp>(
+          loc, asType, result,
+          rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
+    }
+  }
+  rewriter.replaceOp(op, result);
+}
+
+static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
+  Type type = value.getType();
+  if (type.isF32())
+    return value;
+  if (type.getIntOrFloatBitWidth() < 32)
+    return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
+  if (type.getIntOrFloatBitWidth() > 32)
+    return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
+  llvm_unreachable("The only 32-bit float type is f32");
+}
+
+LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
+  Type outType = op.getOut().getType();
+  if (auto outVecType = outType.dyn_cast<VectorType>()) {
+    if (outVecType.isScalable())
+      return failure();
+    if (outVecType.getShape().size() > 1)
+      // Multi-dimensional vectors are currently unsupported.
+      return failure();
+    outType = outVecType.getElementType();
+  }
+  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+}
+
+void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
+                                           PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value in = op.getIn();
+  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  VectorType truncResType = VectorType::get(4, outElemType);
+  if (!in.getType().isa<VectorType>()) {
+    Value asFloat = castToF32(in, loc, rewriter);
+    Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
+        loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
+        /*existing=*/nullptr);
+    Value result = rewriter.create<vector::ExtractElementOp>(
+        loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    return rewriter.replaceOp(op, result);
+  }
+  VectorType outType = op.getOut().getType().cast<VectorType>();
+  int64_t numElements = outType.getNumElements();
+  Value zero = rewriter.createOrFold<arith::ConstantOp>(
+      loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+  if (outType.getShape().empty()) {
+    Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
+    // Recurse to send the 0-D vector case to the 1-D vector case
+    Value scalarTrunc =
+        rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
+    result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
+    return rewriter.replaceOp(op, result);
+  }
+
+  for (int64_t i = 0; i < numElements; i += 4) {
+    int64_t elemsThisOp = std::min(numElements, i + 4) - i;
+    Value thisResult = nullptr;
+    for (int64_t j = 0; j < elemsThisOp; j += 2) {
+      Value elemA = rewriter.create<vector::ExtractElementOp>(
+          loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
+      Value asFloatA = castToF32(elemA, loc, rewriter);
+      Value asFloatB = nullptr;
+      if (j + 1 < elemsThisOp) {
+        Value elemB = rewriter.create<vector::ExtractElementOp>(
+            loc, in,
+            rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
+        asFloatB = castToF32(elemB, loc, rewriter);
+      }
+      thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
+          loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
+    }
+    if (elemsThisOp < 4)
+      thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, thisResult, 0, elemsThisOp, 1);
+    result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+                                                           result, i, 1);
+  }
+  rewriter.replaceOp(op, result);
+}
+
+void mlir::arith::populateArithToAMDGPUConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
+      patterns.getContext());
+}
+
+void ArithToAMDGPUConversionPass::runOnOperation() {
+  Operation *op = getOperation();
+  RewritePatternSet patterns(op->getContext());
+  arith::populateArithToAMDGPUConversionPatterns(patterns);
+  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+    return signalPassFailure();
+}

diff  --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
new file mode 100644
index 000000000000000..359015b6f86ad87
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArithToAMDGPU
+  ArithToAMDGPU.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToAMDGPU
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRAMDGPUDialect
+  MLIRArithDialect
+  MLIRVectorDialect
+  MLIRPass
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 660e48768c4ff34..35790254be137be 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
+add_subdirectory(ArithToAMDGPU)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index ac34acc8307485c..2575ad4984814b5 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// 8-bit float ops
+//===----------------------------------------------------------------------===//
+LogicalResult PackedTrunc2xFp8Op::verify() {
+  if (getExisting() && getExisting().getType() != getResult().getType())
+    return emitOpError("existing values must have same type as result");
+  return success();
+}
+
+LogicalResult PackedStochRoundFp8Op::verify() {
+  if (getExisting() && getExisting().getType() != getResult().getType())
+    return emitOpError("existing values must have same type as result");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // RawBuffer*Op
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
new file mode 100644
index 000000000000000..7818a525d17b537
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
+
+// CHECK-LABEL: func @ext_scalar
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2FNUZ to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_short_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FNUZ> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+
+func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc
+// CHECK-SAME: ([[V:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
+func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+  func.return %ret : vector<4xf8E4M3FNUZ>
+}
+
+// CHECK-LABEL: func @packed_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
+func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+  func.return %ret : vector<4xf8E4M3FNUZ>
+}
+
+// CHECK-LABEL: func @packed_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]:  vector<4xf8E5M2FNUZ>)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
+// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
+func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
+  func.return %ret : vector<4xf8E5M2FNUZ>
+}
+
+// CHECK-LABEL: func @packed_stoch_round
+// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK:  builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
+func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
+  %ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FNUZ>
+  func.return %ret : vector<4xf8E4M3FNUZ>
+}
+
+// CHECK-LABEL: func @packed_stoch_round_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]:  vector<4xf8E5M2FNUZ>)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
+// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
+func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
+  %ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
+  func.return %ret : vector<4xf8E5M2FNUZ>
+}

diff  --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
new file mode 100644
index 000000000000000..a6c11d022e2c15f
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_ext
+// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2FNUZ to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: return [[W]]
+func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
+  %w = arith.extf %v : f8E5M2FNUZ to f16
+  return %w : f16
+}
+
+// No 0-D test because arith.extf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_short
+// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
+// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
+// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
+// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]]
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
+// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
+// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]]
+// CHECK: return [[W1]] : vector<2xf64>
+
+func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
+  %w = arith.extf %v : vector<2xf8E5M2FNUZ> to vector<2xf64>
+  return %w : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long
+// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insertelement [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]]
+// CHECK: return [[W8]]
+func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
+  %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
+  return %w : vector<9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK: [[C0:%.+]] = arith.constant 0 : index
+// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
+// CHECK: return [[W]] : f8E5M2FNUZ
+func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
+  %w = arith.truncf %v : f16 to f8E5M2FNUZ
+  return %w : f8E5M2FNUZ
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_short
+// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index]
+// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
+// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index]
+// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>
+// CHECK: return [[W]] : vector<2xf8E5M2FNUZ>
+func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
+  %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2FNUZ>
+  return %w : vector<2xf8E5M2FNUZ>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long
+// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: return [[W]]
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
+  %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
+  return %w : vector<9xf8E4M3FNUZ>
+}

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 142224e59a95d7a..5e1ab79962d2f0f 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -1,5 +1,19 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
+func.func @mixing_packed_trunc_types(%arg0: f32, %arg1: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> {
+  // expected-error at +1 {{'amdgpu.packed_trunc_2xfp8' op existing values must have same type as result}}
+  %ret = amdgpu.packed_trunc_2xfp8 %arg0, undef into %arg1[word 0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ>
+  func.return %ret : vector<4xf8E4M3FNUZ>
+}
+
+// -----
+
+func.func @mixing_packed_stoch_round_types(%arg0: f32, %arg1: i32, %arg2: vector<4xf8E5M2FNUZ>) -> vector<4xf8E4M3FNUZ> {
+  // expected-error at +1 {{'amdgpu.packed_stoch_round_fp8' op existing values must have same type as result}}
+  %ret = amdgpu.packed_stoch_round_fp8 %arg0 + %arg1 into %arg2[0] : f32 to vector<4xf8E4M3FNUZ> into vector<4xf8E5M2FNUZ>
+  func.return %ret : vector<4xf8E4M3FNUZ>
+}
+
 // -----
 
 func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>,

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 4088c6750c91b8d..744a096d757e02e 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -4,6 +4,27 @@
 // Verify the generic form can be parsed.
 // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
 
+// CHECK-LABEL: func @ext_packed_fp8
+// CHECK: amdgpu.ext_packed_fp8
+func.func @ext_packed_fp8(%v: vector<4xf8E4M3FNUZ>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : vector<4xf8E4M3FNUZ> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc_2xfp8
+// CHECK: amdgpu.packed_trunc_2xfp8
+func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v1, %v2 into %others[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
+  func.return %ret : vector<4xf8E5M2FNUZ>
+}
+
+// CHECK-LABEL: func @packed_stoch_round_fp8
+// CHECK: amdgpu.packed_stoch_round_fp8
+func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
+  %ret = amdgpu.packed_stoch_round_fp8 %v1 + %stoch into %others[2] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
+  func.return %ret : vector<4xf8E5M2FNUZ>
+}
+
 // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
 func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 {
   // CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32

diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 26de6a50fee38b9..5a14df9ef9f8dc2 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -330,6 +330,27 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
+llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+// CHECK-LABEL: @rocdl_8bit_floats
+// CHECK: rocdl.cvt.f32.bf8
+// CHECK: rocdl.cvt.f32.fp8
+// CHECK: rocdl.cvt.pk.bf8.f32
+// CHECK: rocdl.cvt.pk.fp8.f32
+// CHECK: rocdl.cvt.sr.bf8.f32
+// CHECK: rocdl.cvt.sr.fp8.f32
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  %c2 = llvm.mlir.constant(2 : i32) : i32
+  %c3 = llvm.mlir.constant(3 : i32) : i32
+  %false = llvm.mlir.constant(false) : i1
+  %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
+  %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
+  %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
+  %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
+  %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
+  %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
+  llvm.return %source5 : i32
+}
+
 // -----
 
 // expected-error at below {{attribute attached to unexpected op}}

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 777bef8fea5847d..8b37dfbe3c6e881 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -468,6 +468,27 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
   llvm.return %val : i32
 }
 
+llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
+// CHECK-LABEL: @rocdl_8bit_floats
+// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
+// CHECK: call float @llvm.amdgcn.cvt.f32.fp8(i32 %{{.+}}, i32 0)
+// CHECK: call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
+// CHECK: call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %{{.+}}, float %{{.+}}, i32 %{{.+}}, i1 false)
+// CHECK: call i32 @llvm.amdgcn.cvt.sr.bf8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 2)
+// CHECK: call i32 @llvm.amdgcn.cvt.sr.fp8.f32(float %{{.+}}, i32 %{{.+}}, i32 %{{.+}}, i32 3)
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  %c2 = llvm.mlir.constant(2 : i32) : i32
+  %c3 = llvm.mlir.constant(3 : i32) : i32
+  %false = llvm.mlir.constant(false) : i1
+  %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
+  %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
+  %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
+  %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
+  %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
+  %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
+  llvm.return %source5 : i32
+}
+
 // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "amdgpu-implicitarg-num-bytes"="56" }
 // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
 // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"


        


More information about the Mlir-commits mailing list