[Mlir-commits] [mlir] 21f1a61 - [mlir][vector] Additional transpose folding (#138347)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 14 14:56:39 PDT 2025


Author: James Newling
Date: 2025-05-14T14:56:35-07:00
New Revision: 21f1a616d60934953de3f304aafcad968e770b0a

URL: https://github.com/llvm/llvm-project/commit/21f1a616d60934953de3f304aafcad968e770b0a
DIFF: https://github.com/llvm/llvm-project/commit/21f1a616d60934953de3f304aafcad968e770b0a.diff

LOG: [mlir][vector] Additional transpose folding  (#138347)

Fold transpose with unit-dimensions. Seen in the wild:
```
 %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
```

This transpose can be folded because (1) it preserves the shape and (2)
the shuffled dims are unit extent.

Also addresses comment about static vs anonymous namespace:
https://github.com/llvm/llvm-project/pull/135841#discussion_r2071869067

---------

Signed-off-by: James Newling <james.newling at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..79bf87ccd34af 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5573,13 +5573,11 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-namespace {
-
 /// Return true if `transpose` does not permute a pair of non-unit dims.
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+static bool isOrderPreserving(TransposeOp transpose) {
   ArrayRef<int64_t> permutation = transpose.getPermutation();
   VectorType sourceType = transpose.getSourceVectorType();
   ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5599,8 +5597,6 @@ bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
-} // namespace
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -5997,18 +5993,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
     return ub::PoisonAttr::get(getContext());
 
-  // Eliminate identity transpose ops. This happens when the dimensions of the
-  // input vector remain in their original order after the transpose operation.
-  ArrayRef<int64_t> perm = getPermutation();
-
-  // Check if the permutation of the dimensions contains sequential values:
-  // {0, 1, 2, ...}.
-  for (int64_t i = 0, e = perm.size(); i < e; i++) {
-    if (perm[i] != i)
-      return {};
-  }
+  // Eliminate identity transposes, and more generally any transposes that
+  // preserves the shape without permuting elements.
+  //
+  // Examples of what to fold:
+  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
+  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
+  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
+  //
+  // Example of what NOT to fold:
+  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+  //
+  if (getSourceVectorType() == getResultVectorType() &&
+      isOrderPreserving(*this))
+    return getVector();
 
-  return getVector();
+  return {};
 }
 
 LogicalResult vector::TransposeOp::verify() {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..974f4506a2ef0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 
 // -----
 
-// CHECK-LABEL: transpose_1D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
-func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
-  // CHECK-NOT: transpose
-  %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
-  // CHECK-NEXT: return [[ARG]]
-  return %0 : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: transpose_2D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
-  // CHECK-NOT: transpose
-  %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
-  // CHECK-NEXT: return [[ARG]]
-  return %0 : vector<4x3xf32>
-}
-
-// -----
-
 // CHECK-LABEL: transpose_3D_identity
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 7d8daec4dcba7..c84aea6609665 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -1,6 +1,10 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
-// This file contains some canonicalizations tests involving vector.transpose.
+// This file contains some tests of canonicalizations and foldings involving vector.transpose.
+
+// +---------------------------------------------------------------------------
+//  Tests of FoldTransposeBroadcast: transpose(broadcast) -> broadcast
+// +---------------------------------------------------------------------------
 
 // CHECK-LABEL: func @transpose_scalar_broadcast1
 //  CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
@@ -248,3 +252,47 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi
   %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
   return %1 : vector<2x3xi8>
 }
+
+// -----
+
+// +-----------------------------------
+//  Tests of TransposeOp::fold
+// +-----------------------------------
+
+//  CHECK-LABEL: transpose_1D_identity
+//   CHECK-SAME:   [[ARG:%.*]]: vector<4xf32>
+//   CHECK-NEXT:   return [[ARG]]
+func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: transpose_2D_identity
+//  CHECK-SAME:    [[ARG:%.*]]: vector<4x3xf32>
+//  CHECK-NEXT:    return [[ARG]]
+func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
+  %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+  return %0 : vector<4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: transpose_shape_and_order_preserving
+//  CHECK-SAME:    [[ARG:%.*]]: vector<6x1x1x4xi8>
+//  CHECK-NEXT:    return [[ARG]]
+func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
+  return %0 : vector<6x1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_transpose_fold
+//       CHECK: [[TRANSP:%.*]] = vector.transpose
+//       CHECK: return [[TRANSP]]
+func.func @negative_transpose_fold(%arg : vector<2x2xi8>) -> vector<2x2xi8> {
+  %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+  return %0 : vector<2x2xi8>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 83395504e8c74..a730f217f027d 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -65,13 +65,15 @@ func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32>
   return %0 : vector<1x8x8xf32>
 }
 
-// CHECK-LABEL:   func @transpose1023_1x1x8x8xf32(
-func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
-  // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
-  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
-  return %0 : vector<1x1x8x8xf32>
+// CHECK-LABEL:   func @transpose1023_2x1x8x4xf32(
+func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
+  // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
+  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
+  // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
+  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
+  return %0 : vector<1x2x8x4xf32>
 }
 
 /// Scalable dim should not be unrolled.


        


More information about the Mlir-commits mailing list