[Mlir-commits] [mlir] [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu (PR #146372)
Tim Gymnich
llvmlistbot at llvm.org
Mon Jun 30 08:41:45 PDT 2025
https://github.com/tgymnich updated https://github.com/llvm/llvm-project/pull/146372
>From 8e6274d30c9b643013efd5674293644702269ebb Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Thu, 26 Jun 2025 09:54:55 +0000
Subject: [PATCH 1/3] [mlir][amdgpu] Add conversion for arith.scaling_extf to
amdgpu
---
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 165 ++++++++++++++++++
1 file changed, 165 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3596b3235a631..22cd4703c6005 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -15,11 +15,17 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
@@ -32,6 +38,7 @@ using namespace mlir::amdgpu;
namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
struct ArithToAMDGPUConversionPass final
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +80,28 @@ struct TruncfToFloat16RewritePattern final
PatternRewriter &rewriter) const override;
};
+struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ Chipset chipset;
+ ScalingExtFRewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+ LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ Chipset chipset;
+ ScalingTruncFRewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+ LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+ PatternRewriter &rewriter) const override;
+};
+
} // end namespace
static bool isSupportedF8(Type elementType, Chipset chipset) {
@@ -395,6 +424,137 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
return success();
}
+static Value getOriginalVectorValue(Value value) {
+ Value current = value;
+ while (Operation *definingOp = current.getDefiningOp()) {
+ bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+ .Case<vector::ShapeCastOp>(
+ [¤t](auto op) {
+ current = op.getSource();
+ return true;
+ })
+ .Case<vector::BroadcastOp>(
+ [¤t](auto op) {
+ current = op.getSource();
+ return false;
+ })
+ .Case<vector::SplatOp>(
+ [¤t](auto op) {
+ current = op.getInput();
+ return false;
+ })
+ .Default([](Operation *) {
+ return false;
+ });
+
+ if (!skipOp) {
+ break;
+ }
+ }
+ return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ constexpr const int64_t opWidth = 2;
+
+ Value in = op.getIn();
+ Value scale = op.getScale();
+ Value out = op.getOut();
+
+ Type f32 = rewriter.getF32Type();
+ Type inType = getElementTypeOrSelf(in);
+ Type scaleType = getElementTypeOrSelf(scale);
+ Type outType = getElementTypeOrSelf(out);
+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+ VectorType inVecType = dyn_cast<VectorType>(in.getType());
+ VectorType outVecType = dyn_cast<VectorType>(out.getType());
+
+ if (outVecType && outVecType.isScalable())
+ return failure();
+
+ Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+ if (scaleType.getIntOrFloatBitWidth() < 32)
+ scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ else if (scaleType.getIntOrFloatBitWidth() > 32)
+ scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+ VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+ if (!outVecType) {
+ Value inCast = rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(loc, extScaleResultType, inCast, scale, 0);
+ scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+ return success();
+ }
+
+ Value origScale = getOriginalVectorValue(scale);
+ Type origScaleType = origScale.getType();
+ VectorType origScaleVecType = isa<VectorType>(origScaleType) ? cast<VectorType>(origScaleType) : VectorType::get(1, origScaleType);
+
+ ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
+ ArrayRef<int64_t> inShape = inVecType.getShape();
+
+ SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+ paddedScaleShape.insert(paddedScaleShape.end(), inShape.size() - originalScaleShape.size(),
+ 1);
+
+ auto ratio = computeShapeRatio(inShape, paddedScaleShape);
+ if (!ratio)
+ return failure();
+
+ const int64_t blockSize = computeProduct(*ratio);
+
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+ for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
+ SmallVector<int64_t> strides(offsets.size(), 1);
+ Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, in, offsets, *ratio, strides);
+ VectorType block1DType = VectorType::get(blockSize, inType);
+ Value block1D =
+ rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ Value uniformScale =
+ rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+ VectorType blockResultType = VectorType::get(blockSize, outType);
+ Value blockResult =
+ rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+ for (int64_t i = 0, sliceWidth = opWidth - blockSize % opWidth;
+ i < blockSize;
+ i += sliceWidth, sliceWidth = opWidth - blockSize % opWidth) {
+ Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, block1D, i, sliceWidth, 1);
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+ loc, extScaleResultType, slice, uniformScale, 0);
+ if (sliceWidth != opWidth)
+ scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, scaleExt, 0, sliceWidth, 1);
+ blockResult = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, scaleExt, blockResult, i, 1);
+ }
+
+ VectorType resultType = VectorType::get(*ratio, outType);
+ Value cast = rewriter.create<vector::ShapeCastOp>(loc, resultType,
+ blockResult);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, cast, result, offsets, strides);
+ }
+
+ rewriter.replaceOp(op, result);
+
+ return success();
+}
+
+LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const {
+ return success();
+}
+
void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +566,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
}
if (allowPackedF16Rtz)
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+
+ if (chipset >= kGfx950) {
+ patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), chipset);
+ patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), chipset);
+ }
}
void ArithToAMDGPUConversionPass::runOnOperation() {
>From 26c837729a19a1f9e3cd5e22c5db4f35a77f73cb Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 30 Jun 2025 12:18:36 +0000
Subject: [PATCH 2/3] add tests
---
.../Conversion/ArithToAMDGPU/scale_ext.mlir | 553 ++++++++++++++++++
1 file changed, 553 insertions(+)
create mode 100644 mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
new file mode 100644
index 0000000000000..1669926ae48cc
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
@@ -0,0 +1,553 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+
+// CHECK-LABEL: @conversion_f8_fallback
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT: [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT: [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT: [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT: [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT: [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT: [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT: [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT: [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT: [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT: [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT: [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT: [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT: [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT: [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT: [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT: [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT: [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT: [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT: [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT: return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f8_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+ %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+ return %ext : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @conversion_f4_fallback
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT: [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT: [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT: [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT: [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT: [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT: [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT: [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT: [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT: [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT: [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT: [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT: [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT: [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT: [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT: [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT: [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT: [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT: [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT: [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT: [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT: return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f4_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+ %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+ return %ext : vector<2x2xf32>
+}
+
+
+// CHECK-LABEL: @conversion_broadcast
+// CHECK: [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf32>
+// CHECK-NEXT: [[BCAST:%.+]] = vector.broadcast %arg1 : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+// CHECK-NEXT: [[IN_CAST:%.+]] = vector.shape_cast %arg0 : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
+// CHECK-NEXT: [[SCALE_CAST:%.+]] = vector.shape_cast [[BCAST]] : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+// CHECK-NEXT: [[SCALE_EXT:%.+]] = arith.extf [[SCALE_CAST]] : vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32>
+// CHECK-NEXT: [[IN_SLICE_0:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_0:%.+]] = vector.shape_cast [[IN_SLICE_0]]
+// CHECK-NEXT: [[SCALE_SCALAR_0:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 0]
+// CHECK-NEXT: [[PACKED_0:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_0]][0], [[SCALE_SCALAR_0]]
+// CHECK-NEXT: [[OUT_SLICE_0:%.+]] = vector.extract_strided_slice [[PACKED_0]]
+// CHECK-NEXT: [[OUT_SCALAR_0:%.+]] = vector.shape_cast [[OUT_SLICE_0]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_0]], [[CST]]
+// CHECK-NEXT: [[IN_SLICE_1:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_1:%.+]] = vector.shape_cast [[IN_SLICE_1]]
+// CHECK-NEXT: [[SCALE_SCALAR_1:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 1]
+// CHECK-NEXT: [[PACKED_1:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_1]][0], [[SCALE_SCALAR_1]]
+// CHECK-NEXT: [[OUT_SLICE_1:%.+]] = vector.extract_strided_slice [[PACKED_1]]
+// CHECK-NEXT: [[OUT_SCALAR_1:%.+]] = vector.shape_cast [[OUT_SLICE_1]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_1]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_2:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_2:%.+]] = vector.shape_cast [[IN_SLICE_2]]
+// CHECK-NEXT: [[SCALE_SCALAR_2:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 2]
+// CHECK-NEXT: [[PACKED_2:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_2]][0], [[SCALE_SCALAR_2]]
+// CHECK-NEXT: [[OUT_SLICE_2:%.+]] = vector.extract_strided_slice [[PACKED_2]]
+// CHECK-NEXT: [[OUT_SCALAR_2:%.+]] = vector.shape_cast [[OUT_SLICE_2]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_2]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_3:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_3:%.+]] = vector.shape_cast [[IN_SLICE_3]]
+// CHECK-NEXT: [[SCALE_SCALAR_3:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 3]
+// CHECK-NEXT: [[PACKED_3:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_3]][0], [[SCALE_SCALAR_3]]
+// CHECK-NEXT: [[OUT_SLICE_3:%.+]] = vector.extract_strided_slice [[PACKED_3]]
+// CHECK-NEXT: [[OUT_SCALAR_3:%.+]] = vector.shape_cast [[OUT_SLICE_3]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_3]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_4:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_4:%.+]] = vector.shape_cast [[IN_SLICE_4]]
+// CHECK-NEXT: [[SCALE_SCALAR_4:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 0]
+// CHECK-NEXT: [[PACKED_4:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_4]][0], [[SCALE_SCALAR_4]]
+// CHECK-NEXT: [[OUT_SLICE_4:%.+]] = vector.extract_strided_slice [[PACKED_4]]
+// CHECK-NEXT: [[OUT_SCALAR_4:%.+]] = vector.shape_cast [[OUT_SLICE_4]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_4]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_5:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_5:%.+]] = vector.shape_cast [[IN_SLICE_5]]
+// CHECK-NEXT: [[SCALE_SCALAR_5:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 1]
+// CHECK-NEXT: [[PACKED_5:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_5]][0], [[SCALE_SCALAR_5]]
+// CHECK-NEXT: [[OUT_SLICE_5:%.+]] = vector.extract_strided_slice [[PACKED_5]]
+// CHECK-NEXT: [[OUT_SCALAR_5:%.+]] = vector.shape_cast [[OUT_SLICE_5]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_5]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_6:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_6:%.+]] = vector.shape_cast [[IN_SLICE_6]]
+// CHECK-NEXT: [[SCALE_SCALAR_6:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 2]
+// CHECK-NEXT: [[PACKED_6:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_6]][0], [[SCALE_SCALAR_6]]
+// CHECK-NEXT: [[OUT_SLICE_6:%.+]] = vector.extract_strided_slice [[PACKED_6]]
+// CHECK-NEXT: [[OUT_SCALAR_6:%.+]] = vector.shape_cast [[OUT_SLICE_6]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_6]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_7:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_7:%.+]] = vector.shape_cast [[IN_SLICE_7]]
+// CHECK-NEXT: [[SCALE_SCALAR_7:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 3]
+// CHECK-NEXT: [[PACKED_7:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_7]][0], [[SCALE_SCALAR_7]]
+// CHECK-NEXT: [[OUT_SLICE_7:%.+]] = vector.extract_strided_slice [[PACKED_7]]
+// CHECK-NEXT: [[OUT_SCALAR_7:%.+]] = vector.shape_cast [[OUT_SLICE_7]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_7]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_8:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_8:%.+]] = vector.shape_cast [[IN_SLICE_8]]
+// CHECK-NEXT: [[SCALE_SCALAR_8:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 0]
+// CHECK-NEXT: [[PACKED_8:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_8]][0], [[SCALE_SCALAR_8]]
+// CHECK-NEXT: [[OUT_SLICE_8:%.+]] = vector.extract_strided_slice [[PACKED_8]]
+// CHECK-NEXT: [[OUT_SCALAR_8:%.+]] = vector.shape_cast [[OUT_SLICE_8]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_8]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_9:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_9:%.+]] = vector.shape_cast [[IN_SLICE_9]]
+// CHECK-NEXT: [[SCALE_SCALAR_9:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 1]
+// CHECK-NEXT: [[PACKED_9:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_9]][0], [[SCALE_SCALAR_9]]
+// CHECK-NEXT: [[OUT_SLICE_9:%.+]] = vector.extract_strided_slice [[PACKED_9]]
+// CHECK-NEXT: [[OUT_SCALAR_9:%.+]] = vector.shape_cast [[OUT_SLICE_9]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_9]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_10:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT: [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 2]
+// CHECK-NEXT: [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT: [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT: [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_11:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT: [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 3]
+// CHECK-NEXT: [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT: [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT: [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_12:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_12:%.+]] = vector.shape_cast [[IN_SLICE_12]]
+// CHECK-NEXT: [[SCALE_SCALAR_12:%.+]] = vector.extract [[SCALE_EXT]][1, 1, 0]
+// CHECK-NEXT: [[PACKED_12:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_12]][0], [[SCALE_SCALAR_12]]
+// CHECK-NEXT: [[OUT_SLICE_12:%.+]] = vector.extract_strided_slice [[PACKED_12]]
+// CHECK-NEXT: [[OUT_SCALAR_12:%.+]] = vector.shape_cast [[OUT_SLICE_12]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_12]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_13:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_13:%.+]] = vector.shape_cast [[IN_SLICE_13]]
+// CHECK-NEXT: [[SCALE_SCALAR_13:%.+]] = vector.extract [[SCALE_EXT]][1, 1, 1]
+// CHECK-NEXT: [[PACKED_13:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_13]][0], [[SCALE_SCALAR_13]]
+// CHECK-NEXT: [[OUT_SLICE_13:%.+]] = vector.extract_strided_slice [[PACKED_13]]
+// CHECK-NEXT: [[OUT_SCALAR_13:%.+]] = vector.shape_cast [[OUT_SLICE_13]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_13]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_14:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_14:%.+]] = vector.shape_cast [[IN_SLICE_14]]
+// CHECK-NEXT: [[SCALE_SCALAR_14:%.+]] = vector.extract [[SCALE_EXT]][1, 1, 2]
+// CHECK-NEXT: [[PACKED_14:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_14]][0], [[SCALE_SCALAR_14]]
+// CHECK-NEXT: [[OUT_SLICE_14:%.+]] = vector.extract_strided_slice [[PACKED_14]]
+// CHECK-NEXT: [[OUT_SCALAR_14:%.+]] = vector.shape_cast [[OUT_SLICE_14]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_14]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_15:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_15:%.+]] = vector.shape_cast [[IN_SLICE_15]]
+// CHECK-NEXT: [[SCALE_SCALAR_15:%.+]] = vector.extract [[SCALE_EXT]][1, 1, 3]
+// CHECK-NEXT: [[PACKED_15:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_15]][0], [[SCALE_SCALAR_15]]
+// CHECK-NEXT: [[OUT_SLICE_15:%.+]] = vector.extract_strided_slice [[PACKED_15]]
+// CHECK-NEXT: [[OUT_SCALAR_15:%.+]] = vector.shape_cast [[OUT_SLICE_15]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_15]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_16:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_16:%.+]] = vector.shape_cast [[IN_SLICE_16]]
+// CHECK-NEXT: [[SCALE_SCALAR_16:%.+]] = vector.extract [[SCALE_EXT]][2, 0, 0]
+// CHECK-NEXT: [[PACKED_16:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_16]][0], [[SCALE_SCALAR_16]]
+// CHECK-NEXT: [[OUT_SLICE_16:%.+]] = vector.extract_strided_slice [[PACKED_16]]
+// CHECK-NEXT: [[OUT_SCALAR_16:%.+]] = vector.shape_cast [[OUT_SLICE_16]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_16]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_17:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_17:%.+]] = vector.shape_cast [[IN_SLICE_17]]
+// CHECK-NEXT: [[SCALE_SCALAR_17:%.+]] = vector.extract [[SCALE_EXT]][2, 0, 1]
+// CHECK-NEXT: [[PACKED_17:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_17]][0], [[SCALE_SCALAR_17]]
+// CHECK-NEXT: [[OUT_SLICE_17:%.+]] = vector.extract_strided_slice [[PACKED_17]]
+// CHECK-NEXT: [[OUT_SCALAR_17:%.+]] = vector.shape_cast [[OUT_SLICE_17]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_17]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_18:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_18:%.+]] = vector.shape_cast [[IN_SLICE_18]]
+// CHECK-NEXT: [[SCALE_SCALAR_18:%.+]] = vector.extract [[SCALE_EXT]][2, 0, 2]
+// CHECK-NEXT: [[PACKED_18:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_18]][0], [[SCALE_SCALAR_18]]
+// CHECK-NEXT: [[OUT_SLICE_18:%.+]] = vector.extract_strided_slice [[PACKED_18]]
+// CHECK-NEXT: [[OUT_SCALAR_18:%.+]] = vector.shape_cast [[OUT_SLICE_18]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_18]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_19:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_19:%.+]] = vector.shape_cast [[IN_SLICE_19]]
+// CHECK-NEXT: [[SCALE_SCALAR_19:%.+]] = vector.extract [[SCALE_EXT]][2, 0, 3]
+// CHECK-NEXT: [[PACKED_19:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_19]][0], [[SCALE_SCALAR_19]]
+// CHECK-NEXT: [[OUT_SLICE_19:%.+]] = vector.extract_strided_slice [[PACKED_19]]
+// CHECK-NEXT: [[OUT_SCALAR_19:%.+]] = vector.shape_cast [[OUT_SLICE_19]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_19]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_20:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_20:%.+]] = vector.shape_cast [[IN_SLICE_20]]
+// CHECK-NEXT: [[SCALE_SCALAR_20:%.+]] = vector.extract [[SCALE_EXT]][2, 1, 0]
+// CHECK-NEXT: [[PACKED_20:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_20]][0], [[SCALE_SCALAR_20]]
+// CHECK-NEXT: [[OUT_SLICE_20:%.+]] = vector.extract_strided_slice [[PACKED_20]]
+// CHECK-NEXT: [[OUT_SCALAR_20:%.+]] = vector.shape_cast [[OUT_SLICE_20]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_20]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_21:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_21:%.+]] = vector.shape_cast [[IN_SLICE_21]]
+// CHECK-NEXT: [[SCALE_SCALAR_21:%.+]] = vector.extract [[SCALE_EXT]][2, 1, 1]
+// CHECK-NEXT: [[PACKED_21:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_21]][0], [[SCALE_SCALAR_21]]
+// CHECK-NEXT: [[OUT_SLICE_21:%.+]] = vector.extract_strided_slice [[PACKED_21]]
+// CHECK-NEXT: [[OUT_SCALAR_21:%.+]] = vector.shape_cast [[OUT_SLICE_21]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_21]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_22:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_22:%.+]] = vector.shape_cast [[IN_SLICE_22]]
+// CHECK-NEXT: [[SCALE_SCALAR_22:%.+]] = vector.extract [[SCALE_EXT]][2, 1, 2]
+// CHECK-NEXT: [[PACKED_22:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_22]][0], [[SCALE_SCALAR_22]]
+// CHECK-NEXT: [[OUT_SLICE_22:%.+]] = vector.extract_strided_slice [[PACKED_22]]
+// CHECK-NEXT: [[OUT_SCALAR_22:%.+]] = vector.shape_cast [[OUT_SLICE_22]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_22]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_23:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_23:%.+]] = vector.shape_cast [[IN_SLICE_23]]
+// CHECK-NEXT: [[SCALE_SCALAR_23:%.+]] = vector.extract [[SCALE_EXT]][2, 1, 3]
+// CHECK-NEXT: [[PACKED_23:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_23]][0], [[SCALE_SCALAR_23]]
+// CHECK-NEXT: [[OUT_SLICE_23:%.+]] = vector.extract_strided_slice [[PACKED_23]]
+// CHECK-NEXT: [[OUT_SCALAR_23:%.+]] = vector.shape_cast [[OUT_SLICE_23]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_23]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_24:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_24:%.+]] = vector.shape_cast [[IN_SLICE_24]]
+// CHECK-NEXT: [[SCALE_SCALAR_24:%.+]] = vector.extract [[SCALE_EXT]][3, 0, 0]
+// CHECK-NEXT: [[PACKED_24:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_24]][0], [[SCALE_SCALAR_24]]
+// CHECK-NEXT: [[OUT_SLICE_24:%.+]] = vector.extract_strided_slice [[PACKED_24]]
+// CHECK-NEXT: [[OUT_SCALAR_24:%.+]] = vector.shape_cast [[OUT_SLICE_24]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_24]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_25:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_25:%.+]] = vector.shape_cast [[IN_SLICE_25]]
+// CHECK-NEXT: [[SCALE_SCALAR_25:%.+]] = vector.extract [[SCALE_EXT]][3, 0, 1]
+// CHECK-NEXT: [[PACKED_25:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_25]][0], [[SCALE_SCALAR_25]]
+// CHECK-NEXT: [[OUT_SLICE_25:%.+]] = vector.extract_strided_slice [[PACKED_25]]
+// CHECK-NEXT: [[OUT_SCALAR_25:%.+]] = vector.shape_cast [[OUT_SLICE_25]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_25]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_26:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_26:%.+]] = vector.shape_cast [[IN_SLICE_26]]
+// CHECK-NEXT: [[SCALE_SCALAR_26:%.+]] = vector.extract [[SCALE_EXT]][3, 0, 2]
+// CHECK-NEXT: [[PACKED_26:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_26]][0], [[SCALE_SCALAR_26]]
+// CHECK-NEXT: [[OUT_SLICE_26:%.+]] = vector.extract_strided_slice [[PACKED_26]]
+// CHECK-NEXT: [[OUT_SCALAR_26:%.+]] = vector.shape_cast [[OUT_SLICE_26]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_26]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_27:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_27:%.+]] = vector.shape_cast [[IN_SLICE_27]]
+// CHECK-NEXT: [[SCALE_SCALAR_27:%.+]] = vector.extract [[SCALE_EXT]][3, 0, 3]
+// CHECK-NEXT: [[PACKED_27:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_27]][0], [[SCALE_SCALAR_27]]
+// CHECK-NEXT: [[OUT_SLICE_27:%.+]] = vector.extract_strided_slice [[PACKED_27]]
+// CHECK-NEXT: [[OUT_SCALAR_27:%.+]] = vector.shape_cast [[OUT_SLICE_27]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_27]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_28:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_28:%.+]] = vector.shape_cast [[IN_SLICE_28]]
+// CHECK-NEXT: [[SCALE_SCALAR_28:%.+]] = vector.extract [[SCALE_EXT]][3, 1, 0]
+// CHECK-NEXT: [[PACKED_28:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_28]][0], [[SCALE_SCALAR_28]]
+// CHECK-NEXT: [[OUT_SLICE_28:%.+]] = vector.extract_strided_slice [[PACKED_28]]
+// CHECK-NEXT: [[OUT_SCALAR_28:%.+]] = vector.shape_cast [[OUT_SLICE_28]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_28]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_29:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_29:%.+]] = vector.shape_cast [[IN_SLICE_29]]
+// CHECK-NEXT: [[SCALE_SCALAR_29:%.+]] = vector.extract [[SCALE_EXT]][3, 1, 1]
+// CHECK-NEXT: [[PACKED_29:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_29]][0], [[SCALE_SCALAR_29]]
+// CHECK-NEXT: [[OUT_SLICE_29:%.+]] = vector.extract_strided_slice [[PACKED_29]]
+// CHECK-NEXT: [[OUT_SCALAR_29:%.+]] = vector.shape_cast [[OUT_SLICE_29]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_29]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_30:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_30:%.+]] = vector.shape_cast [[IN_SLICE_30]]
+// CHECK-NEXT: [[SCALE_SCALAR_30:%.+]] = vector.extract [[SCALE_EXT]][3, 1, 2]
+// CHECK-NEXT: [[PACKED_30:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_30]][0], [[SCALE_SCALAR_30]]
+// CHECK-NEXT: [[OUT_SLICE_30:%.+]] = vector.extract_strided_slice [[PACKED_30]]
+// CHECK-NEXT: [[OUT_SCALAR_30:%.+]] = vector.shape_cast [[OUT_SLICE_30]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_30]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_31:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_31:%.+]] = vector.shape_cast [[IN_SLICE_31]]
+// CHECK-NEXT: [[SCALE_SCALAR_31:%.+]] = vector.extract [[SCALE_EXT]][3, 1, 3]
+// CHECK-NEXT: [[PACKED_31:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_31]][0], [[SCALE_SCALAR_31]]
+// CHECK-NEXT: [[OUT_SLICE_31:%.+]] = vector.extract_strided_slice [[PACKED_31]]
+// CHECK-NEXT: [[OUT_SCALAR_31:%.+]] = vector.shape_cast [[OUT_SLICE_31]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_31]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_32:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_32:%.+]] = vector.shape_cast [[IN_SLICE_32]]
+// CHECK-NEXT: [[SCALE_SCALAR_32:%.+]] = vector.extract [[SCALE_EXT]][4, 0, 0]
+// CHECK-NEXT: [[PACKED_32:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_32]][0], [[SCALE_SCALAR_32]]
+// CHECK-NEXT: [[OUT_SLICE_32:%.+]] = vector.extract_strided_slice [[PACKED_32]]
+// CHECK-NEXT: [[OUT_SCALAR_32:%.+]] = vector.shape_cast [[OUT_SLICE_32]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_32]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_33:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_33:%.+]] = vector.shape_cast [[IN_SLICE_33]]
+// CHECK-NEXT: [[SCALE_SCALAR_33:%.+]] = vector.extract [[SCALE_EXT]][4, 0, 1]
+// CHECK-NEXT: [[PACKED_33:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_33]][0], [[SCALE_SCALAR_33]]
+// CHECK-NEXT: [[OUT_SLICE_33:%.+]] = vector.extract_strided_slice [[PACKED_33]]
+// CHECK-NEXT: [[OUT_SCALAR_33:%.+]] = vector.shape_cast [[OUT_SLICE_33]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_33]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_34:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_34:%.+]] = vector.shape_cast [[IN_SLICE_34]]
+// CHECK-NEXT: [[SCALE_SCALAR_34:%.+]] = vector.extract [[SCALE_EXT]][4, 0, 2]
+// CHECK-NEXT: [[PACKED_34:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_34]][0], [[SCALE_SCALAR_34]]
+// CHECK-NEXT: [[OUT_SLICE_34:%.+]] = vector.extract_strided_slice [[PACKED_34]]
+// CHECK-NEXT: [[OUT_SCALAR_34:%.+]] = vector.shape_cast [[OUT_SLICE_34]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_34]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_35:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_35:%.+]] = vector.shape_cast [[IN_SLICE_35]]
+// CHECK-NEXT: [[SCALE_SCALAR_35:%.+]] = vector.extract [[SCALE_EXT]][4, 0, 3]
+// CHECK-NEXT: [[PACKED_35:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_35]][0], [[SCALE_SCALAR_35]]
+// CHECK-NEXT: [[OUT_SLICE_35:%.+]] = vector.extract_strided_slice [[PACKED_35]]
+// CHECK-NEXT: [[OUT_SCALAR_35:%.+]] = vector.shape_cast [[OUT_SLICE_35]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_35]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_36:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_36:%.+]] = vector.shape_cast [[IN_SLICE_36]]
+// CHECK-NEXT: [[SCALE_SCALAR_36:%.+]] = vector.extract [[SCALE_EXT]][4, 1, 0]
+// CHECK-NEXT: [[PACKED_36:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_36]][0], [[SCALE_SCALAR_36]]
+// CHECK-NEXT: [[OUT_SLICE_36:%.+]] = vector.extract_strided_slice [[PACKED_36]]
+// CHECK-NEXT: [[OUT_SCALAR_36:%.+]] = vector.shape_cast [[OUT_SLICE_36]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_36]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_37:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_37:%.+]] = vector.shape_cast [[IN_SLICE_37]]
+// CHECK-NEXT: [[SCALE_SCALAR_37:%.+]] = vector.extract [[SCALE_EXT]][4, 1, 1]
+// CHECK-NEXT: [[PACKED_37:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_37]][0], [[SCALE_SCALAR_37]]
+// CHECK-NEXT: [[OUT_SLICE_37:%.+]] = vector.extract_strided_slice [[PACKED_37]]
+// CHECK-NEXT: [[OUT_SCALAR_37:%.+]] = vector.shape_cast [[OUT_SLICE_37]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_37]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_38:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_38:%.+]] = vector.shape_cast [[IN_SLICE_38]]
+// CHECK-NEXT: [[SCALE_SCALAR_38:%.+]] = vector.extract [[SCALE_EXT]][4, 1, 2]
+// CHECK-NEXT: [[PACKED_38:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_38]][0], [[SCALE_SCALAR_38]]
+// CHECK-NEXT: [[OUT_SLICE_38:%.+]] = vector.extract_strided_slice [[PACKED_38]]
+// CHECK-NEXT: [[OUT_SCALAR_38:%.+]] = vector.shape_cast [[OUT_SLICE_38]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_38]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_39:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_39:%.+]] = vector.shape_cast [[IN_SLICE_39]]
+// CHECK-NEXT: [[SCALE_SCALAR_39:%.+]] = vector.extract [[SCALE_EXT]][4, 1, 3]
+// CHECK-NEXT: [[PACKED_39:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_39]][0], [[SCALE_SCALAR_39]]
+// CHECK-NEXT: [[OUT_SLICE_39:%.+]] = vector.extract_strided_slice [[PACKED_39]]
+// CHECK-NEXT: [[OUT_SCALAR_39:%.+]] = vector.shape_cast [[OUT_SLICE_39]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_39]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_40:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_40:%.+]] = vector.shape_cast [[IN_SLICE_40]]
+// CHECK-NEXT: [[SCALE_SCALAR_40:%.+]] = vector.extract [[SCALE_EXT]][5, 0, 0]
+// CHECK-NEXT: [[PACKED_40:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_40]][0], [[SCALE_SCALAR_40]]
+// CHECK-NEXT: [[OUT_SLICE_40:%.+]] = vector.extract_strided_slice [[PACKED_40]]
+// CHECK-NEXT: [[OUT_SCALAR_40:%.+]] = vector.shape_cast [[OUT_SLICE_40]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_40]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_41:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_41:%.+]] = vector.shape_cast [[IN_SLICE_41]]
+// CHECK-NEXT: [[SCALE_SCALAR_41:%.+]] = vector.extract [[SCALE_EXT]][5, 0, 1]
+// CHECK-NEXT: [[PACKED_41:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_41]][0], [[SCALE_SCALAR_41]]
+// CHECK-NEXT: [[OUT_SLICE_41:%.+]] = vector.extract_strided_slice [[PACKED_41]]
+// CHECK-NEXT: [[OUT_SCALAR_41:%.+]] = vector.shape_cast [[OUT_SLICE_41]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_41]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_42:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_42:%.+]] = vector.shape_cast [[IN_SLICE_42]]
+// CHECK-NEXT: [[SCALE_SCALAR_42:%.+]] = vector.extract [[SCALE_EXT]][5, 0, 2]
+// CHECK-NEXT: [[PACKED_42:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_42]][0], [[SCALE_SCALAR_42]]
+// CHECK-NEXT: [[OUT_SLICE_42:%.+]] = vector.extract_strided_slice [[PACKED_42]]
+// CHECK-NEXT: [[OUT_SCALAR_42:%.+]] = vector.shape_cast [[OUT_SLICE_42]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_42]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_43:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_43:%.+]] = vector.shape_cast [[IN_SLICE_43]]
+// CHECK-NEXT: [[SCALE_SCALAR_43:%.+]] = vector.extract [[SCALE_EXT]][5, 0, 3]
+// CHECK-NEXT: [[PACKED_43:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_43]][0], [[SCALE_SCALAR_43]]
+// CHECK-NEXT: [[OUT_SLICE_43:%.+]] = vector.extract_strided_slice [[PACKED_43]]
+// CHECK-NEXT: [[OUT_SCALAR_43:%.+]] = vector.shape_cast [[OUT_SLICE_43]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_43]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_44:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_44:%.+]] = vector.shape_cast [[IN_SLICE_44]]
+// CHECK-NEXT: [[SCALE_SCALAR_44:%.+]] = vector.extract [[SCALE_EXT]][5, 1, 0]
+// CHECK-NEXT: [[PACKED_44:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_44]][0], [[SCALE_SCALAR_44]]
+// CHECK-NEXT: [[OUT_SLICE_44:%.+]] = vector.extract_strided_slice [[PACKED_44]]
+// CHECK-NEXT: [[OUT_SCALAR_44:%.+]] = vector.shape_cast [[OUT_SLICE_44]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_44]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_45:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_45:%.+]] = vector.shape_cast [[IN_SLICE_45]]
+// CHECK-NEXT: [[SCALE_SCALAR_45:%.+]] = vector.extract [[SCALE_EXT]][5, 1, 1]
+// CHECK-NEXT: [[PACKED_45:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_45]][0], [[SCALE_SCALAR_45]]
+// CHECK-NEXT: [[OUT_SLICE_45:%.+]] = vector.extract_strided_slice [[PACKED_45]]
+// CHECK-NEXT: [[OUT_SCALAR_45:%.+]] = vector.shape_cast [[OUT_SLICE_45]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_45]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_46:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_46:%.+]] = vector.shape_cast [[IN_SLICE_46]]
+// CHECK-NEXT: [[SCALE_SCALAR_46:%.+]] = vector.extract [[SCALE_EXT]][5, 1, 2]
+// CHECK-NEXT: [[PACKED_46:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_46]][0], [[SCALE_SCALAR_46]]
+// CHECK-NEXT: [[OUT_SLICE_46:%.+]] = vector.extract_strided_slice [[PACKED_46]]
+// CHECK-NEXT: [[OUT_SCALAR_46:%.+]] = vector.shape_cast [[OUT_SLICE_46]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_46]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_47:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_47:%.+]] = vector.shape_cast [[IN_SLICE_47]]
+// CHECK-NEXT: [[SCALE_SCALAR_47:%.+]] = vector.extract [[SCALE_EXT]][5, 1, 3]
+// CHECK-NEXT: [[PACKED_47:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_47]][0], [[SCALE_SCALAR_47]]
+// CHECK-NEXT: [[OUT_SLICE_47:%.+]] = vector.extract_strided_slice [[PACKED_47]]
+// CHECK-NEXT: [[OUT_SCALAR_47:%.+]] = vector.shape_cast [[OUT_SLICE_47]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_47]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_48:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_48:%.+]] = vector.shape_cast [[IN_SLICE_48]]
+// CHECK-NEXT: [[SCALE_SCALAR_48:%.+]] = vector.extract [[SCALE_EXT]][6, 0, 0]
+// CHECK-NEXT: [[PACKED_48:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_48]][0], [[SCALE_SCALAR_48]]
+// CHECK-NEXT: [[OUT_SLICE_48:%.+]] = vector.extract_strided_slice [[PACKED_48]]
+// CHECK-NEXT: [[OUT_SCALAR_48:%.+]] = vector.shape_cast [[OUT_SLICE_48]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_48]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_49:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_49:%.+]] = vector.shape_cast [[IN_SLICE_49]]
+// CHECK-NEXT: [[SCALE_SCALAR_49:%.+]] = vector.extract [[SCALE_EXT]][6, 0, 1]
+// CHECK-NEXT: [[PACKED_49:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_49]][0], [[SCALE_SCALAR_49]]
+// CHECK-NEXT: [[OUT_SLICE_49:%.+]] = vector.extract_strided_slice [[PACKED_49]]
+// CHECK-NEXT: [[OUT_SCALAR_49:%.+]] = vector.shape_cast [[OUT_SLICE_49]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_49]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_50:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_50:%.+]] = vector.shape_cast [[IN_SLICE_50]]
+// CHECK-NEXT: [[SCALE_SCALAR_50:%.+]] = vector.extract [[SCALE_EXT]][6, 0, 2]
+// CHECK-NEXT: [[PACKED_50:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_50]][0], [[SCALE_SCALAR_50]]
+// CHECK-NEXT: [[OUT_SLICE_50:%.+]] = vector.extract_strided_slice [[PACKED_50]]
+// CHECK-NEXT: [[OUT_SCALAR_50:%.+]] = vector.shape_cast [[OUT_SLICE_50]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_50]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_51:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_51:%.+]] = vector.shape_cast [[IN_SLICE_51]]
+// CHECK-NEXT: [[SCALE_SCALAR_51:%.+]] = vector.extract [[SCALE_EXT]][6, 0, 3]
+// CHECK-NEXT: [[PACKED_51:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_51]][0], [[SCALE_SCALAR_51]]
+// CHECK-NEXT: [[OUT_SLICE_51:%.+]] = vector.extract_strided_slice [[PACKED_51]]
+// CHECK-NEXT: [[OUT_SCALAR_51:%.+]] = vector.shape_cast [[OUT_SLICE_51]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_51]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_52:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_52:%.+]] = vector.shape_cast [[IN_SLICE_52]]
+// CHECK-NEXT: [[SCALE_SCALAR_52:%.+]] = vector.extract [[SCALE_EXT]][6, 1, 0]
+// CHECK-NEXT: [[PACKED_52:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_52]][0], [[SCALE_SCALAR_52]]
+// CHECK-NEXT: [[OUT_SLICE_52:%.+]] = vector.extract_strided_slice [[PACKED_52]]
+// CHECK-NEXT: [[OUT_SCALAR_52:%.+]] = vector.shape_cast [[OUT_SLICE_52]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_52]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_53:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_53:%.+]] = vector.shape_cast [[IN_SLICE_53]]
+// CHECK-NEXT: [[SCALE_SCALAR_53:%.+]] = vector.extract [[SCALE_EXT]][6, 1, 1]
+// CHECK-NEXT: [[PACKED_53:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_53]][0], [[SCALE_SCALAR_53]]
+// CHECK-NEXT: [[OUT_SLICE_53:%.+]] = vector.extract_strided_slice [[PACKED_53]]
+// CHECK-NEXT: [[OUT_SCALAR_53:%.+]] = vector.shape_cast [[OUT_SLICE_53]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_53]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_54:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_54:%.+]] = vector.shape_cast [[IN_SLICE_54]]
+// CHECK-NEXT: [[SCALE_SCALAR_54:%.+]] = vector.extract [[SCALE_EXT]][6, 1, 2]
+// CHECK-NEXT: [[PACKED_54:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_54]][0], [[SCALE_SCALAR_54]]
+// CHECK-NEXT: [[OUT_SLICE_54:%.+]] = vector.extract_strided_slice [[PACKED_54]]
+// CHECK-NEXT: [[OUT_SCALAR_54:%.+]] = vector.shape_cast [[OUT_SLICE_54]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_54]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_55:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_55:%.+]] = vector.shape_cast [[IN_SLICE_55]]
+// CHECK-NEXT: [[SCALE_SCALAR_55:%.+]] = vector.extract [[SCALE_EXT]][6, 1, 3]
+// CHECK-NEXT: [[PACKED_55:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_55]][0], [[SCALE_SCALAR_55]]
+// CHECK-NEXT: [[OUT_SLICE_55:%.+]] = vector.extract_strided_slice [[PACKED_55]]
+// CHECK-NEXT: [[OUT_SCALAR_55:%.+]] = vector.shape_cast [[OUT_SLICE_55]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_55]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_56:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_56:%.+]] = vector.shape_cast [[IN_SLICE_56]]
+// CHECK-NEXT: [[SCALE_SCALAR_56:%.+]] = vector.extract [[SCALE_EXT]][7, 0, 0]
+// CHECK-NEXT: [[PACKED_56:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_56]][0], [[SCALE_SCALAR_56]]
+// CHECK-NEXT: [[OUT_SLICE_56:%.+]] = vector.extract_strided_slice [[PACKED_56]]
+// CHECK-NEXT: [[OUT_SCALAR_56:%.+]] = vector.shape_cast [[OUT_SLICE_56]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_56]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_57:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_57:%.+]] = vector.shape_cast [[IN_SLICE_57]]
+// CHECK-NEXT: [[SCALE_SCALAR_57:%.+]] = vector.extract [[SCALE_EXT]][7, 0, 1]
+// CHECK-NEXT: [[PACKED_57:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_57]][0], [[SCALE_SCALAR_57]]
+// CHECK-NEXT: [[OUT_SLICE_57:%.+]] = vector.extract_strided_slice [[PACKED_57]]
+// CHECK-NEXT: [[OUT_SCALAR_57:%.+]] = vector.shape_cast [[OUT_SLICE_57]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_57]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_58:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_58:%.+]] = vector.shape_cast [[IN_SLICE_58]]
+// CHECK-NEXT: [[SCALE_SCALAR_58:%.+]] = vector.extract [[SCALE_EXT]][7, 0, 2]
+// CHECK-NEXT: [[PACKED_58:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_58]][0], [[SCALE_SCALAR_58]]
+// CHECK-NEXT: [[OUT_SLICE_58:%.+]] = vector.extract_strided_slice [[PACKED_58]]
+// CHECK-NEXT: [[OUT_SCALAR_58:%.+]] = vector.shape_cast [[OUT_SLICE_58]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_58]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_59:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_59:%.+]] = vector.shape_cast [[IN_SLICE_59]]
+// CHECK-NEXT: [[SCALE_SCALAR_59:%.+]] = vector.extract [[SCALE_EXT]][7, 0, 3]
+// CHECK-NEXT: [[PACKED_59:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_59]][0], [[SCALE_SCALAR_59]]
+// CHECK-NEXT: [[OUT_SLICE_59:%.+]] = vector.extract_strided_slice [[PACKED_59]]
+// CHECK-NEXT: [[OUT_SCALAR_59:%.+]] = vector.shape_cast [[OUT_SLICE_59]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_59]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_60:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_60:%.+]] = vector.shape_cast [[IN_SLICE_60]]
+// CHECK-NEXT: [[SCALE_SCALAR_60:%.+]] = vector.extract [[SCALE_EXT]][7, 1, 0]
+// CHECK-NEXT: [[PACKED_60:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_60]][0], [[SCALE_SCALAR_60]]
+// CHECK-NEXT: [[OUT_SLICE_60:%.+]] = vector.extract_strided_slice [[PACKED_60]]
+// CHECK-NEXT: [[OUT_SCALAR_60:%.+]] = vector.shape_cast [[OUT_SLICE_60]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_60]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_61:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_61:%.+]] = vector.shape_cast [[IN_SLICE_61]]
+// CHECK-NEXT: [[SCALE_SCALAR_61:%.+]] = vector.extract [[SCALE_EXT]][7, 1, 1]
+// CHECK-NEXT: [[PACKED_61:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_61]][0], [[SCALE_SCALAR_61]]
+// CHECK-NEXT: [[OUT_SLICE_61:%.+]] = vector.extract_strided_slice [[PACKED_61]]
+// CHECK-NEXT: [[OUT_SCALAR_61:%.+]] = vector.shape_cast [[OUT_SLICE_61]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_61]], [[ACC_A]]
+// CHECK-NEXT: [[IN_SLICE_62:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_62:%.+]] = vector.shape_cast [[IN_SLICE_62]]
+// CHECK-NEXT: [[SCALE_SCALAR_62:%.+]] = vector.extract [[SCALE_EXT]][7, 1, 2]
+// CHECK-NEXT: [[PACKED_62:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_62]][0], [[SCALE_SCALAR_62]]
+// CHECK-NEXT: [[OUT_SLICE_62:%.+]] = vector.extract_strided_slice [[PACKED_62]]
+// CHECK-NEXT: [[OUT_SCALAR_62:%.+]] = vector.shape_cast [[OUT_SLICE_62]]
+// CHECK-NEXT: [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_62]], [[ACC_B]]
+// CHECK-NEXT: [[IN_SLICE_63:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT: [[IN_SCALAR_63:%.+]] = vector.shape_cast [[IN_SLICE_63]]
+// CHECK-NEXT: [[SCALE_SCALAR_63:%.+]] = vector.extract [[SCALE_EXT]][7, 1, 3]
+// CHECK-NEXT: [[PACKED_63:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_63]][0], [[SCALE_SCALAR_63]]
+// CHECK-NEXT: [[OUT_SLICE_63:%.+]] = vector.extract_strided_slice [[PACKED_63]]
+// CHECK-NEXT: [[OUT_SCALAR_63:%.+]] = vector.shape_cast [[OUT_SLICE_63]]
+// CHECK-NEXT: [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_63]], [[ACC_A]]
+// CHECK-NEXT: [[FINAL_CAST:%.+]] = vector.shape_cast [[ACC_B]] : vector<8x2x4xf32> to vector<8x8xf32>
+// CHECK-NEXT: return [[FINAL_CAST]] : vector<8x8xf32>
+func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
+ %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+ %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
+ %cast2 = vector.shape_cast %bc : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+ %ext = arith.scaling_extf %cast1, %cast2 : vector<8x2x4xf8E5M2>, vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32>
+ %cast3 = vector.shape_cast %ext : vector<8x2x4xf32> to vector<8x8xf32>
+ return %cast3 : vector<8x8xf32>
+}
+
+
+// CHECK-LABEL: @conversion_scalar
+// CHECK: [[SCALE_F32:%.+]] = arith.extf %arg1 : f8E8M0FNU to f32
+// CHECK-NEXT: [[SPLAT_IN:%.+]] = vector.splat %arg0 : vector<1xf8E5M2>
+// CHECK-NEXT: [[PACKED_EXT:%.+]] = amdgpu.scaled_ext_packed [[SPLAT_IN]][0], [[SCALE_F32]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: [[RESULT:%.+]] = vector.extract [[PACKED_EXT]][0] : f32 from vector<2xf32>
+// CHECK-NEXT: return [[RESULT]] : f32
+func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
+ %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
+ return %ext : f32
+}
>From 590e13d749838d70d9a47adfce82745fc19ac39a Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 30 Jun 2025 15:41:33 +0000
Subject: [PATCH 3/3] clang format
---
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 74 ++++++++++---------
1 file changed, 39 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 22cd4703c6005..6e6e2a4d0890f 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -23,7 +24,6 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/LogicalResult.h"
@@ -80,7 +80,8 @@ struct TruncfToFloat16RewritePattern final
PatternRewriter &rewriter) const override;
};
-struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> {
+struct ScalingExtFRewritePattern final
+ : OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
Chipset chipset;
@@ -91,7 +92,8 @@ struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp>
PatternRewriter &rewriter) const override;
};
-struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> {
+struct ScalingTruncFRewritePattern final
+ : OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
Chipset chipset;
@@ -428,24 +430,19 @@ static Value getOriginalVectorValue(Value value) {
Value current = value;
while (Operation *definingOp = current.getDefiningOp()) {
bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
- .Case<vector::ShapeCastOp>(
- [¤t](auto op) {
- current = op.getSource();
- return true;
- })
- .Case<vector::BroadcastOp>(
- [¤t](auto op) {
- current = op.getSource();
- return false;
- })
- .Case<vector::SplatOp>(
- [¤t](auto op) {
- current = op.getInput();
- return false;
- })
- .Default([](Operation *) {
- return false;
- });
+ .Case<vector::ShapeCastOp>([¤t](auto op) {
+ current = op.getSource();
+ return true;
+ })
+ .Case<vector::BroadcastOp>([¤t](auto op) {
+ current = op.getSource();
+ return false;
+ })
+ .Case<vector::SplatOp>([¤t](auto op) {
+ current = op.getInput();
+ return false;
+ })
+ .Default([](Operation *) { return false; });
if (!skipOp) {
break;
@@ -475,7 +472,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
if (outVecType && outVecType.isScalable())
return failure();
- Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+ Type scaleF32Type =
+ scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
@@ -484,22 +482,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType extScaleResultType = VectorType::get(opWidth, outType);
if (!outVecType) {
- Value inCast = rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(loc, extScaleResultType, inCast, scale, 0);
- scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
- return success();
+ Value inCast =
+ rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+ loc, extScaleResultType, inCast, scale, 0);
+ scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+ return success();
}
Value origScale = getOriginalVectorValue(scale);
Type origScaleType = origScale.getType();
- VectorType origScaleVecType = isa<VectorType>(origScaleType) ? cast<VectorType>(origScaleType) : VectorType::get(1, origScaleType);
-
+ VectorType origScaleVecType = isa<VectorType>(origScaleType)
+ ? cast<VectorType>(origScaleType)
+ : VectorType::get(1, origScaleType);
+
ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> paddedScaleShape(originalScaleShape);
- paddedScaleShape.insert(paddedScaleShape.end(), inShape.size() - originalScaleShape.size(),
- 1);
+ paddedScaleShape.insert(paddedScaleShape.end(),
+ inShape.size() - originalScaleShape.size(), 1);
auto ratio = computeShapeRatio(inShape, paddedScaleShape);
if (!ratio)
@@ -540,10 +542,10 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
}
VectorType resultType = VectorType::get(*ratio, outType);
- Value cast = rewriter.create<vector::ShapeCastOp>(loc, resultType,
- blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, cast, result, offsets, strides);
+ Value cast =
+ rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
+ result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -551,7 +553,9 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
return success();
}
-LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const {
+LogicalResult
+ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
+ PatternRewriter &rewriter) const {
return success();
}
More information about the Mlir-commits
mailing list