[Mlir-commits] [mlir] f0c93fd - [mlir][vector] Merge accumulator/result transpose into contract

Lei Zhang llvmlistbot at llvm.org
Fri Oct 7 17:47:33 PDT 2022


Author: Lei Zhang
Date: 2022-10-08T00:43:45Z
New Revision: f0c93fd4cac76d9ae1b2ff99c40c08abeb940f21

URL: https://github.com/llvm/llvm-project/commit/f0c93fd4cac76d9ae1b2ff99c40c08abeb940f21
DIFF: https://github.com/llvm/llvm-project/commit/f0c93fd4cac76d9ae1b2ff99c40c08abeb940f21.diff

LOG: [mlir][vector] Merge accumulator/result transpose into contract

This commit adds a pattern to merge accumulator and result
`vector.transpose` ops into `vector.contract`. This kind of
pattern can be generated for NCHW convolution vectorization,
where we use transposes to convert the 1-D NCW convolution
into NWC during vectorization. Merging the transpose would
mean we can avoid materialize vector extract/insert for
transposes and it makes further vector level transformations
easier.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D135496

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f1644d01fe70b..570b4b2ee001b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1010,7 +1010,7 @@ struct MultiReduceToContract
   }
 };
 
-/// Merge TransposeOp into ContractionOp user.
+/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
 /// Ex:
 /// ```
 ///   %0 = vector.transpose %arg0, [2, 0, 1]
@@ -1033,7 +1033,7 @@ struct MultiReduceToContract
 ///    kind = add} %arg0, %arg1, %cst_f0
 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 ///  ```
-struct CombineContractTranspose
+struct CombineContractABTranspose final
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -1050,8 +1050,6 @@ struct CombineContractTranspose
       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
       if (!transposeOp)
         continue;
-      SmallVector<int64_t> perm;
-      transposeOp.getTransp(perm);
       AffineMap permutationMap = AffineMap::getPermutationMap(
           extractVector<unsigned>(transposeOp.getTransp()),
           contractOp.getContext());
@@ -1068,6 +1066,81 @@ struct CombineContractTranspose
   }
 };
 
+/// Merges accumulator and result transposes into contract.
+///
+/// For example:
+/// ```mlir
+/// %accT = vector.transpose %acc, [0, 2, 1]
+///   : vector<2x8x4xf32> to vector<2x4x8xf32>
+/// %contract = vector.contract {
+///   indexing_maps = [
+///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+///     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+///   ],
+///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+///   kind = #vector.kind<add>
+/// } %lhs, %rhs, %accT
+///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
+/// %0 = vector.transpose %contract, [0, 2, 1]
+///   : vector<2x4x8xf32> to vector<2x8x4>
+/// ```
+/// Becomes:
+/// ```mlir
+/// %0 = vector.contract {
+///   indexing_maps = [
+///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+///     affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+///   ],
+///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+///   kind = #vector.kind<add>
+/// } %lhs, %rhs, %acc
+///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
+/// ```
+struct CombineContractResultTranspose final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
+                                PatternRewriter &rewriter) const override {
+    auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
+    if (!contractOp || !contractOp->hasOneUse())
+      return failure();
+
+    auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
+    if (!accTOp)
+      return failure();
+
+    MLIRContext *context = contractOp.getContext();
+    auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
+    AffineMap contractMap = maps.back();
+
+    // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
+    // To index into A in contract, we need revert(f)(g(C)) -> A.
+    auto accTMap = AffineMap::getPermutationMap(
+        extractVector<unsigned>(accTOp.getTransp()), context);
+
+    // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
+    // To index into E in contract, we need h(g(C)) -> E.
+    auto resTMap = AffineMap::getPermutationMap(
+        extractVector<unsigned>(resTOp.getTransp()), context);
+    auto combinedResMap = resTMap.compose(contractMap);
+
+    // The accumulator and result share the same indexing map. So they should be
+    // the same to be able to merge. This means combinedResMap is the same as
+    // inversePermutation(accTMap).compose(contractMap), which means
+    if (inversePermutation(accTMap) != resTMap)
+      return failure();
+    maps.back() = combinedResMap;
+
+    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+        resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
+        rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
+    return success();
+  }
+};
+
 /// Merge BroadcastOp into ContractionOp user.
 /// Ex:
 /// ```
@@ -1233,7 +1306,7 @@ struct ReorderCastOpsOnBroadcast
 
 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
 /// transpose ops and contraction ops closer, which kicks in
-/// CombineContractTranspose pattern when elementwise ops are between these
+/// CombineContractABTranspose pattern when elementwise ops are between these
 /// operations. Ex:
 /// ```
 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
@@ -2939,9 +3012,9 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
-               CombineContractTranspose, ReorderCastOpsOnBroadcast,
-               ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
-                                                 benefit);
+               CombineContractABTranspose, CombineContractResultTranspose,
+               ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 87b30e6116b03..23a44b7c03f8f 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -356,3 +356,32 @@ func.func @transpose_elementwise_
diff _map(%a : vector<4x6x3x2xf32>, %b: vector<6
   %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
   return %r : vector<6x4x2x3xf32>
 }
+
+// -----
+
+// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+
+// CHECK-LABEL: func.func @contract_result_transpose
+//  CHECK-SAME: (%[[LHS:.+]]: vector<2x4x4xf32>, %[[RHS:.+]]: vector<4x8xf32>, %[[ACC:.+]]: vector<2x8x4xf32>)
+//       CHECK:   %[[CONTRACT:.+]] = vector.contract
+//  CHECK-SAME:     indexing_maps = [#[[$LHS_MAP]], #[[$RHS_MAP]], #[[$ACC_MAP]]]
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+//  CHECK-SAME:     kind = #vector.kind<add>
+//  CHECK-SAME:     %[[LHS]], %[[RHS]], %[[ACC]]
+//       CHECK:   return %[[CONTRACT]]
+func.func @contract_result_transpose(%lhs : vector<2x4x4xf32>, %rhs: vector<4x8xf32>, %acc: vector<2x8x4xf32>) -> vector<2x8x4xf32> {
+  %accT = vector.transpose %acc, [0, 2, 1] : vector<2x8x4xf32> to vector<2x4x8xf32>
+  %contract = vector.contract {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+      affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>
+  } %lhs, %rhs, %accT : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
+  %resT = vector.transpose %contract, [0, 2, 1] : vector<2x4x8xf32> to vector<2x8x4xf32>
+  return %resT : vector<2x8x4xf32>
+}


        


More information about the Mlir-commits mailing list