[Mlir-commits] [mlir] [mlir][vector] Canonicalize/fold 'order preserving' transposes (PR #135841)
James Newling
llvmlistbot at llvm.org
Thu May 1 09:51:25 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135841
>From f4ae20602111d93f734bfbcf99f9e76a56bb7a79 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 10:29:15 -0700
Subject: [PATCH 1/4] add transpose(shape_cast) and shape_cast(transpose)
folders, with tests
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 90 ++++++++++++++++---
mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
.../Vector/canonicalize/vector-transpose.mlir | 64 +++++++++++++
3 files changed, 145 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 368259b38b153..2e5fc70afa4f7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5594,6 +5594,29 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
+namespace {
+
+/// Return true if `transpose` does not permute a pair of dimensions that are
+/// both not of size 1. By `order preserving` we mean that the flattened
+/// versions of the input and output vectors are (numerically) identical.
+/// In other words `transpose` is effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+ int64_t current = 0;
+ for (auto p : permutation) {
+ if (inShape[p] != 1) {
+ if (p < current) {
+ return false;
+ }
+ current = p;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// No-op shape cast.
@@ -5602,13 +5625,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
- // Canceling shape casts.
- if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
- // Only allows valid transitive folding (expand/collapse dimensions).
- VectorType srcType = otherOp.getSource().getType();
+ // shape_cast(something(x)) -> x, or
+ // -> shape_cast(x).
+ //
+ // Confirms that a new shape_cast will have valid semantics (expands OR
+ // collapses dimensions).
+ auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
+ VectorType srcType = source.getType();
if (resultType == srcType)
- return otherOp.getSource();
+ return source;
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
@@ -5618,8 +5643,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
} else {
return {};
}
- setOperand(otherOp.getSource());
+ setOperand(source);
return getResult();
+ };
+
+ // Canceling shape casts.
+ if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+ TypedValue<VectorType> source = otherOp.getSource();
+ return maybeFold(source);
+ }
+
+ // shape_cast(transpose(x)) -> shape_cast(x)
+ if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+ if (transpose.getType().isScalable())
+ return {};
+ if (isOrderPreserving(transpose)) {
+ TypedValue<VectorType> source = transpose.getVector();
+ return maybeFold(source);
+ }
+ return {};
}
// Cancelling broadcast and shape cast ops.
@@ -5646,7 +5688,7 @@ namespace {
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
-/// vector<4x1x1xi1> --> vector<4x1>
+/// vector<4x1x1xi1> --> vector<4x1xi1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6113,6 +6155,34 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto shapeCastOp =
+ transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return failure();
+ if (!isOrderPreserving(transposeOp))
+ return failure();
+ if (transposeOp.getType().isScalable())
+ return failure();
+
+ VectorType resultType = transposeOp.getType();
+
+ // We don't need to check isValidShapeCast at this point, because it is
+ // guaranteed that merging the transpose into the the shape_cast is a valid
+ // shape_cast, because the transpose just inserts/removes ones.
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+ shapeCastOp.getSource());
+ return success();
+ }
+};
+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6211,8 +6281,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
- FoldTransposeBroadcast>(context);
+ results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
+ FoldTransposeSplat, FoldTransposeBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2d365ac2b4287..943a9429574cd 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
+
// -----
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
@@ -3035,7 +3036,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
return %1 : vector<4x8xf32>
}
-
// -----
// CHECK-LABEL: @insert_scalar_poison_idx
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e97e147459de2..322309a559aa0 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -137,3 +137,67 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
return %1 : vector<3x3x3xi8>
}
+
+// -----
+
+// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.transpose %arg, [0, 2, 1, 3]
+ : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+ %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<6x1x1xi8> to vector<6x1x1xi8>
+ return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_shape_cast_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+ %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+ return %1 : vector<2x3xi8>
+}
>From 2e9768890c90f0d4e521c2c5f922f7b456c039c1 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 28 Apr 2025 09:41:42 -0700
Subject: [PATCH 2/4] 'unit dim' is the canonical term for size-1 dims
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++++----
.../Dialect/Vector/canonicalize/vector-transpose.mlir | 4 ++--
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2e5fc70afa4f7..0ac08edeb1f6c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5596,10 +5596,10 @@ LogicalResult ShapeCastOp::verify() {
namespace {
-/// Return true if `transpose` does not permute a pair of dimensions that are
-/// both not of size 1. By `order preserving` we mean that the flattened
-/// versions of the input and output vectors are (numerically) identical.
-/// In other words `transpose` is effectively a shape cast.
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 322309a559aa0..9dee9fdcd3b1d 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -140,7 +140,7 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
// -----
-// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
// 1 -> 0
// 2 -> 4
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
@@ -158,7 +158,7 @@ func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
// -----
-// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
// As this is not increasing (2 > 1), this transpose is not order
>From 13ba9a99944c7342fc5d316e69fd074962484a82 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 30 Apr 2025 20:22:28 -0700
Subject: [PATCH 3/4] leave a TODO note, don't blanket avoid scalable vectors,
add scalable vector tests
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++--
.../Vector/canonicalize/vector-transpose.mlir | 26 +++++++++++++++++++
2 files changed, 38 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0ac08edeb1f6c..c77064ecaa359 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5655,8 +5655,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+ // This folder does
+ // shape_cast(transpose) -> shape_cast
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+ // shape_cast -> shape_cast(transpose)
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
+ // infinite cycles in pattern rewriting.
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+ // vectors, so by disabling this folder for scalar vectors the
+ // cycle is avoided.
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
+ // still needed. If it's not, then we can fold here.
if (transpose.getType().isScalable())
return {};
+
if (isOrderPreserving(transpose)) {
TypedValue<VectorType> source = transpose.getVector();
return maybeFold(source);
@@ -6168,8 +6180,6 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
return failure();
if (!isOrderPreserving(transposeOp))
return failure();
- if (transposeOp.getType().isScalable())
- return failure();
VectorType resultType = transposeOp.getType();
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 9dee9fdcd3b1d..5b1ed641b97b6 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -177,6 +177,32 @@ func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x
// -----
+// Currently the conversion shape_cast(transpose) -> shape_cast) is disabled for
+// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
+// CHECK-LABEL: @negative_transpose_shape_cast_scalable
+// CHECK: vector.transpose
+// CHECK: vector.shape_cast
+func.func @negative_transpose_shape_cast_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
+ %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
+ %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
+ return %1 : vector<[4]xi8>
+}
+
+// -----
+
+// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
+// vectors.
+// CHECK-LABEL: @shape_cast_transpose_scalable
+// CHECK: vector.shape_cast
+// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
+ return %1 : vector<[4]x1xi8>
+}
+
+// -----
+
// CHECK-LABEL: @shape_cast_transpose
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
>From 16edc533e09d7cbf60fa1a892237a5e07ace3245 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 1 May 2025 09:51:18 -0700
Subject: [PATCH 4/4] scalable dims are not unit dims
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++++++--
.../Vector/canonicalize/vector-transpose.mlir | 14 +++++++++++++-
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c77064ecaa359..ed014583b95c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5602,10 +5602,15 @@ namespace {
/// effectively a shape cast.
bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
- ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+ VectorType sourceType = transpose.getSourceVectorType();
+ ArrayRef<int64_t> inShape = sourceType.getShape();
+ ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
+ auto isNonScalableUnitDim = [&](int64_t dim) {
+ return inShape[dim] == 1 && !inDimIsScalable[dim];
+ };
int64_t current = 0;
for (auto p : permutation) {
- if (inShape[p] != 1) {
+ if (!isNonScalableUnitDim(p)) {
if (p < current) {
return false;
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 5b1ed641b97b6..604a633a84fd6 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -177,7 +177,7 @@ func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x
// -----
-// Currently the conversion shape_cast(transpose) -> shape_cast) is disabled for
+// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
// CHECK-LABEL: @negative_transpose_shape_cast_scalable
// CHECK: vector.transpose
@@ -203,6 +203,18 @@ func.func @shape_cast_transpose_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1x
// -----
+// Scalable dimensions should be treated as non-unit dimensions.
+// CHECK-LABEL: @shape_cast_transpose_scalable
+// CHECK: vector.shape_cast
+// CHECK: vector.transpose
+func.func @shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+ %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
+ return %1 : vector<4x[1]xi8>
+}
+
+// -----
+
// CHECK-LABEL: @shape_cast_transpose
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
More information about the Mlir-commits
mailing list