[Mlir-commits] [mlir] [mlir][Vector] Remove `vector.extractelement/insertelement` from sparse vectorizer (PR #143270)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jun 7 07:32:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of
`vector.extractelement` and `vector.insertelement`.
---
Full diff: https://github.com/llvm/llvm-project/pull/143270.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+37-20)
- (modified) mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+5-5)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
// Initialize reduction vector to: | 0 | .. | 0 | r |
- return rewriter.create<vector::InsertElementOp>(
+ return rewriter.create<vector::InsertOp>(
loc, r, constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::MUL:
// Initialize reduction vector to: | 1 | .. | 1 | r |
- return rewriter.create<vector::InsertElementOp>(
+ return rewriter.create<vector::InsertOp>(
loc, r, constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
const VL vl;
};
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+ Value inp) {
+ if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+ if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+ if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+ rewriter.replaceOp(op, redOp.getVector());
+ return success();
+ }
+ }
+ }
+ return failure();
+}
+
/// Reduction chain cleanup.
/// v = for { }
-/// s = vsum(v) v = for { }
-/// u = expand(s) -> for (v) { }
+/// s = vsum(v) v = for { }
+/// u = broadcast(s) -> for (v) { }
/// for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
public:
- using OpRewritePattern<VectorOp>::OpRewritePattern;
+ using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(VectorOp op,
+ LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
- Value inp = op.getSource();
- if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
- if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
- if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
- rewriter.replaceOp(op, redOp.getVector());
- return success();
- }
- }
- }
- return failure();
+ return cleanReducChain(rewriter, op, op.getSource());
}
};
+/// Reduction chain cleanup.
+/// v = for { }
+/// s = vsum(v) v = for { }
+/// u = insert(s) -> for (v) { }
+/// for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp op,
+ PatternRewriter &rewriter) const override {
+ return cleanReducChain(rewriter, op, op.getValueToStore());
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
vector::populateVectorStepLoweringPatterns(patterns);
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
- patterns.add<ReducChainRewriter<vector::InsertElementOp>,
- ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+ patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
// CHECK-NOVEC: }
//
// CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC: vector.insertelement
+// CHECK-VEC: vector.insert
// CHECK-VEC: scf.for
// CHECK-VEC: vector.create_mask
// CHECK-VEC: vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
// CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
// CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
// CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
// CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK: %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
// CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
``````````
</details>
https://github.com/llvm/llvm-project/pull/143270
More information about the Mlir-commits
mailing list