[Mlir-commits] [mlir] [mlir][vector] Prevent folding non memref-type gather into maskedload (PR #135371)
Sagar Kulkarni
llvmlistbot at llvm.org
Fri Apr 11 13:53:28 PDT 2025
https://github.com/sagarkulkarni19 updated https://github.com/llvm/llvm-project/pull/135371
>From 34b56da4b661b8fa40e916c79d7c7748d19244c4 Mon Sep 17 00:00:00 2001
From: Sagar Kulkarni <sagar at rain.ai>
Date: Fri, 11 Apr 2025 09:48:40 -0400
Subject: [PATCH] [mlir][vector] Prevent folding non memref-type gather into
maskedload
This patch fixes an issue in the FoldContiguousGather pattern which was
incorrectly folding vector.gather operations with contiguous indices into
vector.maskedload operations regardless of the base operand type.
While vector.gather operations can work on both tensor and memref types,
vector.maskedload operations are only valid for memref types. The pattern
was incorrectly lowering a tensor-based gather into a masked-load, which
is invalid.
This fix adds a type check to ensure the pattern only applies to memref-based
gather operations.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +++
mlir/test/Dialect/Vector/canonicalize.mlir | 13 +++++++++++++
2 files changed, 16 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..e4de9757254c9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,6 +5340,9 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
+ if (!op.getBase().getType().isa<MemRefType>())
+ return rewriter.notifyMatchFailure(op, "base must be of memref type");
+
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..2d2c7a227a6cb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3149,6 +3149,19 @@ func.func @contiguous_gather_step(%base: memref<?xf32>,
// -----
+// CHECK-LABEL: @no_fold_contiguous_gather_tensor
+func.func @no_fold_contiguous_gather_tensor(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+ // CHECK: vector.gather
+ // CHECK-NOT: vector.maskedload
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru :
+ tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @gather_broadcast(
// TODO: Broadcast is not supported yet
// CHECK: %[[R:.*]] = vector.gather
More information about the Mlir-commits
mailing list