[Mlir-commits] [mlir] 04b1975 - [MLIR] [Vector] Fix canonicalization for vector.scatter with tensor output (#168824)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 12 04:24:41 PST 2025
Author: Ryutaro Okada
Date: 2025-12-12T12:24:37Z
New Revision: 04b197599e73f55624d3514480a7f829343e5b61
URL: https://github.com/llvm/llvm-project/commit/04b197599e73f55624d3514480a7f829343e5b61
DIFF: https://github.com/llvm/llvm-project/commit/04b197599e73f55624d3514480a7f829343e5b61.diff
LOG: [MLIR] [Vector] Fix canonicalization for vector.scatter with tensor output (#168824)
Commit
https://github.com/llvm/llvm-project/commit/7e7ea9c5357efcdf9ba6bd7ea3669e607a9af400
added tensor support for scatter, but running the existing
canonicalization on tensors causes bugs, so we fix the canonicalization
with tensor output.
Closes https://github.com/llvm/llvm-project/issues/168695
---------
Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58b3fe02e5310..12bdc9646ee84 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6100,11 +6100,22 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
+ ShapedType baseType = scatter.getBaseType();
+ bool isMemRef = isa<MemRefType>(baseType);
+ if (!isMemRef && !isa<RankedTensorType>(baseType))
+ return failure();
+
+ // Memrefs have no result, so an all-false mask can simply erase the op.
+ // Tensors carry the updated value, so we must replace uses with the
+ // original base tensor instead of erasing.
switch (getMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
- rewriter.eraseOp(scatter);
+ if (isMemRef)
+ rewriter.eraseOp(scatter);
+ else
+ rewriter.replaceOp(scatter, scatter.getBase());
return success();
case MaskFormat::Unknown:
return failure();
@@ -6120,6 +6131,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
+ // Fold only for memrefs: the replacement uses maskedstore, which does not
+ // support tensor bases. Tensor cases intentionally bail out.
+ if (!isa<MemRefType>(op.getBase().getType()))
+ return failure();
+
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 50d52c92296d0..e17b1cfbe5e0d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3928,6 +3928,53 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
// -----
+// No canoniclization should happen here as the base is a tensor.
+// CHECK-LABEL: @no_fold_contiguous_scatter_tensor
+// CHECK-NOT: vector.maskedstore
+// CHECK: %[[RES:.*]] = vector.scatter
+// CHECK: return %[[RES]]
+func.func @no_fold_contiguous_scatter_tensor(%base: tensor<16xf32>,
+ %mask: vector<16xi1>,
+ %value: vector<16xf32>) -> tensor<16xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ %0 = vector.scatter %base[%c0] [%indices], %mask, %value
+ : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+ return %0 : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_memref_all_false
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>)
+// CHECK-NEXT: return
+func.func @scatter_memref_all_false(%base: memref<?xf32>,
+ %index: vector<16xindex>,
+ %value: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ %mask = arith.constant dense<false> : vector<16xi1>
+ vector.scatter %base[%c0][%index], %mask, %value
+ : memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_tensor_all_false
+// CHECK-SAME: (%[[BASE:.*]]: tensor<16xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
+// CHECK: return %[[BASE]] : tensor<16xf32>
+func.func @scatter_tensor_all_false(%base: tensor<16xf32>,
+ %index: vector<16xindex>,
+ %value: vector<16xf32>) -> tensor<16xf32> {
+ %c0 = arith.constant 0 : index
+ %mask = arith.constant dense<false> : vector<16xi1>
+ %0 = vector.scatter %base[%c0][%index], %mask, %value
+ : tensor<16xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
+ return %0 : tensor<16xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_extract_constant_indices
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
More information about the Mlir-commits
mailing list