[Mlir-commits] [mlir] 3a51920 - [mlir][VectorOps] Implement canonicalization for TransposeOp.

Alex Grosul llvmlistbot at llvm.org
Thu Apr 2 18:37:28 PDT 2020


Author: Alex Grosul
Date: 2020-04-02T18:36:40-07:00
New Revision: 3a5192098c5e0d6d5ef8b74b233ef65996288c11

URL: https://github.com/llvm/llvm-project/commit/3a5192098c5e0d6d5ef8b74b233ef65996288c11
DIFF: https://github.com/llvm/llvm-project/commit/3a5192098c5e0d6d5ef8b74b233ef65996288c11.diff

LOG: [mlir][VectorOps] Implement canonicalization for TransposeOp.

Two back-to-back transpose operations are combined into a single transpose, which uses a combination of their permutation vectors.

Differential Revision: https://reviews.llvm.org/D77331

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index a0ad88347bd9..aac766f38a9c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1320,6 +1320,7 @@ def Vector_TransposeOp :
   let assemblyFormat = [{
     $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
   }];
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5084bc6ece0c..19aa52f8a461 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1566,6 +1566,55 @@ static LogicalResult verify(TransposeOp op) {
   return success();
 }
 
+namespace {
+
+// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
+class TransposeFolder final : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern<TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    // Wrapper around TransposeOp::getTransp() for cleaner code.
+    auto getPermutation = [](TransposeOp transpose) {
+      SmallVector<int64_t, 4> permutation;
+      transpose.getTransp(permutation);
+      return permutation;
+    };
+
+    // Composes two permutations: result[i] = permutation1[permutation2[i]].
+    auto composePermutations = [](ArrayRef<int64_t> permutation1,
+                                  ArrayRef<int64_t> permutation2) {
+      SmallVector<int64_t, 4> result;
+      for (auto index : permutation2)
+        result.push_back(permutation1[index]);
+      return result;
+    };
+
+    // Return if the input of 'transposeOp' is not defined by another transpose.
+    TransposeOp parentTransposeOp =
+        dyn_cast_or_null<TransposeOp>(transposeOp.vector().getDefiningOp());
+    if (!parentTransposeOp)
+      return failure();
+
+    SmallVector<int64_t, 4> permutation = composePermutations(
+        getPermutation(parentTransposeOp), getPermutation(transposeOp));
+    // Replace 'transposeOp' with a new transpose operation.
+    rewriter.replaceOpWithNewOp<TransposeOp>(
+        transposeOp, transposeOp.getResult().getType(),
+        parentTransposeOp.vector(),
+        vector::getVectorSubscriptAttr(rewriter, permutation));
+    return success();
+  }
+};
+
+} // end anonymous namespace
+
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<TransposeFolder>(context);
+}
+
 void TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
   populateFromInt64AttrArray(transp(), results);
 }
@@ -1704,7 +1753,8 @@ void CreateMaskOp::getCanonicalizationPatterns(
 
 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
+  patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder,
+                  TransposeFolder>(context);
 }
 
 namespace mlir {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3ff4052a22b8..61163a704d64 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -125,34 +125,37 @@ func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
 
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<3x4xf32> {
+func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
   // CHECK-NOT: transpose
-  %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
-  // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [1, 0]
-  %1 = vector.transpose %0, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
-  // CHECK-NOT: transpose
-  %2 = vector.transpose %1, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
-  // CHECK: [[ADD:%.*]] = addf [[T1]], [[T1]]
-  %4 = addf %1, %2 : vector<3x4xf32>
+  %0 = vector.transpose %arg, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
+  %1 = vector.transpose %0, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
+  %2 = vector.transpose %1, [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+  %3 = vector.transpose %2, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+  // CHECK: [[ADD:%.*]] = addf [[ARG]], [[ARG]]
+  %4 = addf %2, %3 : vector<4x3xf32>
   // CHECK-NEXT: return [[ADD]]
-  return %4 : vector<3x4xf32>
+  return %4 : vector<4x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: transpose_3D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
-func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<2x3x4xf32> {
-  // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [1, 2, 0]
+func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
+  // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [2, 1, 0]
   %0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32>
+  %1 = vector.transpose %0, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
   // CHECK-NOT: transpose
-  %1 = vector.transpose %0, [0, 1, 2] : vector<3x2x4xf32> to vector<3x2x4xf32>
-  // CHECK: [[T2:%.*]] = vector.transpose [[T0]], [1, 0, 2]
-  %2 = vector.transpose %1, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
-  // CHECK: [[ADD:%.*]] = addf [[T2]], [[T2]]
-  %3 = addf %2, %2 : vector<2x3x4xf32>
+  %2 = vector.transpose %1, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+  %3 = vector.transpose %2, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+  // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T0]]
+  %4 = mulf %1, %3 : vector<2x3x4xf32>
+  // CHECK: [[T5:%.*]] = vector.transpose [[MUL]], [2, 1, 0]
+  %5 = vector.transpose %4, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
   // CHECK-NOT: transpose
-  %4 = vector.transpose %3, [0, 1, 2] : vector<2x3x4xf32> to vector<2x3x4xf32>
+  %6 = vector.transpose %3, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+  // CHECK: [[ADD:%.*]] = addf [[T5]], [[ARG]]
+  %7 = addf %5, %6 : vector<4x3x2xf32>
   // CHECK-NEXT: return [[ADD]]
-  return %4 : vector<2x3x4xf32>
+  return %7 : vector<4x3x2xf32>
 }


        


More information about the Mlir-commits mailing list