[Mlir-commits] [mlir] [MLIR] [Vector] Fix canonicalization for vector.scatter with tensor output (PR #168824)

Ryutaro Okada llvmlistbot at llvm.org
Tue Dec 9 04:05:37 PST 2025


https://github.com/sakupan102 updated https://github.com/llvm/llvm-project/pull/168824

>From 34567b757d55e695f1d916828145fb640e167f13 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Thu, 20 Nov 2025 13:57:33 +0900
Subject: [PATCH 1/5] [MLIR] [Vector] Fix canonicalization for vector.scatter
 with tensor output

Commit https://github.com/llvm/llvm-project/commit/7e7ea9c5357efcdf9ba6bd7ea3669e607a9af400 added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we fix the canonicalization with tensor output.

Closes https://github.com/llvm/llvm-project/issues/168695

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 13 ++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 15 +++++++++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..5ede0d008e997 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,11 +6087,19 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
+    ShapedType baseType = scatter.getBaseType();
+    bool isMemRef = isa<MemRefType>(baseType);
+    if (!isMemRef && !isa<RankedTensorType>(baseType))
+      return failure();
+
     switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
-      rewriter.eraseOp(scatter);
+      if (isMemRef)
+        rewriter.eraseOp(scatter);
+      else
+        rewriter.replaceOp(scatter, scatter.getBase());
       return success();
     case MaskFormat::Unknown:
       return failure();
@@ -6107,6 +6115,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(op.getBase().getType()))
+      return failure();
+
     if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 084f49fca212f..fea6c39f05187 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3909,6 +3909,21 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
 
 // -----
 
+// CHECK-LABEL: @scatter_tensor_all_false
+//  CHECK-SAME:   (%[[BASE:.*]]: tensor<16xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
+//       CHECK:   return %[[BASE]] : tensor<16xf32>
+func.func @scatter_tensor_all_false(%base: tensor<16xf32>,
+                                    %index: vector<16xindex>,
+                                    %value: vector<16xf32>) -> tensor<16xf32> {
+  %c0 = arith.constant 0 : index
+  %mask = arith.constant dense<false> : vector<16xi1>
+  %0 = vector.scatter %base[%c0][%index], %mask, %value
+      : tensor<16xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+  return %0 : tensor<16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_extract_constant_indices
 //   CHECK-SAME:   %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
 //        CHECK:   %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>

>From e57730e282b5c3274278b4b9557ae604aaf7d722 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 9 Dec 2025 00:27:15 +0900
Subject: [PATCH 2/5] Add test for contiguous scatter tensor

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index fea6c39f05187..91fd2f65b630f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3909,6 +3909,25 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
 
 // -----
 
+// No canoniclization should happen here as the base is a tensor.
+// CHECK-LABEL: @contiguous_scatter_tensor
+//  CHECK-SAME:   (%[[BASE:.*]]: tensor<16xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+//       CHECK:   %[[SCATTER:.*]] = vector.scatter %[[BASE]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[VALUE]] : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+//       CHECK:   return %[[SCATTER]] : tensor<16xf32>
+func.func @contiguous_scatter_tensor(%base: tensor<16xf32>,
+                                          %mask: vector<16xi1>,
+                                          %value: vector<16xf32>) -> tensor<16xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  %0 = vector.scatter %base[%c0] [%indices], %mask, %value
+      : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+  return %0 : tensor<16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @scatter_tensor_all_false
 //  CHECK-SAME:   (%[[BASE:.*]]: tensor<16xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
 //       CHECK:   return %[[BASE]] : tensor<16xf32>

>From e43659caee13f790afbbbc7b84f1edc5da123449 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 9 Dec 2025 00:30:17 +0900
Subject: [PATCH 3/5] add comments

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5ede0d008e997..97bbce8b9e683 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6092,6 +6092,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
     if (!isMemRef && !isa<RankedTensorType>(baseType))
       return failure();
 
+    // Memrefs have no result, so an all-false mask can simply erase the op.
+    // Tensors carry the updated value, so we must replace uses with the
+    // original base tensor instead of erasing.
     switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
@@ -6115,6 +6118,8 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
+    // Fold only for memrefs: the replacement uses maskedstore, which does not
+    // support tensor bases. Tensor cases intentionally bail out.
     if (!isa<MemRefType>(op.getBase().getType()))
       return failure();
 

>From 17b271fcc80d1ca74dd895b16e3333488a14f181 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 9 Dec 2025 20:54:57 +0900
Subject: [PATCH 4/5] fix test case

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 91fd2f65b630f..a19e46f56c293 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3910,13 +3910,11 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
 // -----
 
 // No canoniclization should happen here as the base is a tensor.
-// CHECK-LABEL: @contiguous_scatter_tensor
-//  CHECK-SAME:   (%[[BASE:.*]]: tensor<16xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
-//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
-//       CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
-//       CHECK:   %[[SCATTER:.*]] = vector.scatter %[[BASE]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[VALUE]] : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
-//       CHECK:   return %[[SCATTER]] : tensor<16xf32>
-func.func @contiguous_scatter_tensor(%base: tensor<16xf32>,
+// CHECK-LABEL: @no_fold_contiguous_scatter_tensor
+// CHECK-NOT: vector.maskedstore
+//     CHECK:   %[[RES:.*]] = vector.scatter
+//     CHECK:   return %[[RES]]
+func.func @no_fold_contiguous_scatter_tensor(%base: tensor<16xf32>,
                                           %mask: vector<16xi1>,
                                           %value: vector<16xf32>) -> tensor<16xf32> {
   %c0 = arith.constant 0 : index

>From a237a23094699085e95decca2ce3602963401d83 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 9 Dec 2025 21:05:15 +0900
Subject: [PATCH 5/5] add test for scatter memref

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a19e46f56c293..17746c4faac06 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3926,6 +3926,21 @@ func.func @no_fold_contiguous_scatter_tensor(%base: tensor<16xf32>,
 
 // -----
 
+// CHECK-LABEL: @scatter_memref_all_false
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>)
+//  CHECK-NEXT:   return
+func.func @scatter_memref_all_false(%base: memref<?xf32>,
+                                    %index: vector<16xindex>,
+                                    %value: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  %mask = arith.constant dense<false> : vector<16xi1>
+  vector.scatter %base[%c0][%index], %mask, %value
+      : memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @scatter_tensor_all_false
 //  CHECK-SAME:   (%[[BASE:.*]]: tensor<16xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
 //       CHECK:   return %[[BASE]] : tensor<16xf32>



More information about the Mlir-commits mailing list