[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Use native packing support (PR #150342)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 17:15:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to.

Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.

---

Patch is 28.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150342.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+62-38) 
- (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir (+24-22) 
- (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir (+27-33) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
                                            PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opOutWidth = 2;
+  constexpr int64_t opInWidth = 8;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   else if (scaleType.getIntOrFloatBitWidth() > 32)
     scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
 
-  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+  VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
   SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+  if (origScaleVecType)
     llvm::append_range(originalScaleShape, origScaleVecType.getShape());
 
   originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+    for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
          i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
-      Value scaleExt = amdgpu::ScaledExtPackedOp::create(
-          rewriter, loc, extScaleResultType, slice, uniformScale, 0);
-      if (sliceWidth != opWidth)
-        scaleExt = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleExt, 0, sliceWidth, 1);
-      blockResult = vector::InsertStridedSliceOp::create(
-          rewriter, loc, scaleExt, blockResult, i, 1);
+         i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+      Value inSlice = vector::ExtractStridedSliceOp::create(
+          rewriter, loc, block1D, i, inSliceWidth, 1);
+      for (int64_t j = 0,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+           j < inSliceWidth; j += outSliceWidth,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+        // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+        Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+            rewriter, loc, extScaleResultType, inSlice, uniformScale,
+            j / opOutWidth);
+        if (outSliceWidth < opOutWidth) {
+          scaleExt = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+        }
+        blockResult = vector::InsertStridedSliceOp::create(
+            rewriter, loc, scaleExt, blockResult, i + j, 1);
+      }
     }
 
     VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
                                              PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opInWidth = 2;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType outVecType = dyn_cast<VectorType>(out.getType());
   VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
   if (outVecType && outVecType.isScalable())
     return failure();
 
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   Value zero = arith::ConstantOp::create(rewriter, loc, outType,
                                          rewriter.getFloatAttr(outType, 0.0));
-  unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
-  VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+  int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+  VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
-  SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
-    llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+  SmallVector<int64_t> scaleShape;
+  if (origScaleVecType)
+    llvm::append_range(scaleShape, origScaleVecType.getShape());
 
-  originalScaleShape.insert(originalScaleShape.end(),
-                            inShape.size() - originalScaleShape.size(), 1);
+  scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
 
-  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+  auto maybeRatio = computeShapeRatio(inShape, scaleShape);
   assert(maybeRatio &&
          "failed to derive block size from broadcast or splat operation");
 
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
-         i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
-      Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
-          rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
-          /*existing=*/nullptr);
-      int64_t packedWidth =
-          cast<VectorType>(scaleTrunc.getType()).getNumElements();
-      if (packedWidth != opWidth)
+    for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+         i < blockSize; i += outSliceWidth,
+                 outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+      Value scaleTrunc;
+      // Case where <= 2 elements are being truncated.
+      if (outSliceWidth <= opInWidth) {
+        Value slice = vector::ExtractStridedSliceOp::create(
+            rewriter, loc, block1D, i, outSliceWidth, 1);
+        // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+        scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+            rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+            /*existing=*/nullptr);
+      } else {
+        scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+                                                 truncScaleResultType, zero);
+        for (int64_t j = 0,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+             j < outSliceWidth; j += opInWidth,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+          Value slice = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, block1D, i + j, inSliceWidth, 1);
+          scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+              rewriter, loc, truncScaleResultType, slice, uniformScale,
+              j / opInWidth, scaleTrunc);
+        }
+      }
+      if (outSliceWidth != opOutWidth) {
         scaleTrunc = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+            rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+      }
       blockResult = vector::InsertStridedSliceOp::create(
           rewriter, loc, scaleTrunc, blockResult, i, 1);
     }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
 // CHECK-DAG:     %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
 // CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
 // CHECK-DAG:     vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    %[[IN_SLICE_CAST:.+]] = vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT:    vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT:    vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
 func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
     %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
     %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
 // CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
 // CHECK-DAG:     %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
 // CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
 // CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK:         %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 // CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf32>
 func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
     %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
     return %ext : f32
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+    return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[P1:.+]] = amdgpu.packed_scaled_trunc
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT:    %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT:    %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT:    vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-N...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/150342


More information about the Mlir-commits mailing list