[Mlir-commits] [mlir] [mlir][Vector] Remove more special case uses for extractelement/insertelement (PR #130166)
Kunwar Grover
llvmlistbot at llvm.org
Thu Mar 6 11:44:36 PST 2025
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/130166
None
>From 58a09e2ab50b9895433e024e40f07a3968925173 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 6 Mar 2025 19:43:27 +0000
Subject: [PATCH] [mlir][Vector] Remove more special case uses for
extractelement/insertelement
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 2 ++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 32 +++++++++++--------
.../Transforms/LowerVectorBroadcast.cpp | 6 +---
.../Transforms/LowerVectorShapeCast.cpp | 22 ++-----------
.../Transforms/VectorTransferOpTransforms.cpp | 13 ++------
.../VectorToLLVM/vector-to-llvm.mlir | 2 +-
mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
.../scalar-vector-transfer-to-memref.mlir | 8 ++---
...vector-shape-cast-lowering-transforms.mlir | 4 +--
9 files changed, 32 insertions(+), 59 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..2f5436f353539 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -718,6 +718,7 @@ def Vector_ExtractOp :
let results = (outs AnyType:$result);
let builders = [
+ OpBuilder<(ins "Value":$source)>,
OpBuilder<(ins "Value":$source, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
@@ -913,6 +914,7 @@ def Vector_InsertOp :
let results = (outs AnyVectorOfAnyRank:$result);
let builders = [
+ OpBuilder<(ins "Value":$source, "Value":$dest)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..860778fc9db38 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -560,11 +560,9 @@ struct ElideUnitDimsInMultiDimReduction
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
- SmallVector<int64_t> zeroIdx(shape.size(), 0);
if (mask)
- mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
- cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
- zeroIdx);
+ mask = rewriter.create<vector::ExtractOp>(loc, mask);
+ cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
}
Value result =
@@ -698,16 +696,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
return failure();
Location loc = reductionOp.getLoc();
- Value result;
- if (vectorType.getRank() == 0) {
- if (mask)
- mask = rewriter.create<ExtractElementOp>(loc, mask);
- result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
- } else {
- if (mask)
- mask = rewriter.create<ExtractOp>(loc, mask, 0);
- result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
- }
+ if (mask)
+ mask = rewriter.create<ExtractOp>(loc, mask);
+ Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -1294,6 +1285,12 @@ void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+ Value source) {
+ auto vectorTy = cast<VectorType>(source.getType());
+ build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) {
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -2916,6 +2913,13 @@ void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
}
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest) {
+ auto vectorTy = cast<VectorType>(dest.getType());
+ build(builder, result, source, dest,
+ SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, int64_t position) {
build(builder, result, source, dest, ArrayRef<int64_t>{position});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fec3c6c52e5e4..11dcfe421e0c4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -52,11 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext;
- if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
- else
- ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 9c1e5fcee91de..23324a007377e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -189,25 +189,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
incIdx(resIdx, resultVectorType);
}
- Value extract;
- if (srcRank == 0) {
- // 0-D vector special case
- assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
- extract = rewriter.create<vector::ExtractElementOp>(
- loc, op.getSourceVectorType().getElementType(), op.getSource());
- } else {
- extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- }
-
- if (resRank == 0) {
- // 0-D vector special case
- assert(resIdx.empty() && "Unexpected indices for 0-D vector");
- result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
- } else {
- result =
- rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
- }
+ Value extract =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2413a4126f3f7..074c2d5664f64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -920,17 +920,8 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Only float and integer element types are supported.
- Value scalar;
- if (vecType.getRank() == 0) {
- // vector.extract does not support vector<f32> etc., so use
- // vector.extractelement instead.
- scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
- xferOp.getVector());
- } else {
- SmallVector<int64_t> pos(vecType.getRank(), 0);
- scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
- xferOp.getVector(), pos);
- }
+ Value scalar =
+ rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
// Construct a scalar store.
if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..52a2224b963f2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -187,7 +187,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8bb6593d99058 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2658,7 +2658,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
// CHECK-LABEL: func.func @fold_0d_vector_reduction
func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
- // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
// CHECK-NEXT: return %[[RES]] : f32
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
return %0 : f32
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index b4ebb14b8829e..52b0fdee184f6 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -45,9 +45,7 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
// CHECK-LABEL: func @transfer_write_0d(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
%0 = vector.broadcast %f : f32 to vector<f32>
vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -69,9 +67,7 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
// CHECK-LABEL: func @tensor_transfer_write_0d(
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-// CHECK: %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
// CHECK: return %[[r]]
func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
%0 = vector.broadcast %f : f32 to vector<f32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ab30acf68b30b..ef32f8c6a1cdb 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -117,7 +117,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
// CHECK-LABEL: func.func @shape_cast_0d1d(
// CHECK-SAME: %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK: %[[EXTRACT0:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
+// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
// CHECK: return %[[RES]] : vector<1xf32>
// CHECK: }
@@ -131,7 +131,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
// CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK: %[[RES:.*]] = vector.insertelement %[[EXTRACT0]], %[[UB]][] : vector<f32>
+// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
// CHECK: return %[[RES]] : vector<f32>
// CHECK: }
More information about the Mlir-commits
mailing list