[Mlir-commits] [mlir] [mlir][vector] Sink vector.extract/splat into load/store ops (PR #134389)
Ivan Butygin
llvmlistbot at llvm.org
Tue Apr 22 06:33:56 PDT 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/134389
>From fc53309660aa43eeb4dac68500b9fcfe7c71aaf2 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 4 Apr 2025 15:39:48 +0200
Subject: [PATCH 1/9] [mlir][vector] Sink extract/splat into load/store ops
```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```
```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```
---
.../Vector/TransformOps/VectorTransformOps.td | 24 ++-
.../Vector/Transforms/VectorRewritePatterns.h | 5 +
.../TransformOps/VectorTransformOps.cpp | 5 +
.../Vector/Transforms/VectorTransforms.cpp | 127 ++++++++++++
.../Dialect/Vector/vector-sink-transform.mlir | 1 +
mlir/test/Dialect/Vector/vector-sink.mlir | 189 ++++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 1 +
7 files changed, 350 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..7fbb437908866 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
- At the moment, these patterns are limited to vector.broadcast and
- vector.transpose.
+ At the moment, these patterns are limited to vector.broadcast,
+ vector.transpose and vector.extract.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.sink_mem_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Patterns that remove redundant Vector Ops by merging them with load/store
+ ops
+ ```
+ vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ vector.extract %0[1] : f32 from vector<4xf32>
+ ```
+ Gets converted to:
+ ```
+ %c1 = arith.constant 1 : index
+ %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+ %1 = memref.load %arg0[%0] : memref<?xf32>
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..06919a5ea27f4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns that remove redundant Vector Ops by re-ordering them with
+/// memory Ops:
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are
/// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
vector::populateSinkVectorOpsPatterns(patterns);
}
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 89839d0440d3c..f22276cd3d168 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1043,6 +1043,127 @@ class ExtractOpFromElementwise final
}
};
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+ if (!loadOp)
+ return rewriter.notifyMatchFailure(op, "not a load op");
+
+ if (!loadOp->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ VectorType memVecType = loadOp.getVectorType();
+ if (memVecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ MemRefType memType = loadOp.getMemRefType();
+ if (isa<VectorType>(memType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ int64_t rankOffset = memType.getRank() - memVecType.getRank();
+ if (rankOffset < 0)
+ return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+ auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+ int64_t finalRank = 0;
+ if (resVecType)
+ finalRank = resVecType.getRank();
+
+ SmallVector<Value> indices = loadOp.getIndices();
+ SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loadOp);
+ Location loc = loadOp.getLoc();
+ for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+ OpFoldResult pos = extractPos[i - rankOffset];
+ if (isConstantIntValue(pos, 0))
+ continue;
+
+ Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+ auto ovf = arith::IntegerOverflowFlags::nsw;
+ indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+ }
+
+ Value base = loadOp.getBase();
+ if (resVecType) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+ indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+ }
+ rewriter.eraseOp(loadOp);
+ return success();
+ }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getVectorType();
+ if (vecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ if (isa<VectorType>(op.getMemRefType().getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ if (vecType.getNumElements() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "only 1-element, vectors are supported");
+
+ Operation *splat = op.getValueToStore().getDefiningOp();
+ if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+ return rewriter.notifyMatchFailure(op, "not a splat");
+
+ if (!splat->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ Value source = splat->getOperand(0);
+ Value base = op.getBase();
+ ValueRange indices = op.getIndices();
+
+ if (isa<VectorType>(source.getType())) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+ }
+ rewriter.eraseOp(splat);
+ return success();
+ }
+};
+
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -2109,6 +2230,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
patterns.getContext(), benefit);
}
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.sink_ops
+ transform.apply_patterns.vector.sink_mem_ops
} : !transform.any_op
transform.yield
}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..ad4fdbe0a7b5a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1 : f32
}
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_high_rank
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK: return %[[EXT]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+ %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_unsupported_ranks
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[EXT]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreFromSplat]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @store_broadcast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+ %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+ vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+ %0 = vector.splat %arg2 : vector<[1]xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_vec_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_non_1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+ %0 = vector.splat %arg2 : vector<4xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK: return %[[RES:.*]] : vector<1xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return %0 : vector<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSinkVectorOpsPatterns(patterns);
+ populateSinkVectorMemOpsPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
>From e2dd80afd793f75e7449a91ff970ad58de814cfb Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 4 Apr 2025 16:38:13 +0200
Subject: [PATCH 2/9] comment
---
.../Vector/Transforms/VectorRewritePatterns.h | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 06919a5ea27f4..7a079dcc6affc 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,8 +161,17 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-/// Patterns that remove redundant Vector Ops by re-ordering them with
-/// memory Ops:
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
>From c2ddc12ebd1b26887f427dc9bfb479c3b69d3be5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 02:04:12 +0200
Subject: [PATCH 3/9] review comments
---
.../Vector/Transforms/VectorTransforms.cpp | 24 +++++++++++++----
mlir/test/Dialect/Vector/vector-sink.mlir | 27 ++++++-------------
2 files changed, 27 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f22276cd3d168..cd49bc7dca8af 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
};
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
+///
+/// Example:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
/// This may result in cleaner code when extracting a single value
/// from multi-element vector and also to help canonicalize 1-element vectors to
/// scalars.
+///
+/// Example:
/// ```
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1044,6 +1048,8 @@ class ExtractOpFromElementwise final
};
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+///
+/// Example:
/// ```
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
/// vector.extract %0[1] : f32 from vector<4xf32>
@@ -1062,13 +1068,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
PatternRewriter &rewriter) const override {
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
if (!loadOp)
- return rewriter.notifyMatchFailure(op, "not a load op");
+ return rewriter.notifyMatchFailure(op, "expected a load op");
+ // Checking for single use so we won't duplicate load ops.
if (!loadOp->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
- VectorType memVecType = loadOp.getVectorType();
- if (memVecType.isScalable())
+ VectorType loadVecType = loadOp.getVectorType();
+ if (loadVecType.isScalable())
return rewriter.notifyMatchFailure(op,
"scalable vectors are not supported");
@@ -1077,7 +1084,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
return rewriter.notifyMatchFailure(
op, "memrefs of vectors are not supported");
- int64_t rankOffset = memType.getRank() - memVecType.getRank();
+ int64_t rankOffset = memType.getRank() - loadVecType.getRank();
if (rankOffset < 0)
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
@@ -1089,6 +1096,9 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
SmallVector<Value> indices = loadOp.getIndices();
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+ // There may be memory stores between the load and the extract op, so we
+ // need to make sure that the new load op is inserted at the same place as
+ // the original load op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
@@ -1110,12 +1120,15 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
} else {
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
}
+ // We checked for single use so we can safely erase the load op.
rewriter.eraseOp(loadOp);
return success();
}
};
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+///
+/// Example:
/// ```
/// %0 = vector.splat %arg2 : vector<1xf32>
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
@@ -1145,8 +1158,9 @@ class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
Operation *splat = op.getValueToStore().getDefiningOp();
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
- return rewriter.notifyMatchFailure(op, "not a splat");
+ return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+ // Checking for single use so we can remove splat.
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index ad4fdbe0a7b5a..6b060984e9f63 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -587,9 +587,9 @@ func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %a
return %1 : vector<4xf32>
}
-// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+// CHECK-LABEL: @negative_extract_load_scalar_from_vec_memref
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
-func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+func.func @negative_extract_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
// CHECK: return %[[EXT]] : f32
@@ -609,9 +609,9 @@ func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: inde
return %1, %0 : f32, vector<4xf32>
}
-// CHECK-LABEL: @negative_load_scalable
+// CHECK-LABEL: @negative_extract_load_scalable
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
// CHECK: return %[[EXT]] : f32
@@ -620,17 +620,6 @@ func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
return %1 : f32
}
-// CHECK-LABEL: @negative_extract_load_unsupported_ranks
-// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
-// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
-// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK: return %[[EXT]] : vector<4xf32>
- %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
- %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
- return %1 : vector<4xf32>
-}
-
//-----------------------------------------------------------------------------
// [Pattern: StoreFromSplat]
//-----------------------------------------------------------------------------
@@ -653,9 +642,9 @@ func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
return
}
-// CHECK-LABEL: @store_broadcast_1d_2d
+// CHECK-LABEL: @store_broadcast_1d_to_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
-func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+func.func @store_broadcast_1d_to_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
%0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
@@ -682,9 +671,9 @@ func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: inde
return
}
-// CHECK-LABEL: @negative_store_non_1
+// CHECK-LABEL: @negative_store_more_than_one_element
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
-func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
%0 = vector.splat %arg2 : vector<4xf32>
>From abf51afcfd37cbc66c2c4c8ea40537985d520847 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 02:24:20 +0200
Subject: [PATCH 4/9] ArithIndexingBuilder
---
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 14 +++++++++++++-
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 6 +++---
.../Dialect/Vector/Transforms/VectorTransforms.cpp | 5 ++---
3 files changed, 18 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index d759299cbf762..34a8c3be0f3a8 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr);
/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {
- ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
+ ArithBuilder(
+ OpBuilder &b, Location loc,
+ arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none)
+ : b(b), loc(loc), ovf(ovf) {}
Value _and(Value lhs, Value rhs);
Value add(Value lhs, Value rhs);
@@ -114,6 +117,15 @@ struct ArithBuilder {
private:
OpBuilder &b;
Location loc;
+ arith::IntegerOverflowFlags ovf;
+};
+
+/// ArithBuilder specialized specifically for tensor/memref indexing
+/// calculations. Those calculations generally should never signed overflow, so
+/// we can set oveflow flags accordingly.
+struct ArithIndexingBuilder : public ArithBuilder {
+ ArithIndexingBuilder(OpBuilder &b, Location loc)
+ : ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
};
namespace arith {
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 8dde9866b22b3..6b1074e454bd5 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
Value ArithBuilder::add(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
return b.create<arith::AddFOp>(loc, lhs, rhs);
- return b.create<arith::AddIOp>(loc, lhs, rhs);
+ return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
}
Value ArithBuilder::sub(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
return b.create<arith::SubFOp>(loc, lhs, rhs);
- return b.create<arith::SubIOp>(loc, lhs, rhs);
+ return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
}
Value ArithBuilder::mul(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
return b.create<arith::MulFOp>(loc, lhs, rhs);
- return b.create<arith::MulIOp>(loc, lhs, rhs);
+ return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
}
Value ArithBuilder::sgt(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index cd49bc7dca8af..55d9b6bdff3a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1102,15 +1102,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
+ ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
if (isConstantIntValue(pos, 0))
continue;
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
-
- auto ovf = arith::IntegerOverflowFlags::nsw;
- indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+ indices[i] = idxBuilderf.add(indices[i], offset);
}
Value base = loadOp.getBase();
>From 3668826c94d5bb562502b301acc41296cc43e5c6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 02:31:44 +0200
Subject: [PATCH 5/9] rename pattern
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 55d9b6bdff3a6..6b57fb7c9fd9c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1136,7 +1136,8 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
-class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+class StoreFromSplatOrBroadcast final
+ : public OpRewritePattern<vector::StoreOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -2245,8 +2246,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
- patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
- benefit);
+ patterns.add<ExtractOpFromLoad, StoreFromSplatOrBroadcast>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
>From 9b7af3ab491788d358b81e01033dc9f4dfd0bce2 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 13 Apr 2025 02:39:25 +0200
Subject: [PATCH 6/9] fix doc
---
.../mlir/Dialect/Vector/TransformOps/VectorTransformOps.td | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 7fbb437908866..2e3494d970230 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -491,6 +491,7 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
+ ```
}];
let assemblyFormat = "attr-dict";
>From 1b5b4082a7c273fcb1896460d354dd52e3df6990 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 17 Apr 2025 22:59:00 +0200
Subject: [PATCH 7/9] revirew comments
---
.../Vector/Transforms/VectorTransforms.cpp | 14 +++++++-------
mlir/test/Dialect/Vector/vector-sink.mlir | 18 +++++++++---------
2 files changed, 16 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6b57fb7c9fd9c..029bbe631e258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1088,10 +1088,10 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
if (rankOffset < 0)
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
- auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+ auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
int64_t finalRank = 0;
- if (resVecType)
- finalRank = resVecType.getRank();
+ if (extractVecType)
+ finalRank = extractVecType.getRank();
SmallVector<Value> indices = loadOp.getIndices();
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
@@ -1113,8 +1113,8 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
}
Value base = loadOp.getBase();
- if (resVecType) {
- rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+ if (extractVecType) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
indices);
} else {
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
@@ -1136,7 +1136,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
-class StoreFromSplatOrBroadcast final
+class StoreOpFromSplatOrBroadcast final
: public OpRewritePattern<vector::StoreOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -2246,7 +2246,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
- patterns.add<ExtractOpFromLoad, StoreFromSplatOrBroadcast>(
+ patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 6b060984e9f63..0bdfd1c8a83bf 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -551,9 +551,9 @@ func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2
return %1 : f32
}
-// CHECK-LABEL: @extract_load_vec
+// CHECK-LABEL: @extract_load_vec_non_zero_off
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
@@ -563,9 +563,9 @@ func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index)
return %1 : vector<4xf32>
}
-// CHECK-LABEL: @extract_load_scalar_high_rank
+// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
@@ -587,9 +587,9 @@ func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %a
return %1 : vector<4xf32>
}
-// CHECK-LABEL: @negative_extract_load_scalar_from_vec_memref
+// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
-func.func @negative_extract_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
// CHECK: return %[[EXT]] : f32
@@ -621,7 +621,7 @@ func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) ->
}
//-----------------------------------------------------------------------------
-// [Pattern: StoreFromSplat]
+// [Pattern: StoreOpFromSplatOrBroadcast]
//-----------------------------------------------------------------------------
// CHECK-LABEL: @store_splat
@@ -661,9 +661,9 @@ func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f3
return
}
-// CHECK-LABEL: @negative_store_vec_memref
+// CHECK-LABEL: @negative_store_memref_of_vec
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
-func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+func.func @negative_store_memref_of_vec(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
%0 = vector.splat %arg2 : vector<1xf32>
>From cfaef9df2d3339ff7755ead57d13b37a62f12c54 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 17 Apr 2025 23:14:23 +0200
Subject: [PATCH 8/9] ignore non-byte-aligned types
---
.../Vector/Transforms/VectorTransforms.cpp | 13 ++++++++---
mlir/test/Dialect/Vector/vector-sink.mlir | 22 +++++++++++++++++++
2 files changed, 32 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 029bbe631e258..be3f1f0cafe1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1047,6 +1047,14 @@ class ExtractOpFromElementwise final
}
};
+static bool isSupportedMemSinkElementType(Type type) {
+ if (isa<IndexType>(type))
+ return true;
+
+ // Non-byte-aligned types are tricky, skip them.
+ return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
+}
+
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
///
/// Example:
@@ -1080,9 +1088,8 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
"scalable vectors are not supported");
MemRefType memType = loadOp.getMemRefType();
- if (isa<VectorType>(memType.getElementType()))
- return rewriter.notifyMatchFailure(
- op, "memrefs of vectors are not supported");
+ if (!isSupportedMemSinkElementType(memType.getElementType()))
+ return rewriter.notifyMatchFailure(op, "unsupported memref element type");
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
if (rankOffset < 0)
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 0bdfd1c8a83bf..900ad99bb4a4c 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -528,6 +528,16 @@ func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
return %1 : f32
}
+// CHECK-LABEL: @extract_load_index
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
+func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
+// CHECK: return %[[RES]] : index
+ %0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
+ %1 = vector.extract %0[0] : index from vector<4xindex>
+ return %1 : index
+}
+
// CHECK-LABEL: @extract_load_scalar_non_zero_off
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
@@ -598,6 +608,18 @@ func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvecto
return %1 : f32
}
+// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
+// Subbyte types are tricky, ignore them for now.
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
+// CHECK: return %[[EXT]] : i1
+ %0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
+ %1 = vector.extract %0[0] : i1 from vector<8xi1>
+ return %1 : i1
+}
+
// CHECK-LABEL: @negative_extract_load_no_single_use
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
>From a959b60b785bb9884db0d0343f7614c113621e52 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 22 Apr 2025 15:20:54 +0200
Subject: [PATCH 9/9] review comments
---
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 4 ++--
.../Vector/TransformOps/VectorTransformOps.td | 11 ++++++++---
.../Vector/Transforms/VectorTransforms.cpp | 15 +++++++++++----
3 files changed, 21 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 34a8c3be0f3a8..c0b286494996b 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -121,8 +121,8 @@ struct ArithBuilder {
};
/// ArithBuilder specialized specifically for tensor/memref indexing
-/// calculations. Those calculations generally should never signed overflow, so
-/// we can set oveflow flags accordingly.
+/// calculations. Those calculations generally should never signed overflow and
+/// always use signed integers, so we can set oveflow flags accordingly.
struct ArithIndexingBuilder : public ArithBuilder {
ArithIndexingBuilder(OpBuilder &b, Location loc)
: ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2e3494d970230..14cbbac99d9ae 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -458,7 +458,9 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Patterns that remove redundant Vector Ops by re-ordering them with
- e.g. elementwise Ops:
+ e.g. elementwise Ops.
+
+ Example:
```
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
@@ -480,8 +482,11 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.sink_mem_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Patterns that remove redundant Vector Ops by merging them with load/store
- ops
+ Patterns that replace redundant Vector Ops (followed by
+ `vector.load`/`vector.store`) with either vector.load/vector.store or
+ `memref.load`/`memref.store`. Currently limited to 1-element vectors.
+
+ Example:
```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index be3f1f0cafe1e..b94c5fce64f83 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1047,15 +1047,18 @@ class ExtractOpFromElementwise final
}
};
+/// Check if the element type is suitable for vector.load/store sinking.
+/// Element type must be index or byte-aligned integer or floating-point type.
static bool isSupportedMemSinkElementType(Type type) {
if (isa<IndexType>(type))
return true;
- // Non-byte-aligned types are tricky, skip them.
return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
}
-/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
+/// Only index and byte-aligned integer and floating-point element types are
+/// supported for now.
///
/// Example:
/// ```
@@ -1088,8 +1091,11 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
"scalable vectors are not supported");
MemRefType memType = loadOp.getMemRefType();
+
+ // Non-byte-aligned types are tricky and may require special handling,
+ // ignore them for now.
if (!isSupportedMemSinkElementType(memType.getElementType()))
- return rewriter.notifyMatchFailure(op, "unsupported memref element type");
+ return rewriter.notifyMatchFailure(op, "unsupported element type");
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
if (rankOffset < 0)
@@ -1161,7 +1167,7 @@ class StoreOpFromSplatOrBroadcast final
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(
- op, "only 1-element, vectors are supported");
+ op, "only 1-element vectors are supported");
Operation *splat = op.getValueToStore().getDefiningOp();
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
@@ -2253,6 +2259,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
+ // TODO: Consider converting these patterns to canonicalizations.
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
patterns.getContext(), benefit);
}
More information about the Mlir-commits
mailing list