[Mlir-commits] [mlir] 813bfff - [mlir][VectorOps] Adds canonicalization rewrite patterns for vector ShapeCastOp.
Andy Davis
llvmlistbot at llvm.org
Tue Feb 11 13:11:53 PST 2020
Author: Andy Davis
Date: 2020-02-11T13:11:45-08:00
New Revision: 813bfffec34b87d32c9c834718f660afb5275bc8
URL: https://github.com/llvm/llvm-project/commit/813bfffec34b87d32c9c834718f660afb5275bc8
DIFF: https://github.com/llvm/llvm-project/commit/813bfffec34b87d32c9c834718f660afb5275bc8.diff
LOG: [mlir][VectorOps] Adds canonicalization rewrite patterns for vector ShapeCastOp.
Summary:
Adds two rewrite patterns for the vector ShapeCastOp.
*) ShapeCastOp decomposer: decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps each on vector types.
*) ShapeCastOp folder: folds canceling shape cast ops (e.g. shape_cast A -> B followed by shape_cast B -> A) away.
Reviewers: nicolasvasilache, aartbik
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74327
Added:
Modified:
mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 440c7707ce75..8bdeb92afe5e 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -646,6 +646,90 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
}
};
+/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
+/// on vector types.
+struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ // Check if 'shapeCastOp' has tuple source/result type.
+ auto sourceTupleType =
+ shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
+ auto resultTupleType =
+ shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
+ if (!sourceTupleType || !resultTupleType)
+ return matchFailure();
+ assert(sourceTupleType.size() == resultTupleType.size());
+
+ // Create single-vector ShapeCastOp for each source tuple element.
+ Location loc = shapeCastOp.getLoc();
+ SmallVector<Value, 8> resultElements;
+ resultElements.reserve(resultTupleType.size());
+ for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
+ auto sourceElement = rewriter.create<vector::TupleGetOp>(
+ loc, sourceTupleType.getType(i), shapeCastOp.source(),
+ rewriter.getI64IntegerAttr(i));
+ resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
+ loc, resultTupleType.getType(i), sourceElement));
+ }
+
+ // Replace 'shapeCastOp' with tuple of 'resultElements'.
+ rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
+ resultElements);
+ return matchSuccess();
+ }
+};
+
+/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
+//
+// Example:
+//
+// The following MLIR with cancelling ShapeCastOps:
+//
+// %0 = source : vector<5x4x2xf32>
+// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
+// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
+// %3 = user %2 : vector<5x4x2xf32>
+//
+// Should canonicalize to the following:
+//
+// %0 = source : vector<5x4x2xf32>
+// %1 = user %0 : vector<5x4x2xf32>
+//
+struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ // Check if 'shapeCastOp' has vector source/result type.
+ auto sourceVectorType =
+ shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
+ auto resultVectorType =
+ shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
+ if (!sourceVectorType || !resultVectorType)
+ return matchFailure();
+
+ // Check if shape cast op source operand is also a shape cast op.
+ auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
+ shapeCastOp.source().getDefiningOp());
+ if (!sourceShapeCastOp)
+ return matchFailure();
+ auto operandSourceVectorType =
+ sourceShapeCastOp.source().getType().cast<VectorType>();
+ auto operandResultVectorType =
+ sourceShapeCastOp.result().getType().cast<VectorType>();
+
+ // Check if shape cast operations invert each other.
+ if (operandSourceVectorType != resultVectorType ||
+ operandResultVectorType != sourceVectorType)
+ return matchFailure();
+
+ rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
+ return matchSuccess();
+ }
+};
+
// Patter rewrite which forward tuple elements to their users.
// User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
// -> User(Producer)
@@ -784,8 +868,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
- context);
+ patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
+ SplitTransferWriteOp, TupleGetFolderOp>(context);
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir
index 1153ffb2999f..7582758384ce 100644
--- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir
@@ -346,3 +346,62 @@ func @vector_transfers_vector_element_type() {
return
}
+
+// Test that ShapeCastOp on tuple of vectors, decomposes to multiple
+// ShapeCastOps on vectors.
+// CHECK-LABEL: func @shape_cast_decomposition
+// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32>
+// CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32>
+// CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32>
+
+func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>,
+ %arg1 : vector<3x4x2xf32>)
+ -> (vector<20x2xf32>, vector<12x2xf32>) {
+ %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
+ %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
+ tuple<vector<20x2xf32>, vector<12x2xf32>>
+ %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+ %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+ return %2, %3 : vector<20x2xf32>, vector<12x2xf32>
+}
+
+// Test that cancelling ShapeCastOps are canonicalized away.
+// EX:
+//
+// The following MLIR with cancelling ShapeCastOps:
+//
+// %0 = source : vector<5x4x2xf32>
+// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
+// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
+// %3 = user %2 : vector<5x4x2xf32>
+//
+// Should canonicalize to the following:
+//
+//
+// %0 = source : vector<5x4x2xf32>
+// %1 = user %0 : vector<5x4x2xf32>
+//
+
+// ShapeCastOps on vectors.
+// CHECK-LABEL: func @shape_cast_fold
+// CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32>
+
+func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
+ -> (vector<5x4x2xf32>, vector<3x4x2xf32>) {
+ %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
+
+ %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
+ tuple<vector<20x2xf32>, vector<12x2xf32>>
+
+ %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+ %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
+
+ %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32>
+ %5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
+ tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
+
+ %6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
+ %7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
+
+ return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
+}
More information about the Mlir-commits
mailing list