[Mlir-commits] [mlir] [MLIR] [Vector] Disable canonicalization for vector.scatter with tensor output (PR #168824)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 19 20:59:15 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Ryutaro Okada (sakupan102)
<details>
<summary>Changes</summary>
Commit https://github.com/llvm/llvm-project/commit/7e7ea9c5357efcdf9ba6bd7ea3669e607a9af400 added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.
Closes https://github.com/llvm/llvm-project/issues/168695
---
Full diff: https://github.com/llvm/llvm-project/pull/168824.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..c4d49334602db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
+ if (!isa<MemRefType>(scatter.getBase().getType()))
+ return failure();
+
switch (getMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
@@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
+ if (!isa<MemRefType>(op.getBase().getType()))
+ return failure();
+
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/168824
More information about the Mlir-commits
mailing list