[Mlir-commits] [mlir] e54236d - [mlir][vector] Cast away leading one dims for insert ops
Lei Zhang
llvmlistbot at llvm.org
Thu Apr 14 06:00:51 PDT 2022
Author: Lei Zhang
Date: 2022-04-14T08:57:32-04:00
New Revision: e54236dfb5982bc8358bad62a27e6048f06a0272
URL: https://github.com/llvm/llvm-project/commit/e54236dfb5982bc8358bad62a27e6048f06a0272
DIFF: https://github.com/llvm/llvm-project/commit/e54236dfb5982bc8358bad62a27e6048f06a0272.diff
LOG: [mlir][vector] Cast away leading one dims for insert ops
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D123621
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index d555c60439f71..0688a405491ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
@@ -37,7 +38,7 @@ static SmallVector<int64_t> splatZero(int64_t rank) {
namespace {
// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
+// input by inserting vector.broadcast.
struct CastAwayExtractStridedSliceLeadingOneDim
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
@@ -84,8 +85,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
}
};
-// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// inputs by inserting vector.shape_cast.
+// Casts away leading one dimensions in vector.insert_strided_slice's vector
+// inputs by inserting vector.broadcast.
struct CastAwayInsertStridedSliceLeadingOneDim
: public OpRewritePattern<vector::InsertStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
@@ -125,6 +126,61 @@ struct CastAwayInsertStridedSliceLeadingOneDim
}
};
+// Casts away leading one dimensions in vector.insert's vector inputs by
+// inserting vector.broadcast.
+struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ Type oldSrcType = insertOp.getSourceType();
+ Type newSrcType = oldSrcType;
+ int64_t oldSrcRank = 0, newSrcRank = 0;
+ if (auto type = oldSrcType.dyn_cast<VectorType>()) {
+ newSrcType = trimLeadingOneDims(type);
+ oldSrcRank = type.getRank();
+ newSrcRank = newSrcType.cast<VectorType>().getRank();
+ }
+
+ VectorType oldDstType = insertOp.getDestVectorType();
+ VectorType newDstType = trimLeadingOneDims(oldDstType);
+
+ int64_t srcDropCount = oldSrcRank - newSrcRank;
+ int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+ if (srcDropCount == 0 && dstDropCount == 0)
+ return failure();
+
+ // Trim leading one dimensions from both operands.
+ Location loc = insertOp.getLoc();
+
+ Value newSrcVector = insertOp.getSource();
+ if (oldSrcRank != 0) {
+ newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.getSource(), splatZero(srcDropCount));
+ }
+ Value newDstVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.getDest(), splatZero(dstDropCount));
+
+ unsigned oldPosRank = insertOp.getPosition().getValue().size();
+ unsigned newPosRank = newDstType.getRank() - newSrcRank;
+ SmallVector<Attribute> newPositions = llvm::to_vector(
+ insertOp.getPosition().getValue().take_back(newPosRank));
+ if (newPosRank > oldPosRank) {
+ auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
+ newPositions.resize(newPosRank, zeroAttr);
+ }
+
+ auto newInsertOp = rewriter.create<vector::InsertOp>(
+ loc, newDstType, newSrcVector, newDstVector,
+ rewriter.getArrayAttr(newPositions));
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
+ newInsertOp);
+
+ return success();
+ }
+};
+
// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
@@ -383,7 +439,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns) {
patterns
.add<CastAwayExtractStridedSliceLeadingOneDim,
- CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext());
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 70beb0fe43c5f..9115eda439cf4 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -265,3 +265,42 @@ func @cast_away_elementwise_leading_one_dims(
return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
}
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
+// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<1x1x4xf32>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[BCAST]]
+func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+ %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32>
+ return %0: vector<1x1x4xf32>
+}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
+// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[BCAST]]
+func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+ %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0: vector<1x1x4xf32>
+}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
+// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[BCAST]]
+func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+ %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32>
+ return %0: vector<1x1x4xf32>
+}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
+// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
+// CHECK: return %[[INSERT]]
+func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> {
+ %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
+ return %0: vector<8x1x4xf32>
+}
More information about the Mlir-commits
mailing list