[Mlir-commits] [mlir] [mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. (PR #102518)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Aug 8 11:28:41 PDT 2024
https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/102518
https://github.com/llvm/llvm-project/commit/da8778e499d8049ac68c2e152941a38ff2bc9fb2 breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and add a test.
>From d96ccc2ccb677c2cc31cccb8b555834de5375231 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 8 Aug 2024 11:26:17 -0700
Subject: [PATCH] [mlir][vector] Handle corner cases in
DropUnitDimsFromTransposeOp.
https://github.com/llvm/llvm-project/commit/da8778e499d8049ac68c2e152941a38ff2bc9fb2
breaks the lowering of vector.transpose that all the dimensions are unit
dimensions. The revision fixes the issue and add a test.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Dialect/Vector/Transforms/VectorTransforms.cpp | 8 ++++++++
.../test/Dialect/Vector/vector-transfer-flatten.mlir | 12 ++++++++++++
2 files changed, 20 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bc0c96b32a80f5..b4ae9b319343ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1771,6 +1771,14 @@ struct DropUnitDimsFromTransposeOp final
newPerm.push_back(idx - droppedDimsBefore[idx]);
}
+ // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
+ // type when the dimensions are unit dimensions. In this case, the newPerm
+ // should be [0].
+ if (sourceTypeWithoutUnitDims.getRank() == 1 &&
+ sourceTypeWithoutUnitDims.getShape()[0] == 1 && newPerm.empty()) {
+ newPerm.push_back(0);
+ }
+
Location loc = op.getLoc();
// Drop the unit dims via shape_cast.
auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 937dbf22bb713f..0d34d692393fd8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -737,6 +737,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
// -----
+func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+ %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
+ return %res : vector<1x1x1xf32>
+}
+// The `vec` is returned because there are other flattening patterns fold
+// vector.shape_cast ops away.
+// CHECK-LABEL: func.func @transpose_with_all_unit_dims
+// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
+// CHECK-NEXT: return %[[VEC]]
+
+// -----
+
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
return %res : vector<4x3x2xf32>
More information about the Mlir-commits
mailing list