[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dims from vector.transpose (PR #102017)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Aug 7 02:34:05 PDT 2024


================
@@ -700,3 +700,36 @@ func.func @negative_out_of_bound_transfer_write(
 }
 // CHECK:     func.func @negative_out_of_bound_transfer_write
 // CHECK-NOT:   memref.collapse_shape
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromTransposeOp]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
+///----------------------------------------------------------------------------------------
+
+func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
+  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  return %0 : vector<[4]x1x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
+// CHECK-SAME:                                               %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
+
+// -----
+
+func.func @transpose_with_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x1x1x1x4x1xf32> {
+  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+  return %0 : vector<[4]x1x1x1x4x1xf32>
+}
+
----------------
banach-space wrote:

>From what I can tell, these two tests exercise identical code-paths, right? If that's the case, let's reduce this to just one test case. I'd probably keep the 2nd one as the more complex case. 

How about some negative tests, e.g. with scalable unit dims, `[1]`?

Also, input Vectors in this file are called `%vec` and outputs are `%res`. I know this because I've update it recently https://github.com/llvm/llvm-project/pull/101471 😅 Could you follow similar format? Thanks!

https://github.com/llvm/llvm-project/pull/102017


More information about the Mlir-commits mailing list