[Mlir-commits] [mlir] [mlir][vector] Add vector.transpose with unit-dim to vector.shape_cast pattern (PR #72105)

Cullen Rhodes llvmlistbot at llvm.org
Mon Nov 13 03:44:50 PST 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/72105

This patch extends the vector.transpose lowering to replace:

  vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>>

with:

  vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>>

Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.

>From 7c53ec4d129d9f7bc389b1a7229924a8ffe34d63 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 9 Nov 2023 19:44:34 +0000
Subject: [PATCH] [mlir][vector] Add vector.transpose with unit-dim to
 vector.shape_cast pattern

This patch extends the vector.transpose lowering to replace:

  vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>>

with:

  vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>>

Source with leading unit-dim (inverse) is also replaced. Unit dim must
be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.
---
 .../Transforms/LowerVectorTranspose.cpp       | 21 ++++++
 .../Vector/vector-transpose-lowering.mlir     | 71 +++++++++++++++++++
 2 files changed, 92 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7d804ddcfa42ffe..cf35d64c0c6268d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -336,6 +336,27 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
+    // Replace:
+    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+    //                                 vector<1xnxelty>
+    // with:
+    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+    //
+    // Source with leading unit dim (inverse) is also replaced. Unit dim must
+    // be fixed. Non-unit can be scalable.
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        transp[0] == 1 && transp[1] == 0) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+      return success();
+    }
+
+    if (inputType.isScalable())
+      return failure();
+
     // Handle a true 2-D matrix transpose differently when requested.
     if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 22d9224838c49c4..c0b44428d5bcf30 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
   return %0 : vector<1x1x8x8xf32>
 }
 
+/// Scalable dim should not be unrolled.
+
+// CHECK-LABEL: func @transpose23_scalable
+// CHECK-NOT: vector.extract
+// CHECK-NOT: vector.insert
+// CHECK: vector.transpose
+func.func @transpose23_scalable(%arg0: vector<2x[3]xf32>) -> vector<[3]x2xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x[3]xf32> to vector<[3]x2xf32>
+  return %0 : vector<[3]x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {
@@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4x1xf32
+func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4x1xf32
+func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1x4xf32
+func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1xnx4xf32
+func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+  return %0 : vector<[4]x1xf32>
+}
+
+/// Scalable unit dim should not be lowered to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4xnx1xf32
+func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+  return %0 : vector<[1]x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4xnx1xf32
+func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+
+  return %0 : vector<[1]x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_transpose
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list