[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Use native packing support (PR #150342)
Krzysztof Drewniak
llvmlistbot at llvm.org
Wed Jul 23 17:15:23 PDT 2025
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/150342
The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to.
Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency.
It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.
>From 0644bfaf4062abea138fcaae183c0c08d57d3c2b Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 24 Jul 2025 00:10:16 +0000
Subject: [PATCH] [mlir][ArithToAMDGPU] Use native packing support
The current arith-to-amdgpu patterns for scaling_extf and
scaling_truncf don't take full advantage of the native packing ability
of the intrinsics being targetted. Scaling extension takes the
location of the two elements to be extended as a constant
argument (byte for fp4, half for fp8), and scaling truncation takes a
32-bit input register and a byte or half to write the truncated values
to.
Not using these features would cause excess unneeded register
pressure. This PR resolves the inefficiency.
It also adds a test for the expected usecase of extending or
truncateting a block of 32 values to/from fp4 with a uniform scale to
ensure that this usage has a minimal amount of vector shuffling.
---
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 100 +++++++++++-------
.../ArithToAMDGPU/scaling-extf.mlir | 46 ++++----
.../ArithToAMDGPU/scaling-truncf.mlir | 60 +++++------
3 files changed, 113 insertions(+), 93 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opOutWidth = 2;
+ constexpr int64_t opInWidth = 8;
Value in = op.getIn();
Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
else if (scaleType.getIntOrFloatBitWidth() > 32)
scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- VectorType extScaleResultType = VectorType::get(opWidth, outType);
+ VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ if (origScaleVecType)
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+ for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = amdgpu::ScaledExtPackedOp::create(
- rewriter, loc, extScaleResultType, slice, uniformScale, 0);
- if (sliceWidth != opWidth)
- scaleExt = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleExt, 0, sliceWidth, 1);
- blockResult = vector::InsertStridedSliceOp::create(
- rewriter, loc, scaleExt, blockResult, i, 1);
+ i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+ Value inSlice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, inSliceWidth, 1);
+ for (int64_t j = 0,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+ j < inSliceWidth; j += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
+ j / opOutWidth);
+ if (outSliceWidth < opOutWidth) {
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i + j, 1);
+ }
}
VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opInWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
if (outVecType && outVecType.isScalable())
return failure();
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value zero = arith::ConstantOp::create(rewriter, loc, outType,
rewriter.getFloatAttr(outType, 0.0));
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
- VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+ VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
- SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
- llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+ SmallVector<int64_t> scaleShape;
+ if (origScaleVecType)
+ llvm::append_range(scaleShape, origScaleVecType.getShape());
- originalScaleShape.insert(originalScaleShape.end(),
- inShape.size() - originalScaleShape.size(), 1);
+ scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
- auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+ auto maybeRatio = computeShapeRatio(inShape, scaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
- i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
- rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
- /*existing=*/nullptr);
- int64_t packedWidth =
- cast<VectorType>(scaleTrunc.getType()).getNumElements();
- if (packedWidth != opWidth)
+ for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+ i < blockSize; i += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+ Value scaleTrunc;
+ // Case where <= 2 elements are being truncated.
+ if (outSliceWidth <= opInWidth) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, outSliceWidth, 1);
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+ /*existing=*/nullptr);
+ } else {
+ scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+ truncScaleResultType, zero);
+ for (int64_t j = 0,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+ j < outSliceWidth; j += opInWidth,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i + j, inSliceWidth, 1);
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
+ j / opInWidth, scaleTrunc);
+ }
+ }
+ if (outSliceWidth != opOutWidth) {
scaleTrunc = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+ rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+ }
blockResult = vector::InsertStridedSliceOp::create(
rewriter, loc, scaleTrunc, blockResult, i, 1);
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT: vector.shape_cast
+// CHECK-NEXT: %[[IN_SLICE_CAST:.+]] = vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT: %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT: vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT: vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
%cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32>
func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
%ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
return %ext : f32
}
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+ return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[P1:.+]] = amdgpu.packed_scaled_trunc
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT: %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT: vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf8E5M2> {
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
%cast1 = vector.shape_cast %in : vector<8x8xf32> to vector<8x2x4xf32>
@@ -122,7 +114,7 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F
// -----
// CHECK-LABEL: @conversion_broadcast_odd
-// CHECK-NEXT: %[[CST3:.+]] = arith.constant dense<0.000000e+00> : vector<3xf8E5M2>
+// CHECK-NEXT: %[[CST4:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2>
// CHECK-NEXT: %[[CST6:.+]] = arith.constant dense<0.000000e+00> : vector<6xf8E5M2>
// CHECK-NEXT: %[[SCALE_BCAST:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
// CHECK-NEXT: %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BCAST]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
@@ -130,24 +122,18 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F
// CHECK-NEXT: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK-NEXT: %[[SCALE0:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<6xf32>
// CHECK-NEXT: %[[IN_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into undef[0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[PACKED0_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[ACCUM0_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into %[[CST4]][0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
-// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into undef[0], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART1]], %[[ACCUM0_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into %[[PACKED0_PART0]][1], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[FINAL_ACCUM_A:.+]] = vector.insert_strided_slice %[[CHUNK0_RES]], %[[CST6]] {offsets = [0], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK-NEXT: %[[SCALE1:.+]] = vector.extract %[[SCALE_EXTF]][3] : f32 from vector<6xf32>
// CHECK-NEXT: %[[IN_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into undef[0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[PACKED1_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[ACCUM1_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into %[[CST4]][0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
-// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into undef[0], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART1]], %[[ACCUM1_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into %[[PACKED1_PART0]][1], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[CHUNK1_RES]], %[[FINAL_ACCUM_A]] {offsets = [3], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<6xf8E5M2>
func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf8E5M2> {
@@ -165,14 +151,10 @@ func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0F
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = vector.extract_strided_slice %[[PACKED0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
+// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into %[[CST]][0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
-// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
-// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf8E5M2>
+// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into %[[PACKED0]][1], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT: return %[[PACKED1]] : vector<4xf8E5M2>
func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> {
%splat = vector.broadcast %scale : f8E8M0FNU to vector<4xf8E8M0FNU>
%ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2>
@@ -191,3 +173,15 @@ func.func @conversion_scalar(%in: f32, %scale: f8E8M0FNU) -> f8E5M2 {
%ext = arith.scaling_truncf %in, %scale : f32, f8E8M0FNU to f8E5M2
return %ext : f8E5M2
}
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[3]
+// CHECK-NOT: amdgpu.packed_scaled_trunc
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf4E2M1FN> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf4E2M1FN>
+ return %trunc : vector<32xf4E2M1FN>
+}
More information about the Mlir-commits
mailing list