[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for gather over a strided memref (PR #72991)

Han-Chung Wang llvmlistbot at llvm.org
Mon Nov 27 15:46:33 PST 2023


================
@@ -151,3 +151,57 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// Check that vector.gather of a strided memref is replaced with a
+// vector.gather with indices encoding the original strides. Note that with the
+// other patterns
+#map = affine_map<()[s0] -> (s0 * 4096)>
+#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)>
+func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) {
+  %c0 = arith.constant 0 : index
+  %x_1 = affine.apply #map()[%x]
+  // Strided MemRef
+  %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+  %cst_0 = arith.constant dense<true> : vector<4xi1>
+  %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+  // Gather of a strided MemRef
+  %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
+  vector.store %7, %subview_1[%c0] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
+  return
+}
+// CHECK-LABEL:   func.func @strided_gather(
+// CHECK-SAME:                         %[[M_in:.*]]: memref<100x3xf32>,
+// CHECK-SAME:                         %[[M_out:.*]]: memref<518400xf32>,
+// CHECK-SAME:                         %[[IDXS:.*]]: vector<4xindex>,
+// CHECK-SAME:                         %[[VAL_4:.*]]: index,
+// CHECK-SAME:                         %[[VAL_5:.*]]: index) {
+// CHECK:           %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
+
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK:           %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
+
+// CHECK:           %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK:             %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
----------------
hanhanW wrote:

nit: we don't need to escape `[` in this case. Because we don't capture `%` in variables. It is a trick that I learned from other MLIR contributers.

```suggestion
// CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
```

https://github.com/llvm/llvm-project/pull/72991


More information about the Mlir-commits mailing list