[Mlir-commits] [mlir] 0b17336 - [mlir][Vector] Make vector.shape_cast based size-1 foldings opt-in and separate.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 15 13:21:48 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-15T21:17:57Z
New Revision: 0b17336f793108a7b10c3fa913039144ef1d0f61
URL: https://github.com/llvm/llvm-project/commit/0b17336f793108a7b10c3fa913039144ef1d0f61
DIFF: https://github.com/llvm/llvm-project/commit/0b17336f793108a7b10c3fa913039144ef1d0f61.diff
LOG: [mlir][Vector] Make vector.shape_cast based size-1 foldings opt-in and separate.
This is in prevision of dropping them altogether and using insert/extract based patterns.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D113928
Added:
mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 1671cc708110e..0ed4b8370cb1d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -56,6 +56,9 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
void populateVectorToVectorCanonicalizationPatterns(
RewritePatternSet &patterns);
+/// Collect a set of vector.shape_cast folding patterns.
+void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
+
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 80b4e606c6ff2..5b4f1abc570ed 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1215,29 +1215,12 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
namespace {
-// If extractOp is only removing unit dimensions it can be transformed to a
-// shapecast.
-class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern<ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
- if (!dstVecType || extractOp.getVectorType().getNumElements() !=
- dstVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
- extractOp.vector());
- return success();
- }
-};
-
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractToShapeCast>(context);
+ // ExtractToShapeCast is not a default canonicalization, it is opt-in by
+ // calling `populateCastAwayVectorLeadingOneDimPatterns`
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -1401,27 +1384,6 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
namespace {
-// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
-// the degenerated case where the broadcast only adds dimensions of size 1 it
-// can be replaced by a ShapeCastOp. This canonicalization checks if the total
-// number of elements is the same before and after the broadcast to detect if
-// the only change in the vector type are new dimensions of size 1.
-class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
-public:
- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
- PatternRewriter &rewriter) const override {
- auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
- if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
- srcVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(
- broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
- return success();
- }
-};
-
// Fold broadcast1(broadcast2(x)) into broadcast1(x).
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
@@ -1440,7 +1402,9 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<BroadcastToShapeCast, BroadcastFolder>(context);
+ // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
+ // calling `populateCastAwayVectorLeadingOneDimPatterns`
+ results.add<BroadcastFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -1605,31 +1569,10 @@ static LogicalResult verify(InsertOp op) {
return success();
}
-namespace {
-
-// If insertOp is only inserting unit dimensions it can be transformed to a
-// shapecast.
-class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertOp insertOp,
- PatternRewriter &rewriter) const override {
- auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
- if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
- srcVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(
- insertOp, insertOp.getDestVectorType(), insertOp.source());
- return success();
- }
-};
-
-} // namespace
-
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToShapeCast>(context);
+ // InsertToShapeCast is not a default canonicalization, it is opt-in by
+ // calling `populateCastAwayVectorLeadingOneDimPatterns`
}
// Eliminates insert operations that produce values identical to their source
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3fb6d4c50e9b4..cf40e4f272609 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1113,11 +1113,18 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
Location loc = op.getLoc();
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
- // Intended 2D/1D lowerings with better implementations.
+
+ // Special case 2D/1D lowerings with better implementations.
+ // TODO: make is ND/1D to allow generic ND->1D->MD.
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
return failure();
+
+ // Generic ShapeCast lowering path goes all the way down to unrolled scalar
+ // extract/insert chains.
+ // TODO: consider evolving the semantics to only allow 1D source or dest and
+ // drop this potentially very expensive lowering.
// Compute number of elements involved in the reshape.
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
@@ -3177,6 +3184,63 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
}
};
+// If extractOp is only removing unit dimensions it can be transformed to a
+// shapecast.
+class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
+ if (!dstVecType || extractOp.getVectorType().getNumElements() !=
+ dstVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
+ extractOp.vector());
+ return success();
+ }
+};
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// shapecast.
+class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+ srcVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(
+ insertOp, insertOp.getDestVectorType(), insertOp.source());
+ return success();
+ }
+};
+
+// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
+// the degenerated case where the broadcast only adds dimensions of size 1 it
+// can be replaced by a ShapeCastOp. This canonicalization checks if the total
+// number of elements is the same before and after the broadcast to detect if
+// the only change in the vector type are new dimensions of size 1.
+class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
+public:
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
+ srcVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(
+ broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
+ return success();
+ }
+};
+
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
@@ -3651,16 +3715,21 @@ void mlir::vector::populatePropagateVectorDistributionPatterns(
patterns.getContext());
}
+void mlir::vector::populateShapeCastFoldingPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ShapeCastOpFolder>(patterns.getContext());
+}
+
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns) {
- patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
- CastAwayInsertStridedSliceLeadingOneDim,
- CastAwayTransferReadLeadingOneDim,
- CastAwayTransferWriteLeadingOneDim,
- CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
- CastAwayBroadcastLeadingOneDim<SplatOp>,
- CastAwayElementwiseLeadingOneDim, ShapeCastOpFolder>(
- patterns.getContext());
+ patterns.add<
+ BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
+ CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
+ CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
+ ExtractToShapeCast, InsertToShapeCast>(patterns.getContext());
+ populateShapeCastFoldingPatterns(patterns);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index cf05308e8129b..3d60745b8ccd7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -717,16 +717,6 @@ func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
// -----
-// CHECK-LABEL: broadcast_to_shapecast
-// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
-// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
-func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
- %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
- return %0 : vector<1x4x4xf16>
-}
-
-// -----
-
// CHECK-LABEL: func @dead_transfer_op
// CHECK-NOT: vector.transfer_read
// CHECK-NOT: vector.transfer_write
@@ -971,20 +961,6 @@ func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
// -----
-// CHECK-LABEL: func @insert_extract_to_shapecast
-// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
-// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
-// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
- %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
- %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
- %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
- return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @transfer_read_of_extract_slice(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
diff --git a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
new file mode 100644
index 0000000000000..5dd44d38ccb7e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
+
+// CHECK-LABEL: broadcast_to_shapecast
+// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
+// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
+func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
+ %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
+ return %0 : vector<1x4x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shapecast
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
+ %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+ %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+ %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
More information about the Mlir-commits
mailing list