[Mlir-commits] [mlir] [mlir][vector] Additional transpose folding (PR #138347)
James Newling
llvmlistbot at llvm.org
Wed May 14 12:26:02 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/138347
>From e5ca0d8297d0c2420f63945119be1ae9daf09cbf Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 2 May 2025 14:11:29 -0700
Subject: [PATCH 1/4] catch additional foldable case
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 32 +++++++++---------
mlir/test/Dialect/Vector/canonicalize.mlir | 22 -------------
.../Vector/canonicalize/vector-transpose.mlir | 33 +++++++++++++++++++
3 files changed, 49 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..79bf87ccd34af 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5573,13 +5573,11 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-namespace {
-
/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+static bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5599,8 +5597,6 @@ bool isOrderPreserving(TransposeOp transpose) {
return true;
}
-} // namespace
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
@@ -5997,18 +5993,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
return ub::PoisonAttr::get(getContext());
- // Eliminate identity transpose ops. This happens when the dimensions of the
- // input vector remain in their original order after the transpose operation.
- ArrayRef<int64_t> perm = getPermutation();
-
- // Check if the permutation of the dimensions contains sequential values:
- // {0, 1, 2, ...}.
- for (int64_t i = 0, e = perm.size(); i < e; i++) {
- if (perm[i] != i)
- return {};
- }
+ // Eliminate identity transposes, and more generally any transposes that
+ // preserves the shape without permuting elements.
+ //
+ // Examples of what to fold:
+ // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
+ // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
+ // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
+ //
+ // Example of what NOT to fold:
+ // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+ //
+ if (getSourceVectorType() == getResultVectorType() &&
+ isOrderPreserving(*this))
+ return getVector();
- return getVector();
+ return {};
}
LogicalResult vector::TransposeOp::verify() {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..974f4506a2ef0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
// -----
-// CHECK-LABEL: transpose_1D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
-func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
- // CHECK-NOT: transpose
- %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
- // CHECK-NEXT: return [[ARG]]
- return %0 : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: transpose_2D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
- // CHECK-NOT: transpose
- %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
- // CHECK-NEXT: return [[ARG]]
- return %0 : vector<4x3xf32>
-}
-
-// -----
-
// CHECK-LABEL: transpose_3D_identity
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 7d8daec4dcba7..0a1b4e05bb118 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -248,3 +248,36 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
return %1 : vector<2x3xi8>
}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_1D_identity
+// CHECK-SAME: [[ARG:%.*]]: vector<4xf32>
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_2D_identity
+// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32>
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
+ %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+ return %0 : vector<4x3xf32>
+}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_shape_and_order_preserving
+// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8>
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
+ %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
+ return %0 : vector<6x1x1x4xi8>
+}
>From 1f21d4d71aa153f172b50978a4a00de698be1948 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 May 2025 10:33:38 -0700
Subject: [PATCH 2/4] Update test with transpose that is now folded
Signed-off-by: James Newling <james.newling at gmail.com>
---
.../Vector/vector-transpose-lowering.mlir | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 83395504e8c74..a730f217f027d 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -65,13 +65,15 @@ func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32>
return %0 : vector<1x8x8xf32>
}
-// CHECK-LABEL: func @transpose1023_1x1x8x8xf32(
-func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
- // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
- // CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
- %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
- return %0 : vector<1x1x8x8xf32>
+// CHECK-LABEL: func @transpose1023_2x1x8x4xf32(
+func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
+ // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
+ // CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
+ // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
+ %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
+ return %0 : vector<1x2x8x4xf32>
}
/// Scalable dim should not be unrolled.
>From 2a2af7c77a4d7cd5e3eb4690ee468fcc074b586a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 12 May 2025 11:11:23 -0700
Subject: [PATCH 3/4] better grouping of tests
---
.../Vector/canonicalize/vector-transpose.mlir | 23 +++++++++++++++----
1 file changed, 19 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 0a1b4e05bb118..e65f05594bb2a 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
-// This file contains some canonicalizations tests involving vector.transpose.
+// This file contains some tests of canonicalizations and foldings involving vector.transpose.
+
+// +----------------------------------------
+// Tests of FoldTransposeBroadcast
+// +----------------------------------------
// CHECK-LABEL: func @transpose_scalar_broadcast1
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
@@ -251,7 +255,10 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi
// -----
-// Test of transpose folding
+// +-----------------------------------
+// Tests of transpose folding
+// +-----------------------------------
+
// CHECK-LABEL: transpose_1D_identity
// CHECK-SAME: [[ARG:%.*]]: vector<4xf32>
// CHECK-NEXT: return [[ARG]]
@@ -262,7 +269,6 @@ func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
// -----
-// Test of transpose folding
// CHECK-LABEL: transpose_2D_identity
// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32>
// CHECK-NEXT: return [[ARG]]
@@ -273,7 +279,6 @@ func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
// -----
-// Test of transpose folding
// CHECK-LABEL: transpose_shape_and_order_preserving
// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8>
// CHECK-NEXT: return [[ARG]]
@@ -281,3 +286,13 @@ func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> ve
%0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
return %0 : vector<6x1x1x4xi8>
}
+
+// -----
+
+// CHECK-LABEL: negative_transpose_fold
+// CHECK: [[TRANSP:%.*]] = vector.transpose
+// CHECK: return [[TRANSP]]
+func.func @negative_transpose_fold(%arg : vector<2x2xi8>) -> vector<2x2xi8> {
+ %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+ return %0 : vector<2x2xi8>
+}
>From d92bf2c4a409c9e0b442d7aeeb7a8ec5783d4d55 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 14 May 2025 12:26:07 -0700
Subject: [PATCH 4/4] test grouping polish
---
.../Dialect/Vector/canonicalize/vector-transpose.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e65f05594bb2a..c84aea6609665 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -2,9 +2,9 @@
// This file contains some tests of canonicalizations and foldings involving vector.transpose.
-// +----------------------------------------
-// Tests of FoldTransposeBroadcast
-// +----------------------------------------
+// +---------------------------------------------------------------------------
+// Tests of FoldTransposeBroadcast: transpose(broadcast) -> broadcast
+// +---------------------------------------------------------------------------
// CHECK-LABEL: func @transpose_scalar_broadcast1
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
@@ -256,7 +256,7 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi
// -----
// +-----------------------------------
-// Tests of transpose folding
+// Tests of TransposeOp::fold
// +-----------------------------------
// CHECK-LABEL: transpose_1D_identity
More information about the Mlir-commits
mailing list