[Mlir-commits] [mlir] 201da87 - [mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. (#102518)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 8 17:29:14 PDT 2024


Author: Han-Chung Wang
Date: 2024-08-08T17:29:09-07:00
New Revision: 201da87c3f53f9450dd60ee5adbddf46fe19c430

URL: https://github.com/llvm/llvm-project/commit/201da87c3f53f9450dd60ee5adbddf46fe19c430
DIFF: https://github.com/llvm/llvm-project/commit/201da87c3f53f9450dd60ee5adbddf46fe19c430.diff

LOG: [mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. (#102518)

https://github.com/llvm/llvm-project/commit/da8778e499d8049ac68c2e152941a38ff2bc9fb2
breaks the lowering of vector.transpose that all the dimensions are unit
dimensions. The revision fixes the issue and adds a test.

---------

Signed-off-by: hanhanW <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bc0c96b32a80f..7f59a378e0351 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1771,6 +1771,13 @@ struct DropUnitDimsFromTransposeOp final
       newPerm.push_back(idx - droppedDimsBefore[idx]);
     }
 
+    // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
+    // type when the dimensions are unit dimensions. In this case, the newPerm
+    // should be [0].
+    if (newPerm.empty()) {
+      newPerm.push_back(0);
+    }
+
     Location loc = op.getLoc();
     // Drop the unit dims via shape_cast.
     auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 937dbf22bb713..df0a5c5fa0ce8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -737,6 +737,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
 
 // -----
 
+func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+  %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
+  return %res : vector<1x1x1xf32>
+}
+// The `vec` is returned because there are other flattening patterns that fold
+// vector.shape_cast ops away.
+// CHECK-LABEL: func.func @transpose_with_all_unit_dims
+// CHECK-SAME:      %[[VEC:.[a-zA-Z0-9]+]]
+// CHECK-NEXT:    return %[[VEC]]
+
+// -----
+
 func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
   %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
   return %res : vector<4x3x2xf32>


        


More information about the Mlir-commits mailing list