[Mlir-commits] [mlir] [mlir][Vector] Remove `vector.extractelement/insertelement` from sparse vectorizer (PR #143270)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jun 7 07:32:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of
`vector.extractelement` and `vector.insertelement`.

---
Full diff: https://github.com/llvm/llvm-project/pull/143270.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+37-20) 
- (modified) mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+5-5) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
     // Initialize reduction vector to: | 0 | .. | 0 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantZero(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::MUL:
     // Initialize reduction vector to: | 1 | .. | 1 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantOne(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
   const VL vl;
 };
 
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+                                     Value inp) {
+  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+    if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+      if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+        rewriter.replaceOp(op, redOp.getVector());
+        return success();
+      }
+    }
+  }
+  return failure();
+}
+
 /// Reduction chain cleanup.
 ///   v = for { }
-///   s = vsum(v)               v = for { }
-///   u = expand(s)       ->    for (v) { }
+///   s = vsum(v)                  v = for { }
+///   u = broadcast(s)       ->    for (v) { }
 ///   for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
 public:
-  using OpRewritePattern<VectorOp>::OpRewritePattern;
+  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(VectorOp op,
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
-    Value inp = op.getSource();
-    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
-      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
-        if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
-          rewriter.replaceOp(op, redOp.getVector());
-          return success();
-        }
-      }
-    }
-    return failure();
+    return cleanReducChain(rewriter, op, op.getSource());
   }
 };
 
+/// Reduction chain cleanup.
+///   v = for { }
+///   s = vsum(v)               v = for { }
+///   u = insert(s)       ->    for (v) { }
+///   for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    return cleanReducChain(rewriter, op, op.getValueToStore());
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
   vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
-  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
-               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+  patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
 // CHECK-NOVEC:       }
 //
 // CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC:       vector.insertelement
+// CHECK-VEC:       vector.insert
 // CHECK-VEC:       scf.for
 // CHECK-VEC:         vector.create_mask
 // CHECK-VEC:         vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-IDX32-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16-IDX32:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC4-SVE:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
 // CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
 // CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
 // CHECK-VEC4-SVE:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
 // CHECK:               %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
 // CHECK:               scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
 // CHECK:             } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK:             %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
 // CHECK:             %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
 // CHECK:               %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
 // CHECK:               %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:           %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON:           %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
 // CHECK-ON:           %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:             %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:             %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-ON:  %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:  %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:  %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:  %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:  %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:  %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:    %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:    %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>

``````````

</details>


https://github.com/llvm/llvm-project/pull/143270


More information about the Mlir-commits mailing list