[Mlir-commits] [mlir] 3a51920 - [mlir][VectorOps] Implement canonicalization for TransposeOp.
Alex Grosul
llvmlistbot at llvm.org
Thu Apr 2 18:37:28 PDT 2020
Author: Alex Grosul
Date: 2020-04-02T18:36:40-07:00
New Revision: 3a5192098c5e0d6d5ef8b74b233ef65996288c11
URL: https://github.com/llvm/llvm-project/commit/3a5192098c5e0d6d5ef8b74b233ef65996288c11
DIFF: https://github.com/llvm/llvm-project/commit/3a5192098c5e0d6d5ef8b74b233ef65996288c11.diff
LOG: [mlir][VectorOps] Implement canonicalization for TransposeOp.
Two back-to-back transpose operations are combined into a single transpose, which uses a combination of their permutation vectors.
Differential Revision: https://reviews.llvm.org/D77331
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index a0ad88347bd9..aac766f38a9c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1320,6 +1320,7 @@ def Vector_TransposeOp :
let assemblyFormat = [{
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5084bc6ece0c..19aa52f8a461 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1566,6 +1566,55 @@ static LogicalResult verify(TransposeOp op) {
return success();
}
+namespace {
+
+// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
+class TransposeFolder final : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern<TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ // Wrapper around TransposeOp::getTransp() for cleaner code.
+ auto getPermutation = [](TransposeOp transpose) {
+ SmallVector<int64_t, 4> permutation;
+ transpose.getTransp(permutation);
+ return permutation;
+ };
+
+ // Composes two permutations: result[i] = permutation1[permutation2[i]].
+ auto composePermutations = [](ArrayRef<int64_t> permutation1,
+ ArrayRef<int64_t> permutation2) {
+ SmallVector<int64_t, 4> result;
+ for (auto index : permutation2)
+ result.push_back(permutation1[index]);
+ return result;
+ };
+
+ // Return if the input of 'transposeOp' is not defined by another transpose.
+ TransposeOp parentTransposeOp =
+ dyn_cast_or_null<TransposeOp>(transposeOp.vector().getDefiningOp());
+ if (!parentTransposeOp)
+ return failure();
+
+ SmallVector<int64_t, 4> permutation = composePermutations(
+ getPermutation(parentTransposeOp), getPermutation(transposeOp));
+ // Replace 'transposeOp' with a new transpose operation.
+ rewriter.replaceOpWithNewOp<TransposeOp>(
+ transposeOp, transposeOp.getResult().getType(),
+ parentTransposeOp.vector(),
+ vector::getVectorSubscriptAttr(rewriter, permutation));
+ return success();
+ }
+};
+
+} // end anonymous namespace
+
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<TransposeFolder>(context);
+}
+
void TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(transp(), results);
}
@@ -1704,7 +1753,8 @@ void CreateMaskOp::getCanonicalizationPatterns(
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
+ patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder,
+ TransposeFolder>(context);
}
namespace mlir {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3ff4052a22b8..61163a704d64 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -125,34 +125,37 @@ func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
// CHECK-LABEL: transpose_2D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<3x4xf32> {
+func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
// CHECK-NOT: transpose
- %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
- // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [1, 0]
- %1 = vector.transpose %0, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
- // CHECK-NOT: transpose
- %2 = vector.transpose %1, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
- // CHECK: [[ADD:%.*]] = addf [[T1]], [[T1]]
- %4 = addf %1, %2 : vector<3x4xf32>
+ %0 = vector.transpose %arg, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
+ %1 = vector.transpose %0, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
+ %2 = vector.transpose %1, [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+ %3 = vector.transpose %2, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+ // CHECK: [[ADD:%.*]] = addf [[ARG]], [[ARG]]
+ %4 = addf %2, %3 : vector<4x3xf32>
// CHECK-NEXT: return [[ADD]]
- return %4 : vector<3x4xf32>
+ return %4 : vector<4x3xf32>
}
// -----
// CHECK-LABEL: transpose_3D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
-func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<2x3x4xf32> {
- // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [1, 2, 0]
+func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
+ // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [2, 1, 0]
%0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32>
+ %1 = vector.transpose %0, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
// CHECK-NOT: transpose
- %1 = vector.transpose %0, [0, 1, 2] : vector<3x2x4xf32> to vector<3x2x4xf32>
- // CHECK: [[T2:%.*]] = vector.transpose [[T0]], [1, 0, 2]
- %2 = vector.transpose %1, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
- // CHECK: [[ADD:%.*]] = addf [[T2]], [[T2]]
- %3 = addf %2, %2 : vector<2x3x4xf32>
+ %2 = vector.transpose %1, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+ %3 = vector.transpose %2, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+ // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T0]]
+ %4 = mulf %1, %3 : vector<2x3x4xf32>
+ // CHECK: [[T5:%.*]] = vector.transpose [[MUL]], [2, 1, 0]
+ %5 = vector.transpose %4, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
// CHECK-NOT: transpose
- %4 = vector.transpose %3, [0, 1, 2] : vector<2x3x4xf32> to vector<2x3x4xf32>
+ %6 = vector.transpose %3, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+ // CHECK: [[ADD:%.*]] = addf [[T5]], [[ARG]]
+ %7 = addf %5, %6 : vector<4x3x2xf32>
// CHECK-NEXT: return [[ADD]]
- return %4 : vector<2x3x4xf32>
+ return %7 : vector<4x3x2xf32>
}
More information about the Mlir-commits
mailing list