[Mlir-commits] [mlir] [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (PR #122437)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jan 10 08:51:35 PST 2025
================
@@ -136,6 +136,24 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
return %0 : vector<2xf32>
}
+// CHECK-LABEL: @gather_strided_memref_1d
+// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
+// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
+// CHECK: %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
+// CHECK: scf.yield %[[RES]] : vector<1xf32>
+// CHECK: } else {
+// CHECK: scf.yield %arg3 : vector<1xf32>
+// CHECK: }
+// CHECK: return %[[RET]] : vector<1xf32>
+func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
----------------
banach-space wrote:
```suggestion
func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
```
https://github.com/llvm/llvm-project/pull/122437
More information about the Mlir-commits
mailing list