[Mlir-commits] [mlir] [mlir][vector] Sink vector.extract/splat into load/store ops (PR #134389)

Ivan Butygin llvmlistbot at llvm.org
Sat Apr 12 17:26:26 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/134389

>From 9b2c084ebff3cad132be861dc29b251cf0380645 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 4 Apr 2025 15:39:48 +0200
Subject: [PATCH 1/4] [mlir][vector] Sink extract/splat into load/store ops

```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```

```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```
---
 .../Vector/TransformOps/VectorTransformOps.td |  24 ++-
 .../Vector/Transforms/VectorRewritePatterns.h |   5 +
 .../TransformOps/VectorTransformOps.cpp       |   5 +
 .../Vector/Transforms/VectorTransforms.cpp    | 127 ++++++++++++
 .../Dialect/Vector/vector-sink-transform.mlir |   1 +
 mlir/test/Dialect/Vector/vector-sink.mlir     | 189 ++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |   1 +
 7 files changed, 350 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..7fbb437908866 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
     %0 = arith.addf %a, %b : vector<4x2xf32>
     %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
     ```
-    At the moment, these patterns are limited to vector.broadcast and
-    vector.transpose.
+    At the moment, these patterns are limited to vector.broadcast,
+    vector.transpose and vector.extract.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.sink_mem_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Patterns that remove redundant Vector Ops by merging them with load/store
+    ops
+    ```
+    vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+    vector.extract %0[1] : f32 from vector<4xf32>
+    ```
+    Gets converted to:
+    ```
+    %c1 = arith.constant 1 : index
+    %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+    %1 = memref.load %arg0[%0] : memref<?xf32>
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..7d3134bdae233 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                    PatternBenefit benefit = 1);
 
+/// Patterns that remove redundant Vector Ops by re-ordering them with
+/// memory Ops:
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
+
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
 /// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
   vector::populateSinkVectorOpsPatterns(patterns);
 }
 
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d50d5fe96f49a..bbd16d330041d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
   }
 };
 
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "not a load op");
+
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType memVecType = loadOp.getVectorType();
+    if (memVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+    if (isa<VectorType>(memType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    int64_t rankOffset = memType.getRank() - memVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+    int64_t finalRank = 0;
+    if (resVecType)
+      finalRank = resVecType.getRank();
+
+    SmallVector<Value> indices = loadOp.getIndices();
+    SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(loadOp);
+    Location loc = loadOp.getLoc();
+    for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+      OpFoldResult pos = extractPos[i - rankOffset];
+      if (isConstantIntValue(pos, 0))
+        continue;
+
+      Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+      auto ovf = arith::IntegerOverflowFlags::nsw;
+      indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+    }
+
+    Value base = loadOp.getBase();
+    if (resVecType) {
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+                                                  indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+    }
+    rewriter.eraseOp(loadOp);
+    return success();
+  }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getVectorType();
+    if (vecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    if (isa<VectorType>(op.getMemRefType().getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    if (vecType.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "only 1-element, vectors are supported");
+
+    Operation *splat = op.getValueToStore().getDefiningOp();
+    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+      return rewriter.notifyMatchFailure(op, "not a splat");
+
+    if (!splat->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    Value source = splat->getOperand(0);
+    Value base = op.getBase();
+    ValueRange indices = op.getIndices();
+
+    if (isa<VectorType>(source.getType())) {
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+    }
+    rewriter.eraseOp(splat);
+    return success();
+  }
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -2175,6 +2296,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
+                                                  benefit);
+}
+
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.sink_ops
+      transform.apply_patterns.vector.sink_mem_ops
     } : !transform.any_op
     transform.yield
   }
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..ad4fdbe0a7b5a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
   %1 = vector.extract %0[1] : f32 from vector<4xf32>
   return %1 : f32
 }
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_unsupported_ranks
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   return %[[EXT]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreFromSplat]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_2d
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+  %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+  vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+  %0 = vector.splat %arg2 : vector<[1]xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_vec_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_non_1
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+  %0 = vector.splat %arg2 : vector<4xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK:   return %[[RES:.*]] : vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return %0 : vector<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateSinkVectorOpsPatterns(patterns);
+    populateSinkVectorMemOpsPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };

>From b228d288936863c6a9d5f5b16386fbe9d1d485fc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 4 Apr 2025 16:38:13 +0200
Subject: [PATCH 2/4] comment

---
 .../Vector/Transforms/VectorRewritePatterns.h       | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7d3134bdae233..2d8b12c871be7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,8 +161,17 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                    PatternBenefit benefit = 1);
 
-/// Patterns that remove redundant Vector Ops by re-ordering them with
-/// memory Ops:
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
 void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
                                       PatternBenefit benefit = 1);
 

>From c3fe42b1d9823f2b2b7a73051b1215b5c2d7a1ed Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 02:04:12 +0200
Subject: [PATCH 3/4] review comments

---
 .../Vector/Transforms/VectorTransforms.cpp    | 24 +++++++++++++----
 mlir/test/Dialect/Vector/vector-sink.mlir     | 27 ++++++-------------
 2 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bbd16d330041d..3b59f9fa694e7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -962,6 +962,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 };
 
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
+///
+/// Example:
 /// ```
 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -1047,6 +1049,8 @@ struct ReorderElementwiseOpsOnBroadcast final
 /// This may result in cleaner code when extracting a single value
 /// from multi-element vector and also to help canonicalize 1-element vectors to
 /// scalars.
+///
+/// Example:
 /// ```
 ///  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
 ///  %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1104,6 +1108,8 @@ class ExtractOpFromElementwise final
 };
 
 /// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+///
+/// Example:
 /// ```
 ///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
 ///  vector.extract %0[1] : f32 from vector<4xf32>
@@ -1122,13 +1128,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
                                 PatternRewriter &rewriter) const override {
     auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
     if (!loadOp)
-      return rewriter.notifyMatchFailure(op, "not a load op");
+      return rewriter.notifyMatchFailure(op, "expected a load op");
 
+    // Checking for single use so we won't duplicate load ops.
     if (!loadOp->hasOneUse())
       return rewriter.notifyMatchFailure(op, "expected single op use");
 
-    VectorType memVecType = loadOp.getVectorType();
-    if (memVecType.isScalable())
+    VectorType loadVecType = loadOp.getVectorType();
+    if (loadVecType.isScalable())
       return rewriter.notifyMatchFailure(op,
                                          "scalable vectors are not supported");
 
@@ -1137,7 +1144,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
       return rewriter.notifyMatchFailure(
           op, "memrefs of vectors are not supported");
 
-    int64_t rankOffset = memType.getRank() - memVecType.getRank();
+    int64_t rankOffset = memType.getRank() - loadVecType.getRank();
     if (rankOffset < 0)
       return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
 
@@ -1149,6 +1156,9 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
     SmallVector<Value> indices = loadOp.getIndices();
     SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
 
+    // There may be memory stores between the load and the extract op, so we
+    // need to make sure that the new load op is inserted at the same place as
+    // the original load op.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(loadOp);
     Location loc = loadOp.getLoc();
@@ -1170,12 +1180,15 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
     } else {
       rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
     }
+    // We checked for single use so we can safely erase the load op.
     rewriter.eraseOp(loadOp);
     return success();
   }
 };
 
 /// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+///
+/// Example:
 /// ```
 /// %0 = vector.splat %arg2 : vector<1xf32>
 /// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
@@ -1205,8 +1218,9 @@ class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
 
     Operation *splat = op.getValueToStore().getDefiningOp();
     if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
-      return rewriter.notifyMatchFailure(op, "not a splat");
+      return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
 
+    // Checking for single use so we can remove splat.
     if (!splat->hasOneUse())
       return rewriter.notifyMatchFailure(op, "expected single op use");
 
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index ad4fdbe0a7b5a..6b060984e9f63 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -587,9 +587,9 @@ func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %a
   return %1 : vector<4xf32>
 }
 
-// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+// CHECK-LABEL: @negative_extract_load_scalar_from_vec_memref
 //  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
-func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+func.func @negative_extract_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
 // CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
 // CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
 // CHECK:   return %[[EXT]] : f32
@@ -609,9 +609,9 @@ func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: inde
   return %1, %0 : f32, vector<4xf32>
 }
 
-// CHECK-LABEL: @negative_load_scalable
+// CHECK-LABEL: @negative_extract_load_scalable
 //  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
 // CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
 // CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
 // CHECK:   return %[[EXT]] : f32
@@ -620,17 +620,6 @@ func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
   return %1 : f32
 }
 
-// CHECK-LABEL: @negative_extract_load_unsupported_ranks
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
-// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
-// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   return %[[EXT]] : vector<4xf32>
-  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
-  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
-  return %1 : vector<4xf32>
-}
-
 //-----------------------------------------------------------------------------
 // [Pattern: StoreFromSplat]
 //-----------------------------------------------------------------------------
@@ -653,9 +642,9 @@ func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
   return
 }
 
-// CHECK-LABEL: @store_broadcast_1d_2d
+// CHECK-LABEL: @store_broadcast_1d_to_2d
 //  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
-func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+func.func @store_broadcast_1d_to_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
 // CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
   %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
   vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
@@ -682,9 +671,9 @@ func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: inde
   return
 }
 
-// CHECK-LABEL: @negative_store_non_1
+// CHECK-LABEL: @negative_store_more_than_one_element
 //  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
-func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
 // CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
 // CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
   %0 = vector.splat %arg2 : vector<4xf32>

>From aaf353d52a2762c42fc842b474837257eb438c25 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 02:24:20 +0200
Subject: [PATCH 4/4] ArithIndexingBuilder

---
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h      | 14 +++++++++++++-
 mlir/lib/Dialect/Arith/Utils/Utils.cpp             |  6 +++---
 .../Dialect/Vector/Transforms/VectorTransforms.cpp |  5 ++---
 3 files changed, 18 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index d759299cbf762..34a8c3be0f3a8 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr);
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {
-  ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
+  ArithBuilder(
+      OpBuilder &b, Location loc,
+      arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none)
+      : b(b), loc(loc), ovf(ovf) {}
 
   Value _and(Value lhs, Value rhs);
   Value add(Value lhs, Value rhs);
@@ -114,6 +117,15 @@ struct ArithBuilder {
 private:
   OpBuilder &b;
   Location loc;
+  arith::IntegerOverflowFlags ovf;
+};
+
+/// ArithBuilder specialized specifically for tensor/memref indexing
+/// calculations. Those calculations generally should never signed overflow, so
+/// we can set oveflow flags accordingly.
+struct ArithIndexingBuilder : public ArithBuilder {
+  ArithIndexingBuilder(OpBuilder &b, Location loc)
+      : ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
 };
 
 namespace arith {
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 8dde9866b22b3..6b1074e454bd5 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
 Value ArithBuilder::add(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
     return b.create<arith::AddFOp>(loc, lhs, rhs);
-  return b.create<arith::AddIOp>(loc, lhs, rhs);
+  return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
 }
 Value ArithBuilder::sub(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
     return b.create<arith::SubFOp>(loc, lhs, rhs);
-  return b.create<arith::SubIOp>(loc, lhs, rhs);
+  return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
 }
 Value ArithBuilder::mul(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
     return b.create<arith::MulFOp>(loc, lhs, rhs);
-  return b.create<arith::MulIOp>(loc, lhs, rhs);
+  return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
 }
 Value ArithBuilder::sgt(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 3b59f9fa694e7..cafe2d6fbb85c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1162,15 +1162,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(loadOp);
     Location loc = loadOp.getLoc();
+    ArithIndexingBuilder idxBuilderf(rewriter, loc);
     for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
       OpFoldResult pos = extractPos[i - rankOffset];
       if (isConstantIntValue(pos, 0))
         continue;
 
       Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
-
-      auto ovf = arith::IntegerOverflowFlags::nsw;
-      indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+      indices[i] = idxBuilderf.add(indices[i], offset);
     }
 
     Value base = loadOp.getBase();



More information about the Mlir-commits mailing list