[Mlir-commits] [mlir] 1a44f38 - [mlir][vector] Canonicalize/fold 'order preserving' transposes (#135841)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 1 15:50:58 PDT 2025


Author: James Newling
Date: 2025-05-01T15:50:55-07:00
New Revision: 1a44f38d2af8724e9819f03d4b76a50615217a8d

URL: https://github.com/llvm/llvm-project/commit/1a44f38d2af8724e9819f03d4b76a50615217a8d
DIFF: https://github.com/llvm/llvm-project/commit/1a44f38d2af8724e9819f03d4b76a50615217a8d.diff

LOG: [mlir][vector] Canonicalize/fold 'order preserving' transposes (#135841)

Handles special case where transpose doesn't permute any non-1
dimensions (and so is effectively a shape_cast) and is adjacent to a
shape_cast that it can fold into. For example

```
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>
```

can be folded into an adjacent shape_cast. An alternative to this PR
would be to canonicalize such transposes to shape_casts directly, but I
think it'll be difficult getting consensus that shape_cast is 'more
canonical' than transpose, so this PR compromises with the less
opinionated claim that

1) shape_cast is more canonical than shape_cast(transpose)
2) shape_cast is more canonical than transpose(shape_cast)

The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to
transposes with scalable dimensions reverses the canonicalization added
here, so I've I've disabled this canonicalization for scalable vectors

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f47e356d6fe14..0f96442bc3756 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
+namespace {
+
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+  ArrayRef<int64_t> permutation = transpose.getPermutation();
+  VectorType sourceType = transpose.getSourceVectorType();
+  ArrayRef<int64_t> inShape = sourceType.getShape();
+  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
+  auto isNonScalableUnitDim = [&](int64_t dim) {
+    return inShape[dim] == 1 && !inDimIsScalable[dim];
+  };
+  int64_t current = 0;
+  for (auto p : permutation) {
+    if (!isNonScalableUnitDim(p)) {
+      if (p < current) {
+        return false;
+      }
+      current = p;
+    }
+  }
+  return true;
+}
+
+} // namespace
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   if (getSource().getType() == resultType)
     return getSource();
 
-  // Y = shape_cast(shape_cast(X)))
-  //      -> X, if X and Y have same type
-  //      -> shape_cast(X) otherwise.
-  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-    VectorType srcType = otherOp.getSource().getType();
-    if (resultType == srcType)
-      return otherOp.getSource();
-    setOperand(otherOp.getSource());
+  // shape_cast(shape_cast(x)) -> shape_cast(x)
+  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
+    setOperand(precedingShapeCast.getSource());
     return getResult();
   }
 
+  // shape_cast(transpose(x)) -> shape_cast(x)
+  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+    // This folder does
+    //    shape_cast(transpose) -> shape_cast
+    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+    //    shape_cast -> shape_cast(transpose)
+    // i.e. the complete opposite. When paired, these 2 patterns can cause
+    // infinite cycles in pattern rewriting.
+    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+    // vectors, so by disabling this folder for scalable vectors the
+    // cycle is avoided.
+    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
+    // still needed. If it's not, then we can fold here.
+    if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+      setOperand(transpose.getVector());
+      return getResult();
+    }
+    return {};
+  }
+
   // Y = shape_cast(broadcast(X))
   //      -> X, if X and Y have same type
   if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
@@ -5619,7 +5662,7 @@ namespace {
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
-///   vector<4x1x1xi1> --> vector<4x1>
+///   vector<4x1x1xi1> --> vector<4x1xi1>
 ///
 static VectorType trimTrailingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6086,6 +6129,32 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto shapeCastOp =
+        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+    if (!isOrderPreserving(transposeOp))
+      return failure();
+
+    VectorType resultType = transposeOp.getType();
+
+    // We don't need to check isValidShapeCast at this point, because it is
+    // guaranteed that merging the transpose into the the shape_cast is a valid
+    // shape_cast, because the transpose just inserts/removes ones.
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+                                                     shapeCastOp.getSource());
+    return success();
+  }
+};
+
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6184,8 +6253,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
-              FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
+              FoldTransposeSplat, FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e0ec9c66d3a48..99f0850000a16 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
+
 // -----
 
 // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
@@ -3061,7 +3062,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
   return %1 : vector<4x8xf32>
 }
 
-
 // -----
 
 // CHECK-LABEL: @insert_scalar_poison_idx

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e97e147459de2..91ee0d335ecca 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -137,3 +137,113 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
   return %1 : vector<3x3x3xi8>
 }
 
+
+// -----
+
+// Test of FoldTransposeShapeCast
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x4x4x1x1xi8> to vector<4x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// Test of FoldTransposeShapeCast
+// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3]
+     : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+  %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// Test of FoldTransposeShapeCast
+// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
+// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
+// CHECK-LABEL: @negative_transpose_shape_cast_scalable
+//       CHECK:  vector.transpose
+//       CHECK:  vector.shape_cast
+func.func @negative_transpose_shape_cast_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
+  %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
+  %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
+  return %1 : vector<[4]xi8>
+}
+
+// -----
+
+// Test of shape_cast folding.
+// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
+// vectors.
+// CHECK-LABEL: @shape_cast_transpose_scalable
+//       CHECK: vector.shape_cast
+//  CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
+  %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
+  return %1 : vector<[4]x1xi8>
+}
+
+// -----
+
+// Test of shape_cast folding.
+// A transpose that is 'order preserving' can be treated like a shape_cast. 
+// CHECK-LABEL: @shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<2x3x1x1xi8> to vector<6x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<6x1x1xi8> to vector<6x1x1xi8>
+  return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// Test of shape_cast folding.
+// Scalable dimensions should be treated as non-unit dimensions.
+// CHECK-LABEL: @shape_cast_transpose_scalable
+//       CHECK: vector.shape_cast
+//       CHECK: vector.transpose
+func.func @shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+  %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
+  return %1 : vector<4x[1]xi8>
+}
+
+// -----
+
+// Test of shape_cast (not) folding.
+// CHECK-LABEL: @negative_shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+//       CHECK:   return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+  %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+  return %1 : vector<2x3xi8>
+}


        


More information about the Mlir-commits mailing list