[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Use native packing support (PR #150342)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Jul 24 09:35:15 PDT 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/150342

>From 0644bfaf4062abea138fcaae183c0c08d57d3c2b Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 24 Jul 2025 00:10:16 +0000
Subject: [PATCH 1/2] [mlir][ArithToAMDGPU] Use native packing support

The current arith-to-amdgpu patterns for scaling_extf and
scaling_truncf don't take full advantage of the native packing ability
of the intrinsics being targetted. Scaling extension takes the
location of the two elements to be extended as a constant
argument (byte for fp4, half for fp8), and scaling truncation takes a
32-bit input register and a byte or half to write the truncated values
to.

Not using these features would cause excess unneeded register
pressure. This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or
truncateting a block of 32 values to/from fp4 with a uniform scale to
ensure that this usage has a minimal amount of vector shuffling.
---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 100 +++++++++++-------
 .../ArithToAMDGPU/scaling-extf.mlir           |  46 ++++----
 .../ArithToAMDGPU/scaling-truncf.mlir         |  60 +++++------
 3 files changed, 113 insertions(+), 93 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
                                            PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opOutWidth = 2;
+  constexpr int64_t opInWidth = 8;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   else if (scaleType.getIntOrFloatBitWidth() > 32)
     scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
 
-  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+  VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
   SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+  if (origScaleVecType)
     llvm::append_range(originalScaleShape, origScaleVecType.getShape());
 
   originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+    for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
          i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
-      Value scaleExt = amdgpu::ScaledExtPackedOp::create(
-          rewriter, loc, extScaleResultType, slice, uniformScale, 0);
-      if (sliceWidth != opWidth)
-        scaleExt = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleExt, 0, sliceWidth, 1);
-      blockResult = vector::InsertStridedSliceOp::create(
-          rewriter, loc, scaleExt, blockResult, i, 1);
+         i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+      Value inSlice = vector::ExtractStridedSliceOp::create(
+          rewriter, loc, block1D, i, inSliceWidth, 1);
+      for (int64_t j = 0,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+           j < inSliceWidth; j += outSliceWidth,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+        // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+        Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+            rewriter, loc, extScaleResultType, inSlice, uniformScale,
+            j / opOutWidth);
+        if (outSliceWidth < opOutWidth) {
+          scaleExt = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+        }
+        blockResult = vector::InsertStridedSliceOp::create(
+            rewriter, loc, scaleExt, blockResult, i + j, 1);
+      }
     }
 
     VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
                                              PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opInWidth = 2;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType outVecType = dyn_cast<VectorType>(out.getType());
   VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
   if (outVecType && outVecType.isScalable())
     return failure();
 
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   Value zero = arith::ConstantOp::create(rewriter, loc, outType,
                                          rewriter.getFloatAttr(outType, 0.0));
-  unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
-  VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+  int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+  VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
-  SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
-    llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+  SmallVector<int64_t> scaleShape;
+  if (origScaleVecType)
+    llvm::append_range(scaleShape, origScaleVecType.getShape());
 
-  originalScaleShape.insert(originalScaleShape.end(),
-                            inShape.size() - originalScaleShape.size(), 1);
+  scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
 
-  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+  auto maybeRatio = computeShapeRatio(inShape, scaleShape);
   assert(maybeRatio &&
          "failed to derive block size from broadcast or splat operation");
 
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
-         i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
-      Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
-          rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
-          /*existing=*/nullptr);
-      int64_t packedWidth =
-          cast<VectorType>(scaleTrunc.getType()).getNumElements();
-      if (packedWidth != opWidth)
+    for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+         i < blockSize; i += outSliceWidth,
+                 outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+      Value scaleTrunc;
+      // Case where <= 2 elements are being truncated.
+      if (outSliceWidth <= opInWidth) {
+        Value slice = vector::ExtractStridedSliceOp::create(
+            rewriter, loc, block1D, i, outSliceWidth, 1);
+        // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+        scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+            rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+            /*existing=*/nullptr);
+      } else {
+        scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+                                                 truncScaleResultType, zero);
+        for (int64_t j = 0,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+             j < outSliceWidth; j += opInWidth,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+          Value slice = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, block1D, i + j, inSliceWidth, 1);
+          scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+              rewriter, loc, truncScaleResultType, slice, uniformScale,
+              j / opInWidth, scaleTrunc);
+        }
+      }
+      if (outSliceWidth != opOutWidth) {
         scaleTrunc = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+            rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+      }
       blockResult = vector::InsertStridedSliceOp::create(
           rewriter, loc, scaleTrunc, blockResult, i, 1);
     }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
 // CHECK-DAG:     %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
 // CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
 // CHECK-DAG:     vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    %[[IN_SLICE_CAST:.+]] = vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT:    vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT:    vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
 func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
     %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
     %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
 // CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
 // CHECK-DAG:     %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
 // CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
 // CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK:         %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 // CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf32>
 func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
     %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
     return %ext : f32
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+    return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[P1:.+]] = amdgpu.packed_scaled_trunc
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT:    %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT:    %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT:    vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
 func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf8E5M2> {
     %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
     %cast1 = vector.shape_cast %in : vector<8x8xf32> to vector<8x2x4xf32>
@@ -122,7 +114,7 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F
 // -----
 
 // CHECK-LABEL: @conversion_broadcast_odd
-// CHECK-NEXT:    %[[CST3:.+]] = arith.constant dense<0.000000e+00> : vector<3xf8E5M2>
+// CHECK-NEXT:    %[[CST4:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2>
 // CHECK-NEXT:    %[[CST6:.+]] = arith.constant dense<0.000000e+00> : vector<6xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_BCAST:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU>
 // CHECK-NEXT:    %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BCAST]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU>
@@ -130,24 +122,18 @@ func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0F
 // CHECK-NEXT:    %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
 // CHECK-NEXT:    %[[SCALE0:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<6xf32>
 // CHECK-NEXT:    %[[IN_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
-// CHECK-NEXT:    %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into undef[0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[PACKED0_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[ACCUM0_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into %[[CST4]][0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2>
 // CHECK-NEXT:    %[[IN_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
-// CHECK-NEXT:    %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into undef[0], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[CHUNK0_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART1]], %[[ACCUM0_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into %[[PACKED0_PART0]][1], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[CHUNK0_RES:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[FINAL_ACCUM_A:.+]] = vector.insert_strided_slice %[[CHUNK0_RES]], %[[CST6]] {offsets = [0], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
 // CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
 // CHECK-NEXT:    %[[SCALE1:.+]] = vector.extract %[[SCALE_EXTF]][3] : f32 from vector<6xf32>
 // CHECK-NEXT:    %[[IN_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32>
-// CHECK-NEXT:    %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into undef[0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[PACKED1_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[ACCUM1_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into %[[CST4]][0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2>
 // CHECK-NEXT:    %[[IN_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
-// CHECK-NEXT:    %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into undef[0], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[CHUNK1_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART1]], %[[ACCUM1_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2>
+// CHECK-NEXT:    %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into %[[PACKED1_PART0]][1], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    %[[CHUNK1_RES:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [3], strides = [1]} : vector<4xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[CHUNK1_RES]], %[[FINAL_ACCUM_A]] {offsets = [3], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2>
 // CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<6xf8E5M2>
 func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf8E5M2> {
@@ -165,14 +151,10 @@ func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0F
 // CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
 // CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
 // CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
-// CHECK-NEXT:    %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = vector.extract_strided_slice %[[PACKED0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
+// CHECK-NEXT:    %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into %[[CST]][0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
 // CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
-// CHECK-NEXT:    %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2>
-// CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf8E5M2>
+// CHECK-NEXT:    %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into %[[PACKED0]][1], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2>
+// CHECK-NEXT:    return %[[PACKED1]] : vector<4xf8E5M2>
 func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> {
     %splat = vector.broadcast %scale : f8E8M0FNU to vector<4xf8E8M0FNU>
     %ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2>
@@ -191,3 +173,15 @@ func.func @conversion_scalar(%in: f32, %scale: f8E8M0FNU) -> f8E5M2 {
     %ext = arith.scaling_truncf %in, %scale : f32, f8E8M0FNU to f8E5M2
     return %ext : f8E5M2
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[3]
+// CHECK-NOT: amdgpu.packed_scaled_trunc
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf4E2M1FN> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf4E2M1FN>
+    return %trunc : vector<32xf4E2M1FN>
+}

>From 8c712df75ed194c131bdd3e51998d3d83f339eb9 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 24 Jul 2025 16:33:11 +0000
Subject: [PATCH 2/2] Handle inWidth correctly, extra test

---
 mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp  |  3 ++-
 mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir | 12 ++++++++++++
 .../Conversion/ArithToAMDGPU/scaling-truncf.mlir     | 12 ++++++++++++
 3 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 4cf80167b20c2..8230591123661 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -450,7 +450,6 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
                                            PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   constexpr int64_t opOutWidth = 2;
-  constexpr int64_t opInWidth = 8;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -461,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   Type scaleType = getElementTypeOrSelf(scale);
   Type outType = getElementTypeOrSelf(out);
 
+  int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
+
   VectorType outVecType = dyn_cast<VectorType>(out.getType());
   VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
 
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index a837bdb8be4fa..1d36be1108d26 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -263,3 +263,15 @@ func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<3
     %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
     return %ext : vector<32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp8_broadcast
+// CHECK-COUNT-8: amdgpu.scaled_ext_packed %{{.+}}[1]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp8_broadcast(%in: vector<32xf8E4M3FN>, %scale: f32) -> vector<32xf32> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %ext = arith.scaling_extf %in, %splat : vector<32xf8E4M3FN>, vector<32xf32> to vector<32xf32>
+    return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 6d6e1e28d2c2c..90a86084ac93f 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -185,3 +185,15 @@ func.func @long_fp4_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf4E
     %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf4E2M1FN>
     return %trunc : vector<32xf4E2M1FN>
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp8_broadcast
+// CHECK-COUNT-8: amdgpu.packed_scaled_trunc %{{.*}} into %{{.+}}[1]
+// CHECK-NOT: amdgpu.packed_scaled_trunc
+// CHECK: return
+func.func @long_fp8_broadcast(%in: vector<32xf32>, %scale: f32) -> vector<32xf8E4M3FN> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %trunc = arith.scaling_truncf %in, %splat : vector<32xf32>, vector<32xf32> to vector<32xf8E4M3FN>
+    return %trunc : vector<32xf8E4M3FN>
+}



More information about the Mlir-commits mailing list