[Mlir-commits] [mlir] [mlir][vector] Allow multi dim vectors in vector.scatter (PR #132217)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Mar 20 12:57:34 PDT 2025


================
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
----------------
banach-space wrote:

[nit] Please make the comment consistent with the example.

https://github.com/llvm/llvm-project/pull/132217


More information about the Mlir-commits mailing list