[Mlir-commits] [mlir] [wip][vector][mlir] Canonicalize to shape_cast where possible (PR #140583)
James Newling
llvmlistbot at llvm.org
Mon May 19 11:40:55 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140583
>From 0909b5bbe8e72aa5e12c6e38b58f48352f44d382 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 19 May 2025 10:22:27 -0700
Subject: [PATCH 1/3] use shape_cast as canonical type for extract broadcast
and transpose
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 226 ++++++++++--------
mlir/test/Dialect/Vector/canonicalize.mlir | 27 +--
.../Dialect/Vector/canonicalize/playtime.mlir | 0
.../canonicalize/vector-shape-cast.mlir | 141 +++++++++++
.../Vector/canonicalize/vector-transpose.mlir | 60 -----
.../drop-unit-dims-with-shape-cast.mlir | 12 -
.../vector-transfer-to-vector-load-store.mlir | 12 +-
.../Vector/vector-warp-distribute.mlir | 8 +-
8 files changed, 291 insertions(+), 195 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/canonicalize/playtime.mlir
create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7ae43b64a5deb..325870fcfaea7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2344,11 +2344,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, cannot convert to
+ // shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results
+ .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+ context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -2651,13 +2685,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -5573,30 +5634,6 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-/// Return true if `transpose` does not permute a pair of non-unit dims.
-/// By `order preserving` we mean that the flattened versions of the input and
-/// output vectors are (numerically) identical. In other words `transpose` is
-/// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
- ArrayRef<int64_t> permutation = transpose.getPermutation();
- VectorType sourceType = transpose.getSourceVectorType();
- ArrayRef<int64_t> inShape = sourceType.getShape();
- ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
- auto isNonScalableUnitDim = [&](int64_t dim) {
- return inShape[dim] == 1 && !inDimIsScalable[dim];
- };
- int64_t current = 0;
- for (auto p : permutation) {
- if (!isNonScalableUnitDim(p)) {
- if (p < current) {
- return false;
- }
- current = p;
- }
- }
- return true;
-}
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
@@ -5611,33 +5648,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- // shape_cast(transpose(x)) -> shape_cast(x)
- if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
- // This folder does
- // shape_cast(transpose) -> shape_cast
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
- // shape_cast -> shape_cast(transpose)
- // i.e. the complete opposite. When paired, these 2 patterns can cause
- // infinite cycles in pattern rewriting.
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
- // vectors, so by disabling this folder for scalable vectors the
- // cycle is avoided.
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
- // still needed. If it's not, then we can fold here.
- if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
- setOperand(transpose.getVector());
- return getResult();
- }
- return {};
- }
-
- // Y = shape_cast(broadcast(X))
- // -> X, if X and Y have same type
- if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
- if (bcastOp.getSourceType() == resultType)
- return bcastOp.getSource();
- }
-
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -5993,21 +6003,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
return ub::PoisonAttr::get(getContext());
- // Eliminate identity transposes, and more generally any transposes that
- // preserves the shape without permuting elements.
- //
- // Examples of what to fold:
- // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
- // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
- // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
- //
- // Example of what NOT to fold:
- // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
- //
- if (getSourceVectorType() == getResultVectorType() &&
- isOrderPreserving(*this))
- return getVector();
-
return {};
}
@@ -6127,32 +6122,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transposeOp,
- PatternRewriter &rewriter) const override {
- auto shapeCastOp =
- transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return failure();
- if (!isOrderPreserving(transposeOp))
- return failure();
-
- VectorType resultType = transposeOp.getType();
-
- // We don't need to check isValidShapeCast at this point, because it is
- // guaranteed that merging the transpose into the the shape_cast is a valid
- // shape_cast, because the transpose just inserts/removes ones.
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
- shapeCastOp.getSource());
- return success();
- }
-};
-
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6248,12 +6217,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
}
};
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+static bool isOrderPreserving(TransposeOp transpose) {
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ VectorType sourceType = transpose.getSourceVectorType();
+ ArrayRef<int64_t> inShape = sourceType.getShape();
+ ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
+ auto isNonScalableUnitDim = [&](int64_t dim) {
+ return inShape[dim] == 1 && !inDimIsScalable[dim];
+ };
+ int64_t current = 0;
+ for (auto p : permutation) {
+ if (!isNonScalableUnitDim(p)) {
+ if (p < current) {
+ return false;
+ }
+ current = p;
+ }
+ }
+ return true;
+}
+
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+
+ // This folder does
+ // shape_cast(transpose) -> shape_cast
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+ // shape_cast -> shape_cast(transpose)
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
+ // infinite cycles in pattern rewriting.
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+ // vectors, so by disabling this folder for scalable vectors the
+ // cycle is avoided.
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
+ // still needed. If it's not, then we can fold here.
+ if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
+ FoldTransposeBroadcast, TransposeToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..45547946ce70a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -754,11 +754,11 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
// CHECK-LABEL: fold_extract_broadcast_negative
-// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
- %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
@@ -797,8 +797,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// rank(extract_output) < rank(broadcast_input)
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
return %r : vector<4xf32>
}
@@ -1840,12 +1840,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
// -----
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2197,7 +2197,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
- // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}
@@ -2684,9 +2684,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
// CHECK-LABEL: func.func @extract_from_broadcast
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
- // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
- // CHECK-NEXT: return %0 : vector<1xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+ // CHECK-NEXT: return %[[RES]] : vector<1xf32>
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
return %1: vector<1xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/playtime.mlir b/mlir/test/Dialect/Vector/canonicalize/playtime.mlir
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
new file mode 100644
index 0000000000000..5495a05d75944
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// +----------------------------------------
+// Tests of TransposeToShapeCast
+// +----------------------------------------
+
+// CHECK-LABEL: @transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// +----------------------------------------
+// Tests of BroadcastToShapeCast
+// +----------------------------------------
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+ return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+ return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+ %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+ return %0 : vector<1xi8>
+}
+
+// -----
+
+// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
+// vectors.
+// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK: vector.shape_cast
+// CHECK: vector.transpose
+func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
+ return %1 : vector<[4]x1xi8>
+}
+
+// -----
+
+// A transpose that is 'order preserving' can be treated like a shape_cast.
+// CHECK-LABEL: @transpose_of_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<6x1x1xi8> to vector<6x1x1xi8>
+ return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// Scalable dimensions should be treated as non-unit dimensions.
+// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK: vector.shape_cast
+// CHECK: vector.transpose
+func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+ %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
+ return %1 : vector<4x[1]xi8>
+}
+
+// -----
+
+// Test of shape_cast (not) folding.
+// CHECK-LABEL: @negative_transpose_of_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi8> {
+ %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+ return %1 : vector<2x3xi8>
+}
+
+// -----
+
+// +----------------------------------------
+// Tests of ExtractToShapeCast
+// +----------------------------------------
+
+// CHECK-LABEL: @extract_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+ %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index c84aea6609665..e6ef4530a0610 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -195,66 +195,6 @@ func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) ->
return %1 : vector<[4]xi8>
}
-// -----
-
-/// +--------------------------------------------------------------------------
-/// Tests of FoldTransposeShapeCast: transpose(shape_cast) -> shape_cast
-/// +--------------------------------------------------------------------------
-
-// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
-// vectors.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
-// CHECK: vector.shape_cast
-// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
-func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
- %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
- %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
- return %1 : vector<[4]x1xi8>
-}
-
-// -----
-
-// A transpose that is 'order preserving' can be treated like a shape_cast.
-// CHECK-LABEL: @transpose_of_shape_cast
-// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
-// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
-// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
-// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
-func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
- %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
- %1 = vector.transpose %0, [0, 2, 1]
- : vector<6x1x1xi8> to vector<6x1x1xi8>
- return %1 : vector<6x1x1xi8>
-}
-
-// -----
-
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
-// CHECK: vector.shape_cast
-// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
- %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
- %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
- return %1 : vector<4x[1]xi8>
-}
-
-// -----
-
-// Test of shape_cast (not) folding.
-// CHECK-LABEL: @negative_transpose_of_shape_cast
-// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
-// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
-// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
-// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
-func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi8> {
- %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
- %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
- return %1 : vector<2x3xi8>
-}
-
-// -----
-
// +-----------------------------------
// Tests of TransposeOp::fold
// +-----------------------------------
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index 34a155fbf2fc1..44abe2ac46fce 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -188,18 +188,6 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
// -----
-func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
- %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
- return %res : vector<1x1x1xf32>
-}
-// The `vec` is returned because there are other flattening patterns that fold
-// vector.shape_cast ops away.
-// CHECK-LABEL: func.func @transpose_with_all_unit_dims
-// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
-// CHECK-NEXT: return %[[VEC]]
-
-// -----
-
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
return %res : vector<4x3x2xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 511ab70f35086..7886fba6c80c4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
%f0 = arith.constant 0.0 : f32
// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][]
-// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+// CHECK: %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
%res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
tensor<f32>, vector<1xf32>
@@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
%c0 = arith.constant 0 : index
%res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
- // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
+ // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
return %res : tensor<?x?x?x?xf32>
}
@@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref(
%c0 = arith.constant 0 : index
vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
- // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..ba47799729f1d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1311,8 +1311,8 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
// CHECK-PROP-DAG: %[[THREADID:.*]] = gpu.thread_id x
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]]
// CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
-// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
-// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
+// CHECK-PROP: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[GATHER]] : vector<1x64xi32> to vector<64xi32>
+// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[SHAPE_CAST]] : vector<64xi32> to vector<64xindex>
// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex>
// CHECK-PROP: gpu.yield %[[EXTRACTELT]] : index
// CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
@@ -1348,8 +1348,8 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1
// CHECK-PROP-LABEL: func @dont_fold_vector_broadcast(
// CHECK-PROP: %[[r:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>)
// CHECK-PROP: %[[some_def:.*]] = "some_def"
-// CHECK-PROP: %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
-// CHECK-PROP: gpu.yield %[[broadcast]] : vector<1x64xf32>
+// CHECK-PROP: %[[shape_cast:.*]] = vector.shape_cast %[[some_def]] : vector<64xf32> to vector<1x64xf32>
+// CHECK-PROP: gpu.yield %[[shape_cast]] : vector<1x64xf32>
// CHECK-PROP: vector.print %[[r]] : vector<1x2xf32>
func.func @dont_fold_vector_broadcast(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {
>From 6e68537455463c1c22683dd60f4e185c244977e9 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 19 May 2025 11:23:26 -0700
Subject: [PATCH 2/3] update testing
---
mlir/test/Dialect/Vector/canonicalize.mlir | 24 ---
.../canonicalize/vector-shape-cast.mlir | 147 ++++++++++--------
.../Vector/canonicalize/vector-transpose.mlir | 52 -------
3 files changed, 85 insertions(+), 138 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 45547946ce70a..f46d7799b4db0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1033,30 +1033,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
// -----
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
- %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
- %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
- return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
- %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
- %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
- return %1 : vector<1x1xf32>
-}
-
-// -----
-
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
index 5495a05d75944..357df0f129a5e 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir
@@ -1,31 +1,5 @@
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
-// +----------------------------------------
-// Tests of TransposeToShapeCast
-// +----------------------------------------
-
-// CHECK-LABEL: @transpose_to_shape_cast
-// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
-// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
-// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
-func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
- %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
- return %0 : vector<2x2x1xf32>
-}
-
-
-// -----
-
-// CHECK-LABEL: @negative_transpose_to_shape_cast
-// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
-// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
-// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
-func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
- %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
- return %0 : vector<2x2x1xf32>
-}
-
-// -----
// +----------------------------------------
// Tests of BroadcastToShapeCast
@@ -42,16 +16,20 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
// -----
-// CHECK-LABEL: @negative_broadcast_to_shape_cast
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
// CHECK-NOT: shape_cast
// CHECK: return
-func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
return %0 : vector<2x3x4xi8>
}
// -----
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
// CHECK-NOT: shape_cast
// CHECK: return
@@ -62,56 +40,101 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
// -----
-// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
-// vectors.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
-// CHECK: vector.shape_cast
-// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
- %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
- %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
- return %1 : vector<[4]x1xi8>
+// +----------------------------------------
+// Tests of TransposeToShapeCast
+// +----------------------------------------
+
+// In this test, the permutation maps the non-unit dimensions (0 and 2) as follows:
+// 0 -> 0
+// 2 -> 1
+// Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
}
// -----
-// A transpose that is 'order preserving' can be treated like a shape_cast.
-// CHECK-LABEL: @transpose_of_shape_cast
-// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @shape_cast_of_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
-// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
-// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
-func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
- %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
- %1 = vector.transpose %0, [0, 2, 1]
- : vector<6x1x1xi8> to vector<6x1x1xi8>
- return %1 : vector<6x1x1xi8>
+// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+// CHECK: return %[[SHAPE_CAST]]
+func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+ return %0 : vector<4x1x1x1x4xi8>
}
// -----
// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_scalable_unit
+// CHECK-NOT: shape_cast
+func.func @transpose_scalable_unit(%arg : vector<[1]x4xi8>) -> vector<4x[1]xi8> {
+ %0 = vector.transpose %arg, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
+ return %0 : vector<4x[1]xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<1x4x4x1xi8> {
+ %0 = vector.transpose %arg, [0, 2, 1, 3]
+ : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+ return %0 : vector<1x4x4x1xi8>
+}
+
+// -----
+
+// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
+// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
+// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
+// CHECK: vector.transpose
+// CHECK: vector.shape_cast
+func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
+ %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
+ %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
+ return %1 : vector<[4]xi8>
+}
+
+// -----
+
+// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
+// vectors.
+// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable
// CHECK: vector.shape_cast
// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
- %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
- %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
- return %1 : vector<4x[1]xi8>
+func.func @negative_transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
+ return %1 : vector<[4]x1xi8>
}
// -----
-// Test of shape_cast (not) folding.
-// CHECK-LABEL: @negative_transpose_of_shape_cast
-// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
-// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
-// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
-// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
-func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi8> {
- %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
- %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
- return %1 : vector<2x3xi8>
+// A test where a transpose cannot be transformed to a shape_cast because it is not order
+// preserving
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e6ef4530a0610..5055b1e2c4862 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -143,58 +143,6 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
// -----
-/// +--------------------------------------------------------------------------
-/// Tests of ShapeCastOp::fold: shape_cast(transpose) -> shape_cast
-/// +--------------------------------------------------------------------------
-
-// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
-// 1 -> 0
-// 2 -> 4
-// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
-// CHECK-LABEL: @shape_cast_of_transpose
-// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
-// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
-// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
-// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
-func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
- %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
- : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
- %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
- return %1 : vector<4x4xi8>
-}
-
-// -----
-
-// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
-// 1 -> 2
-// 2 -> 1
-// As this is not increasing (2 > 1), this transpose is not order
-// preserving and cannot be treated as a shape_cast.
-// CHECK-LABEL: @negative_shape_cast_of_transpose
-// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
-// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
-// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
-// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
-func.func @negative_shape_cast_of_transpose(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
- %0 = vector.transpose %arg, [0, 2, 1, 3]
- : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
- %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
- return %1 : vector<4x4xi8>
-}
-
-// -----
-
-// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
-// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
-// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
-// CHECK: vector.transpose
-// CHECK: vector.shape_cast
-func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
- %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
- %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
- return %1 : vector<[4]xi8>
-}
-
// +-----------------------------------
// Tests of TransposeOp::fold
// +-----------------------------------
>From d546ab38ea525ae344f9114fe137310fc523046e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 19 May 2025 11:41:13 -0700
Subject: [PATCH 3/3] further dup removal
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 21 +--------------------
1 file changed, 1 insertion(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 325870fcfaea7..007dffc7d0355 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5769,10 +5769,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-/// i) Y = ShapeCast(X), or
-/// ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -5787,22 +5784,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
- // Example:
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
- // to
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
- if (srcVectorType) {
- if (srcVectorType.getNumElements() ==
- shapeCastOp.getResultVectorType().getNumElements()) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
- }
- }
-
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
More information about the Mlir-commits
mailing list