[Mlir-commits] [mlir] [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n) (PR #95743)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Jun 19 08:40:52 PDT 2024


================
@@ -267,37 +361,76 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
 
 // -----
 
-func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
-      vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
-      return
+// The input memref has a dynamic trailing shape and hence is not flattened.
+// TODO: This case could be supported via memref.dim
+
+func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+    %idx_1: index,
+    %idx_2: index,
+    %vec : vector<1x2x6xi32>,
+    %arg: memref<1x?x4x6xi32>) {
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x?x4x6xi32>
+  return
 }
 
-// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_write_0d(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
 //   CHECK-128B-NOT:   memref.collapse_shape
-//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
-func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
-      %cst = arith.constant 0 : i8
-      %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
-      return %0 : vector<i8>
+// The vector to be written represents a _non-contiguous_ slice of the output
+// memref.
+
+func.func @transfer_write_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>,
+    %vec : vector<2x1x2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+    vector<2x1x2x2xi8>, memref<5x4x3x2xi8>
+  return
 }
 
-// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_0d(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_slice(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+func.func @transfer_write_0d(
+    %arg : memref<i8>,
+    %vec : vector<i8>) {
+
+  vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
+  return
+}
+
+// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_0d(
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// TODO: Categorize + re-format
+///----------------------------------------------------------------------------------------
----------------
banach-space wrote:

It's for patch k / n :) I've started adding these as:
* I'm trying to avoid big patches to make it easier to review, so ...
* ... I'm adding random TODOs here and there so that I don't forget about things 😅 

https://github.com/llvm/llvm-project/pull/95743


More information about the Mlir-commits mailing list