[Mlir-commits] [mlir] [mlir][ArmSME] Remove `ConvertIllegalShapeCastOpsToTransposes` (PR #139706)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Jun 18 01:18:07 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/139706
>From de24b856bf974522258b28a258f610fecad605e0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 13 May 2025 09:31:35 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Remove
`ConvertIllegalShapeCastOpsToTransposes`
As a follow-up to PR #135841 (see discussion for context), this patch
removes `ConvertIllegalShapeCastOpsToTransposes` from the SME legalization
pass and unblocks `ShapeCastOp::fold` for scalable vectors.
AFAIK, `ConvertIllegalShapeCastOpsToTransposes` was originally needed
because we were generating `vector.shape_cast` ops that couldn't be
lowered otherwise. To confirm it's no longer required, I tested this
patch locally using end-to-end tests.
Notably, this also removes a special case from `ShapeCastOp::fold`.
---
.../ArmSME/Transforms/VectorLegalization.cpp | 54 ----------------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 13 +---
.../Dialect/ArmSME/vector-legalization.mlir | 45 -------------
.../Vector/canonicalize/vector-transpose.mlir | 64 +++++++++++--------
4 files changed, 37 insertions(+), 139 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
}
};
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-struct ConvertIllegalShapeCastOpsToTransposes
- : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto sourceType = shapeCastOp.getSourceVectorType();
- auto resultType = shapeCastOp.getResultVectorType();
- if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
- return rewriter.notifyMatchFailure(shapeCastOp,
- kMatchFailureNotIllegalToLegal);
-
- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
- // then dim 0 is scalable and dim 1 is fixed.
- if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
- return rewriter.notifyMatchFailure(
- shapeCastOp, "expected source to be a 2D scalable vector with a "
- "trailing unit dim");
-
- auto loc = shapeCastOp.getLoc();
- auto transpose = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
- if (resultType.getRank() == 1)
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
- transpose);
- else
- rewriter.replaceOp(shapeCastOp, transpose);
-
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a2357319bd23..887773172339f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5758,18 +5758,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
- // This folder does
- // shape_cast(transpose) -> shape_cast
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
- // shape_cast -> shape_cast(transpose)
- // i.e. the complete opposite. When paired, these 2 patterns can cause
- // infinite cycles in pattern rewriting.
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
- // vectors, so by disabling this folder for scalable vectors the
- // cycle is avoided.
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
- // still needed. If it's not, then we can fold here.
- if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+ if (isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
// -----
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
- return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
- return %cast : vector<[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: @multi_tile_splat
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
{
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index c84aea6609665..f1e1c5e896c66 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -165,6 +165,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
// -----
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @shape_cast_of_transpose_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @shape_cast_of_transpose_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+ return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
@@ -184,36 +203,10 @@ func.func @negative_shape_cast_of_transpose(%arg : vector<1x4x4x1xi8>) -> vector
// -----
-// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
-// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
-// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
-// CHECK: vector.transpose
-// CHECK: vector.shape_cast
-func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
- %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
- %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
- return %1 : vector<[4]xi8>
-}
-
-// -----
-
/// +--------------------------------------------------------------------------
/// Tests of FoldTransposeShapeCast: transpose(shape_cast) -> shape_cast
/// +--------------------------------------------------------------------------
-// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
-// vectors.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
-// CHECK: vector.shape_cast
-// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
-func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
- %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
- %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
- return %1 : vector<[4]x1xi8>
-}
-
-// -----
-
// A transpose that is 'order preserving' can be treated like a shape_cast.
// CHECK-LABEL: @transpose_of_shape_cast
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
@@ -229,11 +222,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
// -----
-// Scalable dimensions should be treated as non-unit dimensions.
// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @transpose_of_shape_cast_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+ return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable_unit
// CHECK: vector.shape_cast
// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
return %1 : vector<4x[1]xi8>
>From d04d335a2ef48a433cb661665edd3e1ef47e7a04 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 18 Jun 2025 09:09:52 +0100
Subject: [PATCH 2/2] fixup! [mlir][ArmSME] Remove
`ConvertIllegalShapeCastOpsToTransposes`
Add LowerColumnTransferReadToLoops. Note, this is to address Ben's
comment here:
* https://github.com/llvm/llvm-project/pull/139706/files#r2088605443
---
.../ArmSME/Transforms/VectorLegalization.cpp | 117 +++++++++++++++++-
.../Dialect/ArmSME/vector-legalization.mlir | 56 +++++++++
2 files changed, 170 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 51750f0bb9694..1e8e1265affa0 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -867,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
}
};
+/// Lower `vector.transfer_read` of a scalable column to `scf::for`
+///
+/// Lowers a "read" of a scalable column from a MemRef for which there is no
+/// hardware pperation that we could use to a loop over the rows to read and
+/// loads one element at a time.
+///
+/// BEFORE:
+/// ```
+/// %res = vector.transfer_read %mem[%a, %b] (...)
+/// : memref<?x?xf32>, vector<[4]x1xf32>
+/// ```
+///
+/// AFTER:
+/// ```
+/// %cst = arith.constant (...) : vector<[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
+/// -> (vector<[4]xf32>) {
+///
+/// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
+/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
+/// scf.yield %vec : vector<[4]xf32>
+/// }
+/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
+/// ```
+///
+/// TODO: This transformation isn't specific to SME - move it to the SVE
+/// dialect.
+/// TODO: Check the in_bounds attribute and generate vector.maskedload if
+/// required.
+struct LowerColumnTransferReadToLoops
+ : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+ // NOTE: This is a fairly low-level transformation, so we shouldn't be
+ // adding support for Tensors without good rationale.
+ if (readOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ readOp, "Tensor semantics are unsupported (either bufferize or "
+ "extend this pattern)");
+
+ auto resType = readOp.getVectorType();
+
+ if (resType.getRank() != 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "Only 2D vectors are supported!");
+
+ if (resType.getShape()[1] != 1)
+ return rewriter.notifyMatchFailure(
+ readOp, "The trailing output dim is != 1 (not supported ATM)");
+
+ if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
+ return rewriter.notifyMatchFailure(
+ readOp, "Expected the leading dim to be scalable and the trailing "
+ "dim to be fixed.");
+
+ // Create new result type - similar to the original vector with the
+ // trailing unit dim collapsed.
+ int64_t numRows = resType.getShape()[0];
+ VectorType newResType = VectorType::get(numRows, resType.getElementType(),
+ /*scalableDims=*/{true});
+
+ // Create a loop over all rows and load one element at a time.
+ auto loc = readOp.getLoc();
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto createVscaleMultiple =
+ vector::makeVscaleConstantBuilder(rewriter, loc);
+ auto upperBound = createVscaleMultiple(numRows);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value init = rewriter.create<arith::ConstantOp>(
+ loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
+
+ scf::ForOp loadLoop;
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ loadLoop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+ ValueRange{init});
+ rewriter.setInsertionPointToStart(loadLoop.getBody());
+
+ auto tileSliceIndex = loadLoop.getInductionVar();
+
+ auto idx0 = rewriter.create<arith::AddIOp>(loc, tileSliceIndex,
+ readOp.getIndices()[0]);
+ auto idx1 = readOp.getIndices()[1];
+
+ Value scalar = rewriter.create<memref::LoadOp>(
+ loc, readOp.getBase(), SmallVector<Value>({idx0, idx1}));
+
+ Operation *updateInit = rewriter.create<vector::InsertOp>(
+ loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
+
+ rewriter.create<scf::YieldOp>(loc, updateInit->getResult(0));
+ }
+
+ // The read operation has been "legalized", but since the original result
+ // type was a 2D vector, we need to cast before returning the result. This
+ // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
+ // no-op).
+ auto sc = rewriter.create<vector::ShapeCastOp>(
+ loc, readOp.getResult().getType(), loadLoop.getResult(0));
+
+ rewriter.replaceOp(readOp, sc);
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
@@ -888,9 +998,10 @@ struct VectorLegalizationPass
// Apply preprocessing patterns.
RewritePatternSet rewritePatterns(context);
- rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory,
- LowerIllegalTransposeStoreViaZA>(context);
+ rewritePatterns
+ .add<FoldExtractFromVectorOfSMELikeCreateMasks,
+ LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
+ LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
return signalPassFailure();
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6e6615c243d2a..6cdf576272ebc 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -611,3 +611,59 @@ func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<
%0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
return %0 : vector<16x16xf32>
}
+
+// -----
+
+//=============================================================================
+// 1D examples - to be moved to the SVE dialect
+//=============================================================================
+
+/// TODO: Handle in_bounds
+
+// CHECK-LABEL: func.func @xfer_read_scalable_column(
+// CHECK-SAME: %[[IDX_0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[PAD:.*]]: f32,
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>) -> vector<[4]x1xf32> {
+func.func @xfer_read_scalable_column(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x1xf32>) {
+ // CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+ // CHECK: %[[STEP:.*]] = arith.constant 1 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[LB:.*]] = arith.constant 0 : index
+ // CHECK: %[[VSCALE:.*]] = vector.vscale
+ // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+
+ // <scf.for>
+ // CHECK: %[[SCF:.*]] = scf.for %[[IND_VAR:.*]] = %[[LB]] to %[[C4_VSCALE]] step %[[STEP]] iter_args(%[[SCF_RES:.*]] = %[[INIT]]) -> (vector<[4]xf32>) {
+ // CHECK: %[[IDX_0_UPDATED:.*]] = arith.addi %[[IND_VAR]], %[[IDX_0]] : index
+ // CHECK: %[[VAL_10:.*]] = memref.load %[[SRC]][%[[IDX_0_UPDATED]], %[[IDX_1]]] : memref<?x?xf32>
+ // CHECK: %[[RES_UPDATED:.*]] = vector.insert %[[VAL_10]], %[[SCF_RES]] [%[[IND_VAR]]] : f32 into vector<[4]xf32>
+ // CHECK: scf.yield %[[RES_UPDATED]] : vector<[4]xf32>
+ // CHECK: }
+
+ // <shape-cast>
+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[SCF]] : vector<[4]xf32> to vector<[4]x1xf32>
+ // CHECK: return %[[SC]]
+ %read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x1xf32>
+ return %read : vector<[4]x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_x2
+func.func @negative_xfer_read_scalable_column_x2(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x2xf32>) {
+ // CHECK-NOT: scf.for
+ // CHECK-NOT: memref.load
+ %read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x2xf32>
+ return %read : vector<[4]x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_scalable_trailing_dim
+func.func @negative_xfer_read_scalable_column_scalable_trailing_dim(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<4x[1]xf32>) {
+ // CHECK-NOT: scf.for
+ // CHECK-NOT: memref.load
+ %read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<4x[1]xf32>
+ return %read : vector<4x[1]xf32>
+}
More information about the Mlir-commits
mailing list