[Mlir-commits] [mlir] [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu (PR #146372)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 30 08:40:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Tim Gymnich (tgymnich)

<details>
<summary>Changes</summary>



---

Patch is 52.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146372.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+165) 
- (added) mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir (+553) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3596b3235a631..22cd4703c6005 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -15,11 +15,17 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
@@ -32,6 +38,7 @@ using namespace mlir::amdgpu;
 namespace {
 // Define commonly used chipsets versions for convenience.
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
 
 struct ArithToAMDGPUConversionPass final
     : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +80,28 @@ struct TruncfToFloat16RewritePattern final
                                 PatternRewriter &rewriter) const override;
 };
 
+struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  Chipset chipset;
+  ScalingExtFRewritePattern(MLIRContext *ctx, Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  Chipset chipset;
+  ScalingTruncFRewritePattern(MLIRContext *ctx, Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // end namespace
 
 static bool isSupportedF8(Type elementType, Chipset chipset) {
@@ -395,6 +424,137 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   return success();
 }
 
+static Value getOriginalVectorValue(Value value) {
+  Value current = value;
+  while (Operation *definingOp = current.getDefiningOp()) {
+    bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+      .Case<vector::ShapeCastOp>(
+        [&current](auto op) {
+          current = op.getSource();
+          return true;
+        })
+      .Case<vector::BroadcastOp>(
+        [&current](auto op) {
+          current = op.getSource();
+          return false;
+        })
+      .Case<vector::SplatOp>(
+        [&current](auto op) {
+          current = op.getInput();
+          return false;
+        })
+      .Default([](Operation *) {
+        return false;
+      });
+
+    if (!skipOp) {
+      break;
+    }
+  }
+  return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+                                           PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  constexpr const int64_t opWidth = 2;
+
+  Value in = op.getIn();
+  Value scale = op.getScale();
+  Value out = op.getOut();
+
+  Type f32 = rewriter.getF32Type();
+  Type inType = getElementTypeOrSelf(in);
+  Type scaleType = getElementTypeOrSelf(scale);
+  Type outType = getElementTypeOrSelf(out);
+  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+  VectorType inVecType = dyn_cast<VectorType>(in.getType());
+  VectorType outVecType = dyn_cast<VectorType>(out.getType());
+
+  if (outVecType && outVecType.isScalable())
+    return failure();
+
+  Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+  if (scaleType.getIntOrFloatBitWidth() < 32)
+    scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+  else if (scaleType.getIntOrFloatBitWidth() > 32)
+    scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+  if (!outVecType) {
+      Value inCast = rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+      Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(loc, extScaleResultType, inCast, scale, 0);
+      scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+      return success();
+  }
+
+  Value origScale = getOriginalVectorValue(scale);
+  Type origScaleType = origScale.getType();
+  VectorType origScaleVecType = isa<VectorType>(origScaleType) ? cast<VectorType>(origScaleType) : VectorType::get(1, origScaleType);
+  
+  ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
+  ArrayRef<int64_t> inShape = inVecType.getShape();
+
+  SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+  paddedScaleShape.insert(paddedScaleShape.end(), inShape.size() - originalScaleShape.size(),
+                       1);
+
+  auto ratio = computeShapeRatio(inShape, paddedScaleShape);
+  if (!ratio)
+    return failure();
+
+  const int64_t blockSize = computeProduct(*ratio);
+
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, outType, rewriter.getFloatAttr(outType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
+    SmallVector<int64_t> strides(offsets.size(), 1);
+    Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, in, offsets, *ratio, strides);
+    VectorType block1DType = VectorType::get(blockSize, inType);
+    Value block1D =
+        rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+    Value uniformScale =
+        rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+    VectorType blockResultType = VectorType::get(blockSize, outType);
+    Value blockResult =
+        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+    for (int64_t i = 0, sliceWidth = opWidth - blockSize % opWidth;
+         i < blockSize;
+         i += sliceWidth, sliceWidth = opWidth - blockSize % opWidth) {
+      Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, block1D, i, sliceWidth, 1);
+      Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+          loc, extScaleResultType, slice, uniformScale, 0);
+      if (sliceWidth != opWidth)
+        scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, scaleExt, 0, sliceWidth, 1);
+      blockResult = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, scaleExt, blockResult, i, 1);
+    }
+
+    VectorType resultType = VectorType::get(*ratio, outType);
+    Value cast = rewriter.create<vector::ShapeCastOp>(loc, resultType,
+                                                            blockResult);
+    result = rewriter.create<vector::InsertStridedSliceOp>(
+        loc, cast, result, offsets, strides);
+  }
+
+  rewriter.replaceOp(op, result);
+
+  return success();
+}
+
+LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const {
+  return success();
+}
+
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
     RewritePatternSet &patterns, bool convertFP8Arithmetic,
     bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +566,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
   }
   if (allowPackedF16Rtz)
     patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+
+  if (chipset >= kGfx950) {
+    patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), chipset);
+    patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), chipset);
+  }
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
new file mode 100644
index 0000000000000..1669926ae48cc
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
@@ -0,0 +1,553 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+
+// CHECK-LABEL: @conversion_f8_fallback
+// CHECK:         [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT:    [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT:    [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT:    [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT:    [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT:    [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT:    [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT:    [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT:    [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT:    [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT:    [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT:    [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT:    [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT:    [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT:    [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT:    [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT:    [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT:    [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT:    [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT:    [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT:    [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT:    return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f8_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+    return %ext : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @conversion_f4_fallback
+// CHECK:         [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT:    [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT:    [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT:    [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT:    [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT:    [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT:    [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT:    [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT:    [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT:    [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT:    [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT:    [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT:    [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT:    [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT:    [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT:    [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT:    [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT:    [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT:    [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT:    [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT:    [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT:    return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f4_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+    return %ext : vector<2x2xf32>
+}
+
+
+// CHECK-LABEL: @conversion_broadcast
+// CHECK:       [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf32>
+// CHECK-NEXT:  [[BCAST:%.+]] = vector.broadcast %arg1 : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+// CHECK-NEXT:  [[IN_CAST:%.+]] = vector.shape_cast %arg0 : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
+// CHECK-NEXT:  [[SCALE_CAST:%.+]] = vector.shape_cast [[BCAST]] : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+// CHECK-NEXT:  [[SCALE_EXT:%.+]] = arith.extf [[SCALE_CAST]] : vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32>
+// CHECK-NEXT:  [[IN_SLICE_0:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_0:%.+]] = vector.shape_cast [[IN_SLICE_0]]
+// CHECK-NEXT:  [[SCALE_SCALAR_0:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 0]
+// CHECK-NEXT:  [[PACKED_0:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_0]][0], [[SCALE_SCALAR_0]]
+// CHECK-NEXT:  [[OUT_SLICE_0:%.+]] = vector.extract_strided_slice [[PACKED_0]]
+// CHECK-NEXT:  [[OUT_SCALAR_0:%.+]] = vector.shape_cast [[OUT_SLICE_0]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_0]], [[CST]]
+// CHECK-NEXT:  [[IN_SLICE_1:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_1:%.+]] = vector.shape_cast [[IN_SLICE_1]]
+// CHECK-NEXT:  [[SCALE_SCALAR_1:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 1]
+// CHECK-NEXT:  [[PACKED_1:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_1]][0], [[SCALE_SCALAR_1]]
+// CHECK-NEXT:  [[OUT_SLICE_1:%.+]] = vector.extract_strided_slice [[PACKED_1]]
+// CHECK-NEXT:  [[OUT_SCALAR_1:%.+]] = vector.shape_cast [[OUT_SLICE_1]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_1]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_2:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_2:%.+]] = vector.shape_cast [[IN_SLICE_2]]
+// CHECK-NEXT:  [[SCALE_SCALAR_2:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 2]
+// CHECK-NEXT:  [[PACKED_2:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_2]][0], [[SCALE_SCALAR_2]]
+// CHECK-NEXT:  [[OUT_SLICE_2:%.+]] = vector.extract_strided_slice [[PACKED_2]]
+// CHECK-NEXT:  [[OUT_SCALAR_2:%.+]] = vector.shape_cast [[OUT_SLICE_2]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_2]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_3:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_3:%.+]] = vector.shape_cast [[IN_SLICE_3]]
+// CHECK-NEXT:  [[SCALE_SCALAR_3:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 3]
+// CHECK-NEXT:  [[PACKED_3:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_3]][0], [[SCALE_SCALAR_3]]
+// CHECK-NEXT:  [[OUT_SLICE_3:%.+]] = vector.extract_strided_slice [[PACKED_3]]
+// CHECK-NEXT:  [[OUT_SCALAR_3:%.+]] = vector.shape_cast [[OUT_SLICE_3]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_3]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_4:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_4:%.+]] = vector.shape_cast [[IN_SLICE_4]]
+// CHECK-NEXT:  [[SCALE_SCALAR_4:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 0]
+// CHECK-NEXT:  [[PACKED_4:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_4]][0], [[SCALE_SCALAR_4]]
+// CHECK-NEXT:  [[OUT_SLICE_4:%.+]] = vector.extract_strided_slice [[PACKED_4]]
+// CHECK-NEXT:  [[OUT_SCALAR_4:%.+]] = vector.shape_cast [[OUT_SLICE_4]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_4]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_5:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_5:%.+]] = vector.shape_cast [[IN_SLICE_5]]
+// CHECK-NEXT:  [[SCALE_SCALAR_5:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 1]
+// CHECK-NEXT:  [[PACKED_5:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_5]][0], [[SCALE_SCALAR_5]]
+// CHECK-NEXT:  [[OUT_SLICE_5:%.+]] = vector.extract_strided_slice [[PACKED_5]]
+// CHECK-NEXT:  [[OUT_SCALAR_5:%.+]] = vector.shape_cast [[OUT_SLICE_5]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_5]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_6:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_6:%.+]] = vector.shape_cast [[IN_SLICE_6]]
+// CHECK-NEXT:  [[SCALE_SCALAR_6:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 2]
+// CHECK-NEXT:  [[PACKED_6:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_6]][0], [[SCALE_SCALAR_6]]
+// CHECK-NEXT:  [[OUT_SLICE_6:%.+]] = vector.extract_strided_slice [[PACKED_6]]
+// CHECK-NEXT:  [[OUT_SCALAR_6:%.+]] = vector.shape_cast [[OUT_SLICE_6]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_6]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_7:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_7:%.+]] = vector.shape_cast [[IN_SLICE_7]]
+// CHECK-NEXT:  [[SCALE_SCALAR_7:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 3]
+// CHECK-NEXT:  [[PACKED_7:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_7]][0], [[SCALE_SCALAR_7]]
+// CHECK-NEXT:  [[OUT_SLICE_7:%.+]] = vector.extract_strided_slice [[PACKED_7]]
+// CHECK-NEXT:  [[OUT_SCALAR_7:%.+]] = vector.shape_cast [[OUT_SLICE_7]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_7]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_8:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_8:%.+]] = vector.shape_cast [[IN_SLICE_8]]
+// CHECK-NEXT:  [[SCALE_SCALAR_8:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 0]
+// CHECK-NEXT:  [[PACKED_8:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_8]][0], [[SCALE_SCALAR_8]]
+// CHECK-NEXT:  [[OUT_SLICE_8:%.+]] = vector.extract_strided_slice [[PACKED_8]]
+// CHECK-NEXT:  [[OUT_SCALAR_8:%.+]] = vector.shape_cast [[OUT_SLICE_8]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_8]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_9:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_9:%.+]] = vector.shape_cast [[IN_SLICE_9]]
+// CHECK-NEXT:  [[SCALE_SCALAR_9:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 1]
+// CHECK-NEXT:  [[PACKED_9:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_9]][0], [[SCALE_SCALAR_9]]
+// CHECK-NEXT:  [[OUT_SLICE_9:%.+]]...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146372


More information about the Mlir-commits mailing list