[Mlir-commits] [mlir] [mlir][vector] Canonicalize gathers/scatters with trivial offsets (PR #117939)

Ivan Butygin llvmlistbot at llvm.org
Thu Nov 28 17:42:46 PST 2024


================
@@ -2826,3 +2826,34 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
   return %1 : vector<1x1x2x1x1x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @contiguous_gather
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+//       CHECK:   return %[[R]]
+func.func @contiguous_gather(%base: memref<?xf32>,
+                             %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
----------------
Hardcode84 wrote:

* We don't need any special handling for constant mask as it already handled in existing masked -> non-masked canonicalizations, added a couple of tests.
* I can add support for non-zero start, but broadcast is more involved
  * For scatters duplicated indices are undefined per current spec
  * For gather we need `reduce(mask)` + 1-element `vector.maskedload` + `extract` + `splat` and I would rather not do this as part of this PR.

https://github.com/llvm/llvm-project/pull/117939


More information about the Mlir-commits mailing list