[Mlir-commits] [mlir] [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (PR #122437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 02:02:49 PST 2025


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/122437

In `Gather1DToConditionalLoads`, currently we will check if the stride of the most minor dim of the input memref is 1. However, according to the verification of vector.load here:
https://github.com/llvm/llvm-project/blob/4e32271e8b304eb018c69f74c16edd1668fcdaf3/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L4971-L4975

.. if the output vector type of `vector.load` contains only one element, we can ignore the requirement of the stride of the input memref.

So here we can allow more cases in lowering vector.gather.

Test cases will be added soon : )



>From e8f3a28d7b4a6f8b39c0b291bce77217c45b4389 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Jan 2025 17:53:04 +0800
Subject: [PATCH] [MLIR][Vector] Allow strided memref for one-element
 vector.load in lowering vector.gather

Signed-off-by: PragmaTwice <twice at apache.org>
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4a8ad4eafdefd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
+    // vector.load requires the most minor memref dim to have unit stride,
+    // or the result vector type to have only one element
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1)
+        if (stridesAttr.getStrides().back() != 1 && resultTy.getNumElements() != 1)
           return failure();
       }
     }



More information about the Mlir-commits mailing list