[Mlir-commits] [mlir] [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (PR #122437)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 10 02:02:49 PST 2025
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/122437
In `Gather1DToConditionalLoads`, currently we will check if the stride of the most minor dim of the input memref is 1. However, according to the verification of vector.load here:
https://github.com/llvm/llvm-project/blob/4e32271e8b304eb018c69f74c16edd1668fcdaf3/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L4971-L4975
.. if the output vector type of `vector.load` contains only one element, we can ignore the requirement of the stride of the input memref.
So here we can allow more cases in lowering vector.gather.
Test cases will be added soon : )
>From e8f3a28d7b4a6f8b39c0b291bce77217c45b4389 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Jan 2025 17:53:04 +0800
Subject: [PATCH] [MLIR][Vector] Allow strided memref for one-element
vector.load in lowering vector.gather
Signed-off-by: PragmaTwice <twice at apache.org>
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4a8ad4eafdefd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
+ // vector.load requires the most minor memref dim to have unit stride,
+ // or the result vector type to have only one element
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
- if (stridesAttr.getStrides().back() != 1)
+ if (stridesAttr.getStrides().back() != 1 && resultTy.getNumElements() != 1)
return failure();
}
}
More information about the Mlir-commits
mailing list