[Mlir-commits] [mlir] [MLIR] [Vector] Disable canonicalization for vector.scatter with tensor output (PR #168824)

Ryutaro Okada llvmlistbot at llvm.org
Wed Nov 19 20:58:43 PST 2025


https://github.com/sakupan102 created https://github.com/llvm/llvm-project/pull/168824

Commit https://github.com/llvm/llvm-project/commit/7e7ea9c5357efcdf9ba6bd7ea3669e607a9af400 added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.

Closes https://github.com/llvm/llvm-project/issues/168695

>From f5b54e5f41c285b52d1e22b7fb9912c39088722e Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Thu, 20 Nov 2025 13:57:33 +0900
Subject: [PATCH] [MLIR] [Vector] Disable canonicalization for vector.scatter
 with tensor output

Commit https://github.com/llvm/llvm-project/commit/7e7ea9c5357efcdf9ba6bd7ea3669e607a9af400 added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.

Closes https://github.com/llvm/llvm-project/issues/168695

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..c4d49334602db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(scatter.getBase().getType()))
+      return failure();
+
     switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
@@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(op.getBase().getType()))
+      return failure();
+
     if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 



More information about the Mlir-commits mailing list