[Mlir-commits] [mlir] [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (PR #122437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 11:27:46 PST 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/122437

>From c9853d621fee639df4e8b72e695f75caca0d0096 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Jan 2025 17:53:04 +0800
Subject: [PATCH 1/4] [MLIR][Vector] Allow strided memref for one-element
 vector.load in lowering vector.gather

Signed-off-by: PragmaTwice <twice at apache.org>
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4aff565b81b453 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
+    // vector.load requires the most minor memref dim to have unit stride,
+    // or the result vector type to have only one element
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1)
+        if (stridesAttr.getStrides().back() != 1 &&
+            resultTy.getNumElements() != 1)
           return failure();
       }
     }

>From 1933fbad58302814ccce5991a9320c0967f3571b Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Jan 2025 21:49:16 +0800
Subject: [PATCH 2/4] add test case

---
 .../Dialect/Vector/vector-gather-lowering.mlir | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5ad3a23e0ba15c..5d7aff6f8762ad 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -136,6 +136,24 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
   return %0 : vector<2xf32>
 }
 
+// CHECK-LABEL: @gather_strided_memref_1d
+// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
+// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
+// CHECK:   %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK:   %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
+// CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
+// CHECK:   scf.yield %[[RES]] : vector<1xf32>
+// CHECK: } else {
+// CHECK:    scf.yield %arg3 : vector<1xf32>
+// CHECK: }
+// CHECK: return %[[RET]] : vector<1xf32>
+func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
 // CHECK-LABEL: @gather_tensor_2d
 // CHECK:  scf.if
 // CHECK:    tensor.extract

>From 70f4b42c805b8a052141b0863a1c5134ebbffa0b Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Sat, 11 Jan 2025 03:21:21 +0800
Subject: [PATCH 3/4] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andrzej WarzyƄski <andrzej.warzynski at gmail.com>
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 4 ++--
 mlir/test/Dialect/Vector/vector-gather-lowering.mlir     | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 4aff565b81b453..3b38505becd188 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,8 +205,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride,
-    // or the result vector type to have only one element
+    // vector.load requires the most minor memref dim to have unit stride
+    // (unless reading exactly 1 element)
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5d7aff6f8762ad..217cce54ce7c61 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -148,7 +148,7 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
 // CHECK:    scf.yield %arg3 : vector<1xf32>
 // CHECK: }
 // CHECK: return %[[RET]] : vector<1xf32>
-func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
+func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
   %c0 = arith.constant 0 : index
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
   return %0 : vector<1xf32>

>From 6a76046075091e177431f70b168fec61c155bf99 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 11 Jan 2025 03:27:32 +0800
Subject: [PATCH 4/4] fix test case and add a negative case

---
 .../Dialect/Vector/vector-gather-lowering.mlir   | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 217cce54ce7c61..20e9400ed698d4 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -136,11 +136,11 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
   return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: @gather_strided_memref_1d
+// CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
 // CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
-// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
 // CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
-// CHECK:   %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK:   %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32>
 // CHECK:   %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
 // CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
 // CHECK:   scf.yield %[[RES]] : vector<1xf32>
@@ -154,6 +154,16 @@ func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, str
   return %0 : vector<1xf32>
 }
 
+// CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
+// CHECK: %[[CONST:.*]] = arith.constant 0 : index
+// CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+// CHECK: return %[[RET]] : vector<2xf32>
+func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @gather_tensor_2d
 // CHECK:  scf.if
 // CHECK:    tensor.extract



More information about the Mlir-commits mailing list