[Mlir-commits] [mlir] [mlir][Vector] Remove usage of `vector.insertelement/extractelement` from Vector (PR #144413)
Diego Caballero
llvmlistbot at llvm.org
Mon Jun 16 12:19:21 PDT 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/144413
This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops
It removes instances of `vector.extractelement` and `vector.insertelement` from the Vector dialect layer.
>From 75ae4a467d6282cd7098df66c0587bccfb885c62 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Sat, 7 Jun 2025 13:33:42 +0000
Subject: [PATCH] [mlir][Vector] Remove usage of
`vector.insertelement/extractelement` from Vector
This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops
It removes instances of `vector.extractelement` and `vector.insertelement`
from the Vector dialect layer.
---
.../Vector/Transforms/VectorRewritePatterns.h | 4 +-
.../Conversion/VectorToSCF/VectorToSCF.cpp | 13 +-
.../Vector/Transforms/VectorDistribute.cpp | 51 +-------
.../Transforms/VectorTransferOpTransforms.cpp | 119 ++++--------------
.../Conversion/VectorToSCF/vector-to-scf.mlir | 14 +--
.../scalar-vector-transfer-to-memref.mlir | 6 +-
.../Vector/vector-warp-distribute.mlir | 48 +++----
7 files changed, 71 insertions(+), 184 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..ec0f856cb3f5a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -233,8 +233,8 @@ void populateBreakDownVectorReductionPatterns(
///
/// [DecomposeNDExtractStridedSlice]
/// ================================
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..45059f19a95c4 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+ return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -760,8 +760,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
- // non-constant value (which can currently only be done via
- // vector.extractelement for 1D vectors).
+ // non-constant value.
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
@@ -824,8 +823,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}
// Print the scalar elements in the inner most loop.
- auto element =
- rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+ auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);
@@ -1567,7 +1565,7 @@ struct Strategy1d<TransferReadOp> {
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
- return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+ return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1595,8 +1593,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val =
- b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+ auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
});
b.create<scf::YieldOp>(loc);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..90970ae53defc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1255,27 +1255,6 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
-/// Pattern to convert vector.extractelement to vector.extract.
-struct WarpOpExtractElement : public WarpDistributionPattern {
- using Base::Base;
- LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
- PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
- if (!operand)
- return failure();
- auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
- SmallVector<OpFoldResult> indices;
- if (auto pos = extractOp.getPosition()) {
- indices.push_back(pos);
- }
- rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, extractOp.getVector(), indices);
- return success();
- }
-};
-
/// Pattern to move out vector.insert with a scalar input.
/// Only supports 1-D and 0-D destinations for now.
struct WarpOpInsertScalar : public WarpDistributionPattern {
@@ -1483,26 +1462,6 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};
-struct WarpOpInsertElement : public WarpDistributionPattern {
- using Base::Base;
- LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
- PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
- if (!operand)
- return failure();
- auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
- SmallVector<OpFoldResult> indices;
- if (auto pos = insertOp.getPosition()) {
- indices.push_back(pos);
- }
- rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(), indices);
- return success();
- }
-};
-
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
@@ -1761,11 +1720,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 384717aeca665..62e7f7cc61f6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -767,23 +767,26 @@ class FlattenContiguousRowMajorTransferWritePattern
unsigned targetVectorBitwidth;
};
-/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
-/// to `memref.load` patterns. The `match` method is shared for both
-/// `vector.extract` and `vector.extract_element`.
-template <class VectorExtractOp>
-class RewriteScalarExtractOfTransferReadBase
- : public OpRewritePattern<VectorExtractOp> {
- using Base = OpRewritePattern<VectorExtractOp>;
-
+/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
+///
+/// All the users of the transfer op must be `vector.extract` ops. If
+/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
+/// users. Otherwise, rewrite only if the extract op is the single user of the
+/// transfer op. Rewriting a single vector load with multiple scalar loads may
+/// negatively affect performance.
+class RewriteScalarExtractOfTransferRead
+ : public OpRewritePattern<vector::ExtractOp> {
public:
- RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
- PatternBenefit benefit,
- bool allowMultipleUses)
- : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
-
- LogicalResult match(VectorExtractOp extractOp) const {
- auto xferOp =
- extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
+ RewriteScalarExtractOfTransferRead(MLIRContext *context,
+ PatternBenefit benefit,
+ bool allowMultipleUses)
+ : OpRewritePattern(context, benefit),
+ allowMultipleUses(allowMultipleUses) {}
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Match phase.
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
// Check that we are extracting a scalar and not a sub-vector.
@@ -795,8 +798,7 @@ class RewriteScalarExtractOfTransferReadBase
// If multiple uses are allowed, check if all the xfer uses are extract ops.
if (allowMultipleUses &&
!llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
- return isa<vector::ExtractOp, vector::ExtractElementOp>(
- use.getOwner());
+ return isa<vector::ExtractOp>(use.getOwner());
}))
return failure();
// Mask not supported.
@@ -808,81 +810,8 @@ class RewriteScalarExtractOfTransferReadBase
// Cannot rewrite if the indices may be out of bounds.
if (xferOp.hasOutOfBoundsDim())
return failure();
- return success();
- }
-
-private:
- bool allowMultipleUses;
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractElementOfTransferRead
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
- using RewriteScalarExtractOfTransferReadBase::
- RewriteScalarExtractOfTransferReadBase;
-
- LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
- PatternRewriter &rewriter) const override {
- if (failed(match(extractOp)))
- return failure();
-
- // Construct scalar load.
- auto loc = extractOp.getLoc();
- auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
- SmallVector<Value> newIndices(xferOp.getIndices().begin(),
- xferOp.getIndices().end());
- if (extractOp.getPosition()) {
- AffineExpr sym0, sym1;
- bindSymbols(extractOp.getContext(), sym0, sym1);
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, sym0 + sym1,
- {newIndices[newIndices.size() - 1], extractOp.getPosition()});
- if (auto value = dyn_cast<Value>(ofr)) {
- newIndices[newIndices.size() - 1] = value;
- } else {
- newIndices[newIndices.size() - 1] =
- rewriter.create<arith::ConstantIndexOp>(loc,
- *getConstantIntValue(ofr));
- }
- }
- if (isa<MemRefType>(xferOp.getBase().getType())) {
- rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
- newIndices);
- } else {
- rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
- extractOp, xferOp.getBase(), newIndices);
- }
-
- return success();
- }
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractOfTransferRead
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
- using RewriteScalarExtractOfTransferReadBase::
- RewriteScalarExtractOfTransferReadBase;
-
- LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- if (failed(match(extractOp)))
- return failure();
- // Construct scalar load.
- auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+ // Rewrite phase: construct scalar load.
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
@@ -923,6 +852,9 @@ class RewriteScalarExtractOfTransferRead
return success();
}
+
+private:
+ bool allowMultipleUses;
};
/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -979,8 +911,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
bool allowMultipleUses) {
- patterns.add<RewriteScalarExtractElementOfTransferRead,
- RewriteScalarExtractOfTransferRead>(patterns.getContext(),
+ patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
benefit, allowMultipleUses);
patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 5a6da3a06387a..33177736eb5fe 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -37,7 +37,7 @@ func.func @materialize_read_1d() {
// Both accesses in the load must be clipped otherwise %i1 + 2 and %i1 + 3 will go out of bounds.
// CHECK: scf.if
// CHECK-NEXT: memref.load
- // CHECK-NEXT: vector.insertelement
+ // CHECK-NEXT: vector.insert
// CHECK-NEXT: scf.yield
// CHECK-NEXT: else
// CHECK-NEXT: scf.yield
@@ -103,7 +103,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK: %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
// CHECK: scf.if {{.*}} -> (vector<3xf32>) {
// CHECK-NEXT: %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
- // CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
+ // CHECK-NEXT: %[[RVEC:.*]] = vector.insert %[[SCAL]], %{{.*}} [%[[I6]]] : f32 into vector<3xf32>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield
@@ -540,9 +540,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
// CHECK: scf.for %[[IDX:.*]] = %[[C_0]] to %[[UB]] step %[[STEP]] {
-// CHECK: %[[MASK_VAL:.*]] = vector.extractelement %[[MASK_VEC]][%[[IDX]] : index] : vector<[16]xi1>
+// CHECK: %[[MASK_VAL:.*]] = vector.extract %[[MASK_VEC]][%[[IDX]]] : i1 from vector<[16]xi1>
// CHECK: scf.if %[[MASK_VAL]] {
-// CHECK: %[[VAL_TO_STORE:.*]] = vector.extractelement %{{.*}}[%[[IDX]] : index] : vector<[16]xf32>
+// CHECK: %[[VAL_TO_STORE:.*]] = vector.extract %{{.*}}[%[[IDX]]] : f32 from vector<[16]xf32>
// CHECK: memref.store %[[VAL_TO_STORE]], %[[ARG_0]][%[[IDX]]] : memref<?xf32, strided<[?], offset: ?>>
// CHECK: } else {
// CHECK: }
@@ -561,7 +561,7 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK: %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[IDX]] : index] : vector<1xf32>
+// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
@@ -591,7 +591,7 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: %[[OUTER_INDEX:.*]] = arith.muli %[[I]], %[[C2]] : index
// CHECK: %[[FLAT_INDEX:.*]] = arith.addi %[[J]], %[[OUTER_INDEX]] : index
-// CHECK: %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]] : index] : vector<4xf32>
+// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[FLAT_INDEX]]] : f32 from vector<4xf32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST_J:.*]] = arith.cmpi ult, %[[J]], %[[C1]] : index
// CHECK: scf.if %[[IS_NOT_LAST_J]] {
@@ -625,7 +625,7 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
// CHECK: %[[LAST_INDEX:.*]] = arith.subi %[[UPPER_BOUND]], %[[C1]] : index
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[UPPER_BOUND]] step %[[C1]] {
-// CHECK: %[[EL:.*]] = vector.extractelement %[[VEC]]{{\[}}%[[IDX]] : index] : vector<[4]xi32>
+// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][%[[IDX]]] : i32 from vector<[4]xi32>
// CHECK: vector.print %[[EL]] : i32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[LAST_INDEX]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 7a1d6b3a8344a..7fec1c6ba5642 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -8,7 +8,7 @@
func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
- %1 = vector.extractelement %0[] : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
return %1 : f32
}
@@ -24,7 +24,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
%cst = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
- %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+ %1 = vector.extract %0[%idx2] : f32 from vector<5xf32>
return %1 : f32
}
@@ -37,7 +37,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
- %1 = vector.extractelement %0[] : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
return %1 : f32
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..1dc0d5c7f9e1c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -663,7 +663,7 @@ func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref
gpu.warp_execute_on_lane_0(%laneid)[32] {
%0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32>
%1 = vector.transfer_read %m1[], %f0 : memref<f32>, vector<f32>
- %2 = vector.extractelement %1[] : vector<f32>
+ %2 = vector.extract %1[] : f32 from vector<f32>
%3 = vector.reduction <add>, %0 : vector<32xf32> into f32
%4 = arith.addf %3, %2 : f32
%5 = vector.broadcast %4 : f32 to vector<f32>
@@ -868,17 +868,17 @@ func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) {
// -----
-// CHECK-PROP-LABEL: func.func @vector_extractelement_0d(
+// CHECK-PROP-LABEL: func.func @vector_extract_0d(
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<f32>
// CHECK-PROP: gpu.yield %[[V]] : vector<f32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][] : f32 from vector<f32>
// CHECK-PROP: return %[[E]] : f32
-func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
+func.func @vector_extract_0d(%laneid: index) -> (f32) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
%0 = "some_def"() : () -> (vector<f32>)
- %1 = vector.extractelement %0[] : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
gpu.yield %1 : f32
}
return %r : f32
@@ -886,18 +886,18 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
// -----
-// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
+// CHECK-PROP-LABEL: func.func @vector_extract_1element(
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<1xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : f32 from vector<1xf32>
// CHECK-PROP: return %[[E]] : f32
-func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
+func.func @vector_extract_1element(%laneid: index) -> (f32) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
%0 = "some_def"() : () -> (vector<1xf32>)
%c0 = arith.constant 0 : index
- %1 = vector.extractelement %0[%c0 : index] : vector<1xf32>
+ %1 = vector.extract %0[%c0] : f32 from vector<1xf32>
gpu.yield %1 : f32
}
return %r : f32
@@ -907,7 +907,7 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
// CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
// CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
-// CHECK-PROP-LABEL: func.func @vector_extractelement_1d(
+// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index
// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32
// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<3xf32>) {
@@ -920,10 +920,10 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
// CHECK-PROP: %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32
// CHECK-PROP: return %[[SHUFFLED]]
-func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
+func.func @vector_extract_1d(%laneid: index, %pos: index) -> (f32) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
%0 = "some_def"() : () -> (vector<96xf32>)
- %1 = vector.extractelement %0[%pos : index] : vector<96xf32>
+ %1 = vector.extract %0[%pos] : f32 from vector<96xf32>
gpu.yield %1 : f32
}
return %r : f32
@@ -933,16 +933,16 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
// Index-typed values cannot be shuffled at the moment.
-// CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index(
+// CHECK-PROP-LABEL: func.func @vector_extract_1d_index(
// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (index) {
// CHECK-PROP: "some_def"
// CHECK-PROP: vector.extract
// CHECK-PROP: gpu.yield {{.*}} : index
// CHECK-PROP: }
-func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) {
+func.func @vector_extract_1d_index(%laneid: index, %pos: index) -> (index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (index) {
%0 = "some_def"() : () -> (vector<96xindex>)
- %1 = vector.extractelement %0[%pos : index] : vector<96xindex>
+ %1 = vector.extract %0[%pos] : index from vector<96xindex>
gpu.yield %1 : index
}
return %r : index
@@ -1142,7 +1142,7 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
// CHECK-PROP: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
-// CHECK-PROP-LABEL: func @vector_insertelement_1d(
+// CHECK-PROP-LABEL: func @_1d(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
// CHECK-PROP: %[[INSERTING_LANE:.*]] = affine.apply #[[$MAP]]()[%[[POS]]]
@@ -1155,11 +1155,11 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-PROP: scf.yield %[[W]]#0
// CHECK-PROP: }
// CHECK-PROP: return %[[R]]
-func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32>) {
+func.func @_1d(%laneid: index, %pos: index) -> (vector<3xf32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
%0 = "some_def"() : () -> (vector<96xf32>)
%f = "another_def"() : () -> (f32)
- %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+ %1 = vector.insert %f, %0[%pos] : f32 into vector<96xf32>
gpu.yield %1 : vector<96xf32>
}
return %r : vector<3xf32>
@@ -1167,18 +1167,18 @@ func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32
// -----
-// CHECK-PROP-LABEL: func @vector_insertelement_1d_broadcast(
+// CHECK-PROP-LABEL: func @_1d_broadcast(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, f32)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: gpu.yield %[[VEC]], %[[VAL]]
// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [%[[POS]]] : f32 into vector<96xf32>
-func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
+func.func @_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
%0 = "some_def"() : () -> (vector<96xf32>)
%f = "another_def"() : () -> (f32)
- %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+ %1 = vector.insert %f, %0[%pos] : f32 into vector<96xf32>
gpu.yield %1 : vector<96xf32>
}
return %r : vector<96xf32>
@@ -1186,17 +1186,17 @@ func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (ve
// -----
-// CHECK-PROP-LABEL: func @vector_insertelement_0d(
+// CHECK-PROP-LABEL: func @_0d(
// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: gpu.yield %[[VEC]], %[[VAL]]
// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector<f32>
-func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
+func.func @_0d(%laneid: index) -> (vector<f32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
%0 = "some_def"() : () -> (vector<f32>)
%f = "another_def"() : () -> (f32)
- %1 = vector.insertelement %f, %0[] : vector<f32>
+ %1 = vector.insert %f, %0[] : f32 into vector<f32>
gpu.yield %1 : vector<f32>
}
return %r : vector<f32>
@@ -1299,7 +1299,7 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
// -----
// Make sure that all operands of the transfer_read op are properly propagated.
-// The vector.extractelement op cannot be propagated because index-typed
+// The vector.extract op cannot be propagated because index-typed
// shuffles are not supported at the moment.
// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
@@ -1333,7 +1333,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1
%28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32>
%29 = vector.extract %28[0] : vector<64xi32> from vector<1x64xi32>
%30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex>
- %36 = vector.extractelement %30[%c0_i32 : index] : vector<64xindex>
+ %36 = vector.extract %30[%c0_i32] : index from vector<64xindex>
%37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32>
gpu.yield %37 : vector<64xf32>
}
More information about the Mlir-commits
mailing list