[Mlir-commits] [mlir] [mlir][vector] Fix vector.gather lowering for strided memrefs. (PR #184706)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Mar 13 04:08:44 PDT 2026


================
@@ -289,3 +291,60 @@ func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
+
+// Verify that gather on a strided 2D memref with zero base offsets
+// delinearizes the gather index directly (linearize and addi fold away).
+// CHECK-LABEL: @gather_strided_memref_2d
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4x2xf32, strided<[4, 1]>>,
+// CHECK-SAME:     %[[IDXVEC:.+]]: vector<4xi32>,
+// CHECK-SAME:     %[[MASK:.+]]: vector<4xi1>,
+// CHECK-SAME:     %[[PASS:.+]]: vector<4xf32>)
+// CHECK-DAG:     %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[IDX0]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
+// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
+// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
+// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
----------------
banach-space wrote:

Missing `scf.if` (and other ops) is quite confusing to me. Could you reformat a bit?

```suggestion
/// idx == 0
// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
// CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[IDX0]] into (4, 2)
// CHECK:         scf.if
// CHECK:           vector.load %[[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>

/// idx == 1
// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][1]
// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
// CHECK:         scf.if
// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>

/// idx == 2
// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][2]
// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>

/// idx == 3
// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][3]
// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
```

https://github.com/llvm/llvm-project/pull/184706


More information about the Mlir-commits mailing list