[Mlir-commits] [mlir] [mlir][Vector] Support `xfer_read(vector.extract))` folding with dynamic indices (PR #143269)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jun 7 07:29:28 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops. It adds support for folding `vector.transfer_read(vector.extract) -> memref.load` with dynamic indices, which is currently supported by `vector.extractelement`.

---
Full diff: https://github.com/llvm/llvm-project/pull/143269.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+19-5) 
- (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+30) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..36197eb1caeb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -886,12 +886,26 @@ class RewriteScalarExtractOfTransferRead
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
-      assert(isa<Attribute>(pos) && "Unexpected non-constant index");
-      int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, extractOp.getLoc(),
-          rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+
+      // Compute affine expression `newIndices[idx] + pos` where `pos` can be
+      // either a constant or a value.
+      OpFoldResult ofr;
+      if (auto attr = dyn_cast<Attribute>(pos)) {
+        int64_t offset = cast<IntegerAttr>(attr).getInt();
+        ofr = affine::makeComposedFoldedAffineApply(
+            rewriter, extractOp.getLoc(),
+            rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+      } else {
+        Value dynamicOffset = cast<Value>(pos);
+        AffineExpr sym0, sym1;
+        bindSymbols(rewriter.getContext(), sym0, sym1);
+        ofr = affine::makeComposedFoldedAffineApply(
+            rewriter, extractOp.getLoc(), sym0 + sym1,
+            {newIndices[idx], dynamicOffset});
+      }
+
+      // Update the corresponding index with the folded result.
       if (auto value = dyn_cast<Value>(ofr)) {
         newIndices[idx] = value;
       } else {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 52b0fdee184f6..9f10063a75092 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32>
   return %1 : vector<16xf32>
 }
 
+// -----
+
+//       CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d_extract_dynamic(
+//  CHECK-SAME:     %[[MEMREF:.*]]: memref<?xf32>, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index
+//       CHECK:   %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]]
+//       CHECK:   %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]]
+func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
+                                            %offset: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<5xf32>
+  %elem = vector.extract %vec[%offset] : f32 from vector<5xf32>
+  return %elem : f32
+}
+
+// -----
+
+//       CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
+//  CHECK-SAME:     %[[MEMREF:.*]]: memref<?x?xf32>, %[[M_IDX:.*]]: index, %[[ROW:.*]]: index, %[[COL:.*]]: index
+//       CHECK:   %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
+//       CHECK:   %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
+//       CHECK:   %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
+func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
+                                            %row_offset: index, %col_offset: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %vec = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
+  %elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
+  return %elem : f32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/143269


More information about the Mlir-commits mailing list