[Mlir-commits] [mlir] 357e380 - [mlir][vector] Prevent folding non memref-type gather into maskedload (#135371)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 18:15:54 PDT 2025


Author: Sagar Kulkarni
Date: 2025-04-12T04:15:51+03:00
New Revision: 357e3803bb94cc622c785f7eb60aa38d552bc5ef

URL: https://github.com/llvm/llvm-project/commit/357e3803bb94cc622c785f7eb60aa38d552bc5ef
DIFF: https://github.com/llvm/llvm-project/commit/357e3803bb94cc622c785f7eb60aa38d552bc5ef.diff

LOG: [mlir][vector] Prevent folding non memref-type gather into maskedload (#135371)

This patch fixes an issue in the FoldContiguousGather pattern which was
incorrectly folding vector.gather operations with contiguous indices
into vector.maskedload operations regardless of the base operand type.

While vector.gather operations can work on both tensor and memref types,
vector.maskedload operations are only valid for memref types. The
pattern was incorrectly lowering a tensor-based gather into a
masked-load, which is invalid.

This fix adds a type check to ensure the pattern only applies to
memref-based gather operations.

Co-authored-by: Sagar Kulkarni <sagar at rain.ai>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5324e38fa7d25..fdbdc72c057af 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5348,6 +5348,9 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherOp op,
                                 PatternRewriter &rewriter) const override {
+    if (!op.getBase().getType().isa<MemRefType>())
+      return rewriter.notifyMatchFailure(op, "base must be of memref type");
+
     if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
       return failure();
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6d82b85777b0..78b0ea78849e8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3198,6 +3198,19 @@ func.func @contiguous_gather_step(%base: memref<?xf32>,
 
 // -----
 
+// CHECK-LABEL: @no_fold_contiguous_gather_tensor
+func.func @no_fold_contiguous_gather_tensor(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  // CHECK: vector.gather
+  // CHECK-NOT: vector.maskedload
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru :
+    tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @gather_broadcast(
 // TODO: Broadcast is not supported yet
 //       CHECK:   %[[R:.*]] = vector.gather


        


More information about the Mlir-commits mailing list