[Mlir-commits] [mlir] [mlir][vector] Add tests xfer-permute-lowering (nfc)(2/n) (PR #96033)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Aug 1 02:14:39 PDT 2024


================
@@ -83,19 +110,44 @@ func.func @xfer_write_transposing_permutation_map_masked(
 ///   * vector.broadcast + vector.transpose + vector.transfer_write with a map
 ///     which _is_ a permutation of a minor identity
 
-// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
-//       CHECK:   %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
-//       CHECK:   %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
-//       CHECK:   %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
-//       CHECK:   %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
-//       CHECK:   vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
-func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index,
-                                                   %base2 : index) {
-
-  %fn1 = arith.constant -2.0 : f32
-  %vf0 = vector.splat %fn1 : vector<7xf32>
-  %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
-  vector.transfer_write %vf0, %mem[%base1, %base2], %mask
+// CHECK-LABEL:   func.func @xfer_write_non_transposing_permutation_map(
+// CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[VEC:.*]]: vector<7xf32>,
+// CHECK-SAME:      %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index) {
+// CHECK:           %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
+// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
+func.func @xfer_write_non_transposing_permutation_map(
+    %mem : memref<?x?xf32>,
+    %arg0 : vector<7xf32>,
+    %base1 : index,
+    %base2 : index) {
----------------
banach-space wrote:

If you don't mind me expanding this PR, I'll do it here. Ultimately, this needs to happen and it's all about making the review process as smooth/easy as possible. If you are happy then I'm also happy :)

> Also getting rid of %dim args and use %idx.

This one I'm a bit unsure of - `%dim` is used for mask **dimension** rather than xfer_{read|write} index 🤔 

https://github.com/llvm/llvm-project/pull/96033


More information about the Mlir-commits mailing list