[Mlir-commits] [mlir] [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu (PR #146372)

Tim Gymnich llvmlistbot at llvm.org
Tue Jul 8 07:44:18 PDT 2025


https://github.com/tgymnich updated https://github.com/llvm/llvm-project/pull/146372

>From bc38211280388cdc856c3e5cd8ae2219ea0f4786 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Thu, 26 Jun 2025 09:54:55 +0000
Subject: [PATCH 1/2] [mlir][amdgpu] Add conversion for arith.scaling_extf to
 amdgpu

---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 272 ++++++++++++++++++
 1 file changed, 272 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3596b3235a631..cf9bb3a000050 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -14,7 +14,10 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -32,6 +35,7 @@ using namespace mlir::amdgpu;
 namespace {
 // Define commonly used chipsets versions for convenience.
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
 
 struct ArithToAMDGPUConversionPass final
     : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +77,28 @@ struct TruncfToFloat16RewritePattern final
                                 PatternRewriter &rewriter) const override;
 };
 
+struct ScalingExtFRewritePattern final
+    : OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  ScalingExtFRewritePattern(MLIRContext *ctx)
+      : OpRewritePattern::OpRewritePattern(ctx) {}
+
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+struct ScalingTruncFRewritePattern final
+    : OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  ScalingTruncFRewritePattern(MLIRContext *ctx)
+      : OpRewritePattern::OpRewritePattern(ctx) {}
+
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // end namespace
 
 static bool isSupportedF8(Type elementType, Chipset chipset) {
@@ -395,6 +421,247 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   return success();
 }
 
+/// Get the broadcasted / splatted value for a chain of ops.
+static Value getOriginalVectorValue(Value value) {
+  Value current = value;
+  while (Operation *definingOp = current.getDefiningOp()) {
+    bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+                      .Case<vector::ShapeCastOp>([&current](auto op) {
+                        current = op.getSource();
+                        return true;
+                      })
+                      .Case<vector::BroadcastOp>([&current](auto op) {
+                        current = op.getSource();
+                        return false;
+                      })
+                      .Case<vector::SplatOp>([&current](auto op) {
+                        current = op.getInput();
+                        return false;
+                      })
+                      .Default([](Operation *) { return false; });
+
+    if (!skipOp) {
+      break;
+    }
+  }
+  return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+                                           PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  constexpr int64_t opWidth = 2;
+
+  Value in = op.getIn();
+  Value scale = op.getScale();
+  Value out = op.getOut();
+
+  Type f32 = rewriter.getF32Type();
+  Type inType = getElementTypeOrSelf(in);
+  Type scaleType = getElementTypeOrSelf(scale);
+  Type outType = getElementTypeOrSelf(out);
+
+  VectorType outVecType = dyn_cast<VectorType>(out.getType());
+  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+
+  if (outVecType && outVecType.isScalable())
+    return failure();
+
+  Type scaleF32Type =
+      scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+  if (scaleType.getIntOrFloatBitWidth() < 32)
+    scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+  else if (scaleType.getIntOrFloatBitWidth() > 32)
+    scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+  if (!outVecType) {
+    Value inCast =
+        rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+    // TODO: replace this with non-packed ScaledExtOp
+    Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+        loc, extScaleResultType, inCast, scale, 0);
+    scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+    return success();
+  }
+
+  VectorType inVecType = cast<VectorType>(in.getType());
+  Value origScale = getOriginalVectorValue(op.getScale());
+
+  ArrayRef<int64_t> inShape = inVecType.getShape();
+  SmallVector<int64_t> originalScaleShape;
+  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+    llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+
+  originalScaleShape.insert(originalScaleShape.end(),
+                            inShape.size() - originalScaleShape.size(), 1);
+
+  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+  assert(maybeRatio &&
+         "failed to derive block size from broadcast or splat operation");
+
+  SmallVector<int64_t> ratio =
+      maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
+
+  int64_t blockSize = computeProduct(ratio);
+
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, outType, rewriter.getFloatAttr(outType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
+    SmallVector<int64_t> strides(offsets.size(), 1);
+    Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, in, offsets, ratio, strides);
+    VectorType block1DType = VectorType::get(blockSize, inType);
+    Value block1D =
+        rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+    Value uniformScale =
+        rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+    VectorType blockResultType = VectorType::get(blockSize, outType);
+    Value blockResult =
+        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+         i < blockSize;
+         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
+      Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, block1D, i, sliceWidth, 1);
+      // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+      Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+          loc, extScaleResultType, slice, uniformScale, 0);
+      if (sliceWidth != opWidth)
+        scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, scaleExt, 0, sliceWidth, 1);
+      blockResult = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, scaleExt, blockResult, i, 1);
+    }
+
+    VectorType resultType = VectorType::get(ratio, outType);
+    Value cast =
+        rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
+    result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
+                                                           offsets, strides);
+  }
+
+  rewriter.replaceOp(op, result);
+
+  return success();
+}
+
+LogicalResult
+ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
+                                             PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  constexpr int64_t opWidth = 2;
+
+  Value in = op.getIn();
+  Value scale = op.getScale();
+  Value out = op.getOut();
+
+  Type f32 = rewriter.getF32Type();
+  Type inType = getElementTypeOrSelf(in);
+  Type scaleType = getElementTypeOrSelf(scale);
+  Type outType = getElementTypeOrSelf(out);
+
+  VectorType outVecType = dyn_cast<VectorType>(out.getType());
+  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+
+  if (outVecType && outVecType.isScalable())
+    return failure();
+
+  Type scaleF32Type =
+      scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+  if (scaleType.getIntOrFloatBitWidth() < 32)
+    scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+  else if (scaleType.getIntOrFloatBitWidth() > 32)
+    scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, outType, rewriter.getFloatAttr(outType, 0.0));
+  unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
+  VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+
+  if (!outVecType) {
+    Type inVecType = VectorType::get(1, inType);
+    Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+    // TODO: replace this with non-packed ScaledTruncOp
+    Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
+        loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
+    scaleTrunc =
+        rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
+    return success();
+  }
+
+  VectorType inVecType = cast<VectorType>(in.getType());
+  Value origScale = getOriginalVectorValue(op.getScale());
+
+  ArrayRef<int64_t> inShape = inVecType.getShape();
+  SmallVector<int64_t> originalScaleShape;
+  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+    llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+
+  originalScaleShape.insert(originalScaleShape.end(),
+                            inShape.size() - originalScaleShape.size(), 1);
+
+  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+  assert(maybeRatio &&
+         "failed to derive block size from broadcast or splat operation");
+
+  SmallVector<int64_t> ratio =
+      maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
+
+  int64_t blockSize = computeProduct(ratio);
+
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
+    SmallVector<int64_t> strides(offsets.size(), 1);
+    Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, in, offsets, ratio, strides);
+    VectorType block1DType = VectorType::get(blockSize, inType);
+    Value block1D =
+        rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+    Value uniformScale =
+        rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+    VectorType blockResultType = VectorType::get(blockSize, outType);
+    Value blockResult =
+        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+         i < blockSize;
+         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
+      Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, block1D, i, sliceWidth, 1);
+      // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+      Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
+          loc, truncScaleResultType, slice, uniformScale, 0,
+          /*existing=*/nullptr);
+      int64_t packedWidth =
+          cast<VectorType>(scaleTrunc.getType()).getNumElements();
+      if (packedWidth != opWidth)
+        scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, scaleTrunc, 0, sliceWidth, 1);
+      blockResult = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, scaleTrunc, blockResult, i, 1);
+    }
+
+    VectorType resultType = VectorType::get(ratio, outType);
+    Value cast =
+        rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
+    result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
+                                                           offsets, strides);
+  }
+
+  rewriter.replaceOp(op, result);
+
+  return success();
+}
+
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
     RewritePatternSet &patterns, bool convertFP8Arithmetic,
     bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +673,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
   }
   if (allowPackedF16Rtz)
     patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+
+  if (chipset >= kGfx950) {
+    patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
+    patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
+  }
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {

>From efc6194b7b664ed2cb9aa175ae92692b87c2fdfe Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 30 Jun 2025 12:18:36 +0000
Subject: [PATCH 2/2] add tests

---
 .../ArithToAMDGPU/scaling-extf.mlir           | 262 ++++++++++++++++++
 .../ArithToAMDGPU/scaling-truncf.mlir         | 193 +++++++++++++
 2 files changed, 455 insertions(+)
 create mode 100644 mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
 create mode 100644 mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir

diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
new file mode 100644
index 0000000000000..095f3e575eca8
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+
+// CHECK-LABEL: @conversion_f8_f32_fallback
+// CHECK:         %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    return %[[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f8_f32_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+    return %ext : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_f4_f32_fallback
+// CHECK:         %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf4E2M1FN> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf4E2M1FN> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf4E2M1FN> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf4E2M1FN> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf32> to vector<1x1xf32>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32>
+// CHECK-NEXT:    return %[[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f4_f32_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+    return %ext : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_f8_f16_fallback
+// CHECK:         %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+// CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf8E5M2> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf8E5M2> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf8E5M2> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2>
+// CHECK-NEXT:    %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf8E5M2> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    return %[[ACC_B]] : vector<2x2xf16>
+func.func @conversion_f8_f16_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf16> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf16>
+    return %ext : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_f4_f16_fallback
+// CHECK:         %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+// CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf4E2M1FN> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf4E2M1FN> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf4E2M1FN> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN>
+// CHECK-NEXT:    %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN>
+// CHECK-NEXT:    %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32>
+// CHECK-NEXT:    %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf4E2M1FN> to vector<2xf16>
+// CHECK-NEXT:    %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+// CHECK-NEXT:    %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf16> to vector<1x1xf16>
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16>
+// CHECK-NEXT:    return %[[ACC_B]] : vector<2x2xf16>
+func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf16> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf16>
+    return %ext : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_broadcast
+// CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf32>
+// CHECK-DAG:     %[[BCAST:.+]] = vector.broadcast %arg1
+// CHECK-DAG:     %[[IN_CAST:.+]] = vector.shape_cast %arg0
+// CHECK-DAG:     %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
+// CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
+// CHECK-DAG:     vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.scaled_ext_packed
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.scaled_ext_packed
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.scaled_ext_packed
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.scaled_ext_packed
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
+    %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+    %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
+    %cast2 = vector.shape_cast %bc : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+    %ext = arith.scaling_extf %cast1, %cast2 : vector<8x2x4xf8E5M2>, vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32>
+    %cast3 = vector.shape_cast %ext : vector<8x2x4xf32> to vector<8x8xf32>
+    return %cast3 : vector<8x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_broadcast_odd
+// CHECK-NEXT:    %[[CST_PARTIAL:.+]] = arith.constant dense<0.000000e+00> : vector<3xf32>
+// CHECK-NEXT:    %[[CST_FINAL:.+]] = arith.constant dense<0.000000e+00> : vector<6xf32>
+// CHECK-NEXT:    %[[SCALE_BC:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
+// CHECK-NEXT:    %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BC]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
+// CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
+// CHECK-NEXT:    %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
+// CHECK-NEXT:    %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
+// CHECK-NEXT:    %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
+// CHECK-NEXT:    %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
+// CHECK-NEXT:    %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
+// CHECK-NEXT:    %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
+// CHECK-NEXT:    %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
+// CHECK-NEXT:    %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
+// CHECK-NEXT:    return %[[RESULT]] : vector<6xf32>
+func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf32> {
+    %bc = vector.broadcast %scale : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
+    %cast = vector.shape_cast %bc : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
+    %ext = arith.scaling_extf %in, %cast : vector<6xf8E5M2>, vector<6xf8E8M0FNU> to vector<6xf32>
+    return %ext : vector<6xf32>
+}
+
+// -----
+// CHECK-LABEL: @conversion_splat
+// CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG:     %[[SCALE_SPLAT:.+]] = vector.splat %arg1 : vector<4xf8E8M0FNU>
+// CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
+// CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf32>
+func.func @conversion_splat(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
+    %splat = vector.splat %scale : vector<4xf8E8M0FNU>
+    %ext = arith.scaling_extf %in, %splat : vector<4xf8E5M2>, vector<4xf8E8M0FNU> to vector<4xf32>
+    return %ext : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_scalar
+// CHECK:         %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK-NEXT:    %[[SPLAT_IN:.+]] = vector.splat %arg0 : vector<1xf8E5M2>
+// CHECK-NEXT:    %[[PACKED_EXT:.+]] = amdgpu.scaled_ext_packed %[[SPLAT_IN]][0], %[[SCALE_F32]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[RESULT:.+]] = vector.extract %[[PACKED_EXT]][0] : f32 from vector<2xf32>
+// CHECK-NEXT:    return %[[RESULT]] : f32
+func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
+    %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
+    return %ext : f32
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
new file mode 100644
index 0000000000000..0519050c5ecc4
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -0,0 +1,193 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+
+// CHECK-LABEL: @conversion_f8_fallback
+// CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf8E5M2>
+// CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK:         %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_00:.+]] = vector.shape_cast %[[IN_SLICE_00]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0]
+// CHECK-NEXT:    %[[PACKED_00:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_00]] into undef[0], %[[SCALE_SCALAR_00]]
+// CHECK-NEXT:    %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]]
+// CHECK-NEXT:    %[[OUT_SCALAR_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]]
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_00]], %[[CST]]
+// CHECK-NEXT:    %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_01:.+]] = vector.shape_cast %[[IN_SLICE_01]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1]
+// CHECK-NEXT:    %[[PACKED_01:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_01]] into undef[0], %[[SCALE_SCALAR_01]]
+// CHECK-NEXT:    %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]]
+// CHECK-NEXT:    %[[OUT_SCALAR_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]]
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_01]], %[[ACC_A]]
+// CHECK-NEXT:    %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_10:.+]] = vector.shape_cast %[[IN_SLICE_10]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0]
+// CHECK-NEXT:    %[[PACKED_10:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_10]] into undef[0], %[[SCALE_SCALAR_10]]
+// CHECK-NEXT:    %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]]
+// CHECK-NEXT:    %[[OUT_SCALAR_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]]
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_10]], %[[ACC_B]]
+// CHECK-NEXT:    %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_11:.+]] = vector.shape_cast %[[IN_SLICE_11]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1]
+// CHECK-NEXT:    %[[PACKED_11:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_11]] into undef[0], %[[SCALE_SCALAR_11]]
+// CHECK-NEXT:    %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]]
+// CHECK-NEXT:    %[[OUT_SCALAR_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]]
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_11]], %[[ACC_A]]
+// CHECK-NEXT:    return %[[ACC_B]] : vector<2x2xf8E5M2>
+func.func @conversion_f8_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf8E5M2> {
+    %ext = arith.scaling_truncf %in, %scale : vector<2x2xf32>, vector<2x2xf8E8M0FNU> to vector<2x2xf8E5M2>
+    return %ext : vector<2x2xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_f4_fallback
+// CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf4E2M1FN>
+// CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK:         %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_00:.+]] = vector.shape_cast %[[IN_SLICE_00]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0]
+// CHECK-NEXT:    %[[PACKED_00:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_00]] into undef[0], %[[SCALE_SCALAR_00]]
+// CHECK-NEXT:    %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]]
+// CHECK-NEXT:    %[[OUT_SCALAR_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]]
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_00]], %[[CST]]
+// CHECK-NEXT:    %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_01:.+]] = vector.shape_cast %[[IN_SLICE_01]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1]
+// CHECK-NEXT:    %[[PACKED_01:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_01]] into undef[0], %[[SCALE_SCALAR_01]]
+// CHECK-NEXT:    %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]]
+// CHECK-NEXT:    %[[OUT_SCALAR_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]]
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_01]], %[[ACC_A]]
+// CHECK-NEXT:    %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_10:.+]] = vector.shape_cast %[[IN_SLICE_10]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0]
+// CHECK-NEXT:    %[[PACKED_10:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_10]] into undef[0], %[[SCALE_SCALAR_10]]
+// CHECK-NEXT:    %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]]
+// CHECK-NEXT:    %[[OUT_SCALAR_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]]
+// CHECK-NEXT:    %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_10]], %[[ACC_B]]
+// CHECK-NEXT:    %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-NEXT:    %[[IN_SCALAR_11:.+]] = vector.shape_cast %[[IN_SLICE_11]]
+// CHECK-NEXT:    %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1]
+// CHECK-NEXT:    %[[PACKED_11:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_11]] into undef[0], %[[SCALE_SCALAR_11]]
+// CHECK-NEXT:    %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]]
+// CHECK-NEXT:    %[[OUT_SCALAR_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]]
+// CHECK-NEXT:    %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_11]], %[[ACC_A]]
+// CHECK-NEXT:    return %[[ACC_B]] : vector<2x2xf4E2M1FN>
+func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf4E2M1FN> {
+    %ext = arith.scaling_truncf %in, %scale : vector<2x2xf32>, vector<2x2xf8E8M0FNU> to vector<2x2xf4E2M1FN>
+    return %ext : vector<2x2xf4E2M1FN>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_broadcast
+// CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf8E5M2>
+// CHECK-DAG:     %[[BCAST:.+]] = vector.broadcast %arg1
+// CHECK-DAG:     %[[IN_CAST:.+]] = vector.shape_cast %arg0
+// CHECK-DAG:     %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
+// CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
+// CHECK-DAG:     vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.packed_scaled_trunc
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.packed_scaled_trunc
+// CHECK-NEXT:    vector.extract_strided_slice
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.packed_scaled_trunc
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
+// CHECK-NEXT:    amdgpu.packed_scaled_trunc
+// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf8E5M2> {
+    %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+    %cast1 = vector.shape_cast %in : vector<8x8xf32> to vector<8x2x4xf32>
+    %cast2 = vector.shape_cast %bc : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+    %ext = arith.scaling_truncf %cast1, %cast2 : vector<8x2x4xf32>, vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf8E5M2>
+    %cast3 = vector.shape_cast %ext : vector<8x2x4xf8E5M2> to vector<8x8xf8E5M2>
+    return %cast3 : vector<8x8xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_broadcast_odd
+// CHECK-NEXT:    %[[CST3:.+]] = arith.constant dense<0.000000e+00> : vector<3xf8E5M2>
+// CHECK-NEXT:    %[[CST6:.+]] = arith.constant dense<0.000000e+00> : vector<6xf8E5M2>
+// CHECK-NEXT:    %[[SCALE_BCAST:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
+// CHECK-NEXT:    %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BCAST]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
+// CHECK-NEXT:    %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
+// CHECK-NEXT:    %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK-NEXT:    %[[SCALE0:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<6xf32>
+// CHECK-NEXT:    %[[IN_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into undef[0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[PACKED0_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[ACCUM0_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[IN_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into undef[0], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[CHUNK0_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART1]], %[[ACCUM0_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[FINAL_ACCUM_A:.+]] = vector.insert_strided_slice %[[CHUNK0_RES]], %[[CST6]] {offsets = [0], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
+// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK-NEXT:    %[[SCALE1:.+]] = vector.extract %[[SCALE_EXTF]][3] : f32 from vector<6xf32>
+// CHECK-NEXT:    %[[IN_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into undef[0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[PACKED1_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[ACCUM1_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[IN_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
+// CHECK-NEXT:    %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into undef[0], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
+// CHECK-NEXT:    %[[CHUNK1_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART1]], %[[ACCUM1_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[CHUNK1_RES]], %[[FINAL_ACCUM_A]] {offsets = [3], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
+// CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<6xf8E5M2>
+func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf8E5M2> {
+    %bc = vector.broadcast %scale : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
+    %cast = vector.shape_cast %bc : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
+    %ext = arith.scaling_truncf %in, %cast : vector<6xf32>, vector<6xf8E8M0FNU> to vector<6xf8E5M2>
+    return %ext : vector<6xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_splat
+// CHECK-DAG:     %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2>
+// CHECK-DAG:     %[[SCALE_SPLAT:.+]] = vector.splat %arg1 : vector<4xf8E8M0FNU>
+// CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
+// CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
+// CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = vector.extract_strided_slice %[[PACKED0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
+// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
+// CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf8E5M2>
+func.func @conversion_splat(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> {
+    %splat = vector.splat %scale : vector<4xf8E8M0FNU>
+    %ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2>
+    return %ext : vector<4xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @conversion_scalar
+// CHECK:         %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK-NEXT:    %[[SPLAT_IN:.+]] = vector.splat %arg0 : vector<1xf32>
+// CHECK-NEXT:    %[[PACKED_TRUNC:.+]] = amdgpu.packed_scaled_trunc %[[SPLAT_IN]] into undef[0], %[[SCALE_F32]]
+// CHECK-NEXT:    %[[RESULT:.+]] = vector.extract %[[PACKED_TRUNC]][0]
+// CHECK-NEXT:    return %[[RESULT]] : f8E5M2
+func.func @conversion_scalar(%in: f32, %scale: f8E8M0FNU) -> f8E5M2 {
+    %ext = arith.scaling_truncf %in, %scale : f32, f8E8M0FNU to f8E5M2
+    return %ext : f8E5M2
+}



More information about the Mlir-commits mailing list