[Mlir-commits] [mlir] f0c93fd - [mlir][vector] Merge accumulator/result transpose into contract
Lei Zhang
llvmlistbot at llvm.org
Fri Oct 7 17:47:33 PDT 2022
Author: Lei Zhang
Date: 2022-10-08T00:43:45Z
New Revision: f0c93fd4cac76d9ae1b2ff99c40c08abeb940f21
URL: https://github.com/llvm/llvm-project/commit/f0c93fd4cac76d9ae1b2ff99c40c08abeb940f21
DIFF: https://github.com/llvm/llvm-project/commit/f0c93fd4cac76d9ae1b2ff99c40c08abeb940f21.diff
LOG: [mlir][vector] Merge accumulator/result transpose into contract
This commit adds a pattern to merge accumulator and result
`vector.transpose` ops into `vector.contract`. This kind of
pattern can be generated for NCHW convolution vectorization,
where we use transposes to convert the 1-D NCW convolution
into NWC during vectorization. Merging the transpose would
mean we can avoid materialize vector extract/insert for
transposes and it makes further vector level transformations
easier.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D135496
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f1644d01fe70b..570b4b2ee001b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1010,7 +1010,7 @@ struct MultiReduceToContract
}
};
-/// Merge TransposeOp into ContractionOp user.
+/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
/// Ex:
/// ```
/// %0 = vector.transpose %arg0, [2, 0, 1]
@@ -1033,7 +1033,7 @@ struct MultiReduceToContract
/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
-struct CombineContractTranspose
+struct CombineContractABTranspose final
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -1050,8 +1050,6 @@ struct CombineContractTranspose
auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
continue;
- SmallVector<int64_t> perm;
- transposeOp.getTransp(perm);
AffineMap permutationMap = AffineMap::getPermutationMap(
extractVector<unsigned>(transposeOp.getTransp()),
contractOp.getContext());
@@ -1068,6 +1066,81 @@ struct CombineContractTranspose
}
};
+/// Merges accumulator and result transposes into contract.
+///
+/// For example:
+/// ```mlir
+/// %accT = vector.transpose %acc, [0, 2, 1]
+/// : vector<2x8x4xf32> to vector<2x4x8xf32>
+/// %contract = vector.contract {
+/// indexing_maps = [
+/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+/// ],
+/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+/// kind = #vector.kind<add>
+/// } %lhs, %rhs, %accT
+/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
+/// %0 = vector.transpose %contract, [0, 2, 1]
+/// : vector<2x4x8xf32> to vector<2x8x4>
+/// ```
+/// Becomes:
+/// ```mlir
+/// %0 = vector.contract {
+/// indexing_maps = [
+/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+/// ],
+/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+/// kind = #vector.kind<add>
+/// } %lhs, %rhs, %acc
+/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
+/// ```
+struct CombineContractResultTranspose final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
+ PatternRewriter &rewriter) const override {
+ auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
+ if (!contractOp || !contractOp->hasOneUse())
+ return failure();
+
+ auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
+ if (!accTOp)
+ return failure();
+
+ MLIRContext *context = contractOp.getContext();
+ auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
+ AffineMap contractMap = maps.back();
+
+ // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
+ // To index into A in contract, we need revert(f)(g(C)) -> A.
+ auto accTMap = AffineMap::getPermutationMap(
+ extractVector<unsigned>(accTOp.getTransp()), context);
+
+ // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
+ // To index into E in contract, we need h(g(C)) -> E.
+ auto resTMap = AffineMap::getPermutationMap(
+ extractVector<unsigned>(resTOp.getTransp()), context);
+ auto combinedResMap = resTMap.compose(contractMap);
+
+ // The accumulator and result share the same indexing map. So they should be
+ // the same to be able to merge. This means combinedResMap is the same as
+ // inversePermutation(accTMap).compose(contractMap), which means
+ if (inversePermutation(accTMap) != resTMap)
+ return failure();
+ maps.back() = combinedResMap;
+
+ rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+ resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
+ rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
+ return success();
+ }
+};
+
/// Merge BroadcastOp into ContractionOp user.
/// Ex:
/// ```
@@ -1233,7 +1306,7 @@ struct ReorderCastOpsOnBroadcast
/// Reorders elementwise(transpose) to transpose(elementwise). This makes
/// transpose ops and contraction ops closer, which kicks in
-/// CombineContractTranspose pattern when elementwise ops are between these
+/// CombineContractABTranspose pattern when elementwise ops are between these
/// operations. Ex:
/// ```
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
@@ -2939,9 +3012,9 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
- CombineContractTranspose, ReorderCastOpsOnBroadcast,
- ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
- benefit);
+ CombineContractABTranspose, CombineContractResultTranspose,
+ ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
+ patterns.getContext(), benefit);
}
void mlir::vector::
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 87b30e6116b03..23a44b7c03f8f 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -356,3 +356,32 @@ func.func @transpose_elementwise_
diff _map(%a : vector<4x6x3x2xf32>, %b: vector<6
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
+
+// -----
+
+// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+
+// CHECK-LABEL: func.func @contract_result_transpose
+// CHECK-SAME: (%[[LHS:.+]]: vector<2x4x4xf32>, %[[RHS:.+]]: vector<4x8xf32>, %[[ACC:.+]]: vector<2x8x4xf32>)
+// CHECK: %[[CONTRACT:.+]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$LHS_MAP]], #[[$RHS_MAP]], #[[$ACC_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK: return %[[CONTRACT]]
+func.func @contract_result_transpose(%lhs : vector<2x4x4xf32>, %rhs: vector<4x8xf32>, %acc: vector<2x8x4xf32>) -> vector<2x8x4xf32> {
+ %accT = vector.transpose %acc, [0, 2, 1] : vector<2x8x4xf32> to vector<2x4x8xf32>
+ %contract = vector.contract {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %lhs, %rhs, %accT : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
+ %resT = vector.transpose %contract, [0, 2, 1] : vector<2x4x8xf32> to vector<2x8x4xf32>
+ return %resT : vector<2x8x4xf32>
+}
More information about the Mlir-commits
mailing list