[Mlir-commits] [mlir] [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (PR #122437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 03:53:16 PST 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/122437

>From c9853d621fee639df4e8b72e695f75caca0d0096 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Jan 2025 17:53:04 +0800
Subject: [PATCH] [MLIR][Vector] Allow strided memref for one-element
 vector.load in lowering vector.gather

Signed-off-by: PragmaTwice <twice at apache.org>
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4aff565b81b453 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
+    // vector.load requires the most minor memref dim to have unit stride,
+    // or the result vector type to have only one element
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1)
+        if (stridesAttr.getStrides().back() != 1 &&
+            resultTy.getNumElements() != 1)
           return failure();
       }
     }



More information about the Mlir-commits mailing list