[Mlir-commits] [mlir] [mlir][Vector] Remove more special case uses for extractelement/insertelement (PR #130166)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 12:12:56 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

A number of places in our codebase special case to use extractelement/insertelement for 0D vectors, because extract/insert did not support 0D vectors previously. Since insert/extract support 0D vectors now, use them instead of special casing.

---
Full diff: https://github.com/llvm/llvm-project/pull/130166.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+2) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+18-14) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+1-5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+3-19) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+2-11) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+2-6) 
- (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..2f5436f353539 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -718,6 +718,7 @@ def Vector_ExtractOp :
   let results = (outs AnyType:$result);
 
   let builders = [
+    OpBuilder<(ins "Value":$source)>,
     OpBuilder<(ins "Value":$source, "int64_t":$position)>,
     OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
     OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
@@ -913,6 +914,7 @@ def Vector_InsertOp :
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let builders = [
+    OpBuilder<(ins "Value":$source, "Value":$dest)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..860778fc9db38 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -560,11 +560,9 @@ struct ElideUnitDimsInMultiDimReduction
     } else {
       // This means we are reducing all the dimensions, and all reduction
       // dimensions are of size 1. So a simple extraction would do.
-      SmallVector<int64_t> zeroIdx(shape.size(), 0);
       if (mask)
-        mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
-      cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
-                                                zeroIdx);
+        mask = rewriter.create<vector::ExtractOp>(loc, mask);
+      cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
     }
 
     Value result =
@@ -698,16 +696,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
       return failure();
 
     Location loc = reductionOp.getLoc();
-    Value result;
-    if (vectorType.getRank() == 0) {
-      if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
-    } else {
-      if (mask)
-        mask = rewriter.create<ExtractOp>(loc, mask, 0);
-      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
-    }
+    if (mask)
+      mask = rewriter.create<ExtractOp>(loc, mask);
+    Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -1294,6 +1285,12 @@ void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+                              Value source) {
+  auto vectorTy = cast<VectorType>(source.getType());
+  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, int64_t position) {
   build(builder, result, source, ArrayRef<int64_t>{position});
@@ -2916,6 +2913,13 @@ void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
 }
 
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+                             Value source, Value dest) {
+  auto vectorTy = cast<VectorType>(dest.getType());
+  build(builder, result, source, dest,
+        SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
                              Value source, Value dest, int64_t position) {
   build(builder, result, source, dest, ArrayRef<int64_t>{position});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fec3c6c52e5e4..11dcfe421e0c4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -52,11 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 
     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
     if (srcRank <= 1 && dstRank == 1) {
-      Value ext;
-      if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
-      else
-        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+      Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
       return success();
     }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 9c1e5fcee91de..23324a007377e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -189,25 +189,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
         incIdx(resIdx, resultVectorType);
       }
 
-      Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      Value extract =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2413a4126f3f7..074c2d5664f64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -920,17 +920,8 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (!xferOp.getPermutationMap().isMinorIdentity())
       return failure();
     // Only float and integer element types are supported.
-    Value scalar;
-    if (vecType.getRank() == 0) {
-      // vector.extract does not support vector<f32> etc., so use
-      // vector.extractelement instead.
-      scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
-                                                         xferOp.getVector());
-    } else {
-      SmallVector<int64_t> pos(vecType.getRank(), 0);
-      scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
-                                                  xferOp.getVector(), pos);
-    }
+    Value scalar =
+        rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
     // Construct a scalar store.
     if (isa<MemRefType>(xferOp.getSource().getType())) {
       rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..52a2224b963f2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -187,7 +187,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
 //       CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
 //       CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
 //       CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-//       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
 //       CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
 //       CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
 //       CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8bb6593d99058 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2658,7 +2658,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
 
 // CHECK-LABEL: func.func @fold_0d_vector_reduction
 func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
-  // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+  // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
   // CHECK-NEXT: return %[[RES]] : f32
   %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
   return %0 : f32
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index b4ebb14b8829e..52b0fdee184f6 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -45,9 +45,7 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
 
 // CHECK-LABEL: func @transfer_write_0d(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
   %0 = vector.broadcast %f : f32 to vector<f32>
   vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -69,9 +67,7 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
 
 // CHECK-LABEL: func @tensor_transfer_write_0d(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-//       CHECK:   %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
 //       CHECK:   return %[[r]]
 func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
   %0 = vector.broadcast %f : f32 to vector<f32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ab30acf68b30b..ef32f8c6a1cdb 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -117,7 +117,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 // CHECK-LABEL:   func.func @shape_cast_0d1d(
 // CHECK-SAME:                               %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
 // CHECK:           %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK:           %[[EXTRACT0:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
+// CHECK:           %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
 // CHECK:           %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
 // CHECK:           return %[[RES]] : vector<1xf32>
 // CHECK:         }
@@ -131,7 +131,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
 // CHECK-SAME:                               %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
 // CHECK:           %[[UB:.*]] = ub.poison : vector<f32>
 // CHECK:           %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK:           %[[RES:.*]] = vector.insertelement %[[EXTRACT0]], %[[UB]][] : vector<f32>
+// CHECK:           %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
 // CHECK:           return %[[RES]] : vector<f32>
 // CHECK:         }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/130166


More information about the Mlir-commits mailing list