[Mlir-commits] [mlir] [mlir][vector] Canonicalize/fold 'order preserving' transposes (PR #135841)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 15 18:01:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

<details>
<summary>Changes</summary>

Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into.  For example 

```
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>
```

can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that

1) shape_cast is more canonical than shape_cast(transpose)
2) shape_cast is more canonical than transpose(shape_cast)

The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors 

---
Full diff: https://github.com/llvm/llvm-project/pull/135841.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+79-9) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+64) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..5da0ef0af032f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5621,6 +5621,29 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
+namespace {
+
+/// Return true if `transpose` does not permute a pair of dimensions that are
+/// both not of size 1. By `order preserving` we mean that the flattened
+/// versions of the input and output vectors are (numerically) identical.
+/// In other words `transpose` is effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+  ArrayRef<int64_t> permutation = transpose.getPermutation();
+  ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+  int64_t current = 0;
+  for (auto p : permutation) {
+    if (inShape[p] != 1) {
+      if (p < current) {
+        return false;
+      }
+      current = p;
+    }
+  }
+  return true;
+}
+
+} // namespace
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // No-op shape cast.
@@ -5629,13 +5652,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
 
-  // Canceling shape casts.
-  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
-    // Only allows valid transitive folding (expand/collapse dimensions).
-    VectorType srcType = otherOp.getSource().getType();
+  // shape_cast(something(x)) -> x, or
+  //                          -> shape_cast(x).
+  //
+  // Confirms that a new shape_cast will have valid semantics (expands OR
+  // collapses dimensions).
+  auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
+    VectorType srcType = source.getType();
     if (resultType == srcType)
-      return otherOp.getSource();
+      return source;
     if (srcType.getRank() < resultType.getRank()) {
       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
         return {};
@@ -5645,8 +5670,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     } else {
       return {};
     }
-    setOperand(otherOp.getSource());
+    setOperand(source);
     return getResult();
+  };
+
+  // Canceling shape casts.
+  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+    TypedValue<VectorType> source = otherOp.getSource();
+    return maybeFold(source);
+  }
+
+  // shape_cast(transpose(x)) -> shape_cast(x)
+  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+    if (transpose.getType().isScalable())
+      return {};
+    if (isOrderPreserving(transpose)) {
+      TypedValue<VectorType> source = transpose.getVector();
+      return maybeFold(source);
+    }
+    return {};
   }
 
   // Cancelling broadcast and shape cast ops.
@@ -5675,7 +5717,7 @@ namespace {
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
-///   vector<4x1x1xi1> --> vector<4x1>
+///   vector<4x1x1xi1> --> vector<4x1xi1>
 ///
 static VectorType trimTrailingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6161,12 +6203,40 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto shapeCastOp =
+        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+    if (!isOrderPreserving(transposeOp))
+      return failure();
+    if (transposeOp.getType().isScalable())
+      return failure();
+
+    VectorType resultType = transposeOp.getType();
+
+    // We don't need to check isValidShapeCast at this point, because it is
+    // guaranteed that merging the transpose into the the shape_cast is a valid
+    // shape_cast, because the transpose just inserts/removes ones.
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+                                                     shapeCastOp.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              FoldTransposeShapeCast, TransposeFolder, FoldTransposeSplat>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..10144cb9034e4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3295,3 +3295,67 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
   %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
   return %res : vector<4x1xi32>
 }
+
+// -----
+
+// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x4x4x1x1xi8> to vector<4x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3]
+     : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+  %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<2x3x1x1xi8> to vector<6x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<6x1x1xi8> to vector<6x1x1xi8>
+  return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+//       CHECK:   return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+  %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+  return %1 : vector<2x3xi8>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/135841


More information about the Mlir-commits mailing list