[Mlir-commits] [mlir] [mlir][vector] Prevent folding non memref-type gather into maskedload (PR #135371)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 07:06:57 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Sagar Kulkarni (sagarkulkarni19)

<details>
<summary>Changes</summary>

This patch fixes an issue in the FoldContiguousGather pattern which was incorrectly folding vector.gather operations with contiguous indices into vector.maskedload operations regardless of the base operand type.

While vector.gather operations can work on both tensor and memref types, vector.maskedload operations are only valid for memref types. The pattern was incorrectly lowering a tensor-based gather into a masked-load, which is invalid.

This fix adds a type check to ensure the pattern only applies to memref-based gather operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/135371.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+3) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..8955438b57343 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,6 +5340,9 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherOp op,
                                 PatternRewriter &rewriter) const override {
+    if (!op.getBase().getType().isa<MemRefType>())
+      return failure();
+
     if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..7d9223696712d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3149,6 +3149,18 @@ func.func @contiguous_gather_step(%base: memref<?xf32>,
 
 // -----
 
+// CHECK-LABEL: @dont_fold_tensor_type_contiguous_gather
+func.func @dont_fold_tensor_type_contiguous_gather(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  // CHECK: vector.gather
+  // CHECK-NOT: vector.maskedload
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @gather_broadcast(
 // TODO: Broadcast is not supported yet
 //       CHECK:   %[[R:.*]] = vector.gather

``````````

</details>


https://github.com/llvm/llvm-project/pull/135371


More information about the Mlir-commits mailing list