[Mlir-commits] [mlir] [mlir][vector] Allow multi dim vectors in vector.scatter (PR #132217)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Mar 20 11:58:18 PDT 2025
================
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// -----
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+ %0 = arith.constant 0: index
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+ %0 = arith.constant 0: index
+ // vector.constant_mask only supports 'none set' or 'all set' scalable
+ // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+ // width vectors above.
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
----------------
qedawkins wrote:
Since the above test is already a negative test for multi-dim scatters, we need to add this one too just yet.
https://github.com/llvm/llvm-project/pull/132217
More information about the Mlir-commits
mailing list