[Mlir-commits] [mlir] [MLIR] [Vector] Fix canonicalization for vector.scatter with tensor output (PR #168824)
Ryutaro Okada
llvmlistbot at llvm.org
Tue Dec 9 04:08:26 PST 2025
================
@@ -3909,6 +3909,40 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
// -----
+// No canoniclization should happen here as the base is a tensor.
+// CHECK-LABEL: @contiguous_scatter_tensor
+// CHECK-SAME: (%[[BASE:.*]]: tensor<16xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+// CHECK: %[[SCATTER:.*]] = vector.scatter %[[BASE]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[VALUE]] : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+// CHECK: return %[[SCATTER]] : tensor<16xf32>
+func.func @contiguous_scatter_tensor(%base: tensor<16xf32>,
+ %mask: vector<16xi1>,
+ %value: vector<16xf32>) -> tensor<16xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ %0 = vector.scatter %base[%c0] [%indices], %mask, %value
+ : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+ return %0 : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_tensor_all_false
+// CHECK-SAME: (%[[BASE:.*]]: tensor<16xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
+// CHECK: return %[[BASE]] : tensor<16xf32>
+func.func @scatter_tensor_all_false(%base: tensor<16xf32>,
+ %index: vector<16xindex>,
+ %value: vector<16xf32>) -> tensor<16xf32> {
+ %c0 = arith.constant 0 : index
+ %mask = arith.constant dense<false> : vector<16xi1>
+ %0 = vector.scatter %base[%c0][%index], %mask, %value
+ : tensor<16xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+ return %0 : tensor<16xf32>
+}
----------------
sakupan102 wrote:
Added memref version of this test.
https://github.com/llvm/llvm-project/pull/168824
More information about the Mlir-commits
mailing list