[Mlir-commits] [mlir] 1353cbc - [mlir][Vector] NFC - Use matchAndRewrite in ContractionOp lowering patterns
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Aug 6 06:04:49 PDT 2020
Author: Nicolas Vasilache
Date: 2020-08-06T09:02:25-04:00
New Revision: 1353cbc2570b2fe4b418a9acea9778eca5625fb7
URL: https://github.com/llvm/llvm-project/commit/1353cbc2570b2fe4b418a9acea9778eca5625fb7
DIFF: https://github.com/llvm/llvm-project/commit/1353cbc2570b2fe4b418a9acea9778eca5625fb7.diff
LOG: [mlir][Vector] NFC - Use matchAndRewrite in ContractionOp lowering patterns
Replace the use of separate match and rewrite which unnecessarily duplicates logic.
Differential Revision: https://reviews.llvm.org/D85421
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index e6c7b7abebd5..35855b3b2137 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -208,9 +208,8 @@ class ContractionOpToMatmulOpLowering
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
- LogicalResult match(vector::ContractionOp op) const override;
- void rewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -250,9 +249,8 @@ class ContractionOpToOuterProductOpLowering
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
- LogicalResult match(vector::ContractionOp op) const override;
- void rewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 33fbed65ace6..922168947ccf 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1576,16 +1576,14 @@ namespace mlir {
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
+LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
+ vector::ContractionOp op, PatternRewriter &rewriter) const {
// TODO: implement masks
if (llvm::size(op.masks()) != 0)
return failure();
-
if (vectorTransformsOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
return failure();
-
if (failed(filter(op)))
return failure();
@@ -1598,11 +1596,10 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
if (!isRowMajorMatmul(op.indexing_maps()))
return failure();
- return success();
-}
+ Type elementType = op.getLhsType().getElementType();
+ if (!elementType.isIntOrFloat())
+ return failure();
-void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
int64_t lhsRows = lhsType.getDimSize(0);
@@ -1622,12 +1619,12 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
lhsColumns, rhsColumns);
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
mul);
- Type elementType = op.getLhsType().getElementType();
- assert(elementType.isIntOrFloat());
if (elementType.isa<IntegerType>())
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
else
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
+
+ return success();
}
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
@@ -1645,8 +1642,8 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult
-ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
+LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
+ vector::ContractionOp op, PatternRewriter &rewriter) const {
// TODO: implement masks
if (llvm::size(op.masks()) != 0)
return failure();
@@ -1658,50 +1655,6 @@ ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
if (failed(filter(op)))
return failure();
- // Determine if the parallel/reduction structure matches something
- // that can be expressed a reduction_size unrolled sequence.
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
- AffineExpr m, n, k;
- bindDims(op.getContext(), m, n, k);
- auto iteratorTypes = op.iterator_types().getValue();
- SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
- if (isParallelIterator(iteratorTypes[0]) &&
- isParallelIterator(iteratorTypes[1]) &&
- isReductionIterator(iteratorTypes[2])) {
- //
- // Two outer parallel, one inner reduction (matmat flavor).
- // When lowering to outerproduct we can support all permutations.
- //
- if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
- maps != infer({{m, k}, {n, k}, {m, n}}) &&
- maps != infer({{k, m}, {k, n}, {m, n}}) &&
- maps != infer({{k, m}, {n, k}, {m, n}}) &&
- maps != infer({{m, k}, {k, n}, {n, m}}) &&
- maps != infer({{m, k}, {n, k}, {n, m}}) &&
- maps != infer({{k, m}, {k, n}, {n, m}}) &&
- maps != infer({{k, m}, {n, k}, {n, m}}))
- return failure();
- return success();
- } else if (isParallelIterator(iteratorTypes[0]) &&
- isReductionIterator(iteratorTypes[1])) {
- //
- // One outer parallel, one inner reduction (matvec flavor)
- // See if a series of AXPY operations chained through FMA operations
- // could replace the default DOT implementation.
- //
- if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec
- maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec
- maps != infer({{n}, {m, n}, {m}}) && // vec-mat
- maps != infer({{n}, {n, m}, {m}})) // vec-mat-trans
- return failure();
- return success();
- }
- return failure();
-}
-
-void ContractionOpToOuterProductOpLowering::rewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
Location loc = op.getLoc();
int64_t reductionSize = 0;
VectorType lhsType = op.getLhsType();
@@ -1759,13 +1712,14 @@ void ContractionOpToOuterProductOpLowering::rewrite(
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = tmp;
+ } else {
+ return failure();
}
- } else {
+ } else if (isParallelIterator(iteratorTypes[0]) &&
+ isReductionIterator(iteratorTypes[1])) {
//
// One outer parallel, one inner reduction (matvec flavor)
//
- assert(isParallelIterator(iteratorTypes[0]) &&
- isReductionIterator(iteratorTypes[1]));
if (maps == infer({{m, n}, {n}, {m}})) {
// Case mat-vec: transpose.
reductionSize = lhsType.getDimSize(1);
@@ -1782,7 +1736,11 @@ void ContractionOpToOuterProductOpLowering::rewrite(
// Case vec-mat-trans: swap and ready to go.
reductionSize = lhsType.getDimSize(0);
std::swap(lhs, rhs);
+ } else {
+ return failure();
}
+ } else {
+ return failure();
}
assert(reductionSize > 0);
@@ -1793,6 +1751,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
}
rewriter.replaceOp(op, res);
+ return success();
}
/// Progressive lowering of ContractionOp.
@@ -1815,7 +1774,6 @@ void ContractionOpToOuterProductOpLowering::rewrite(
LogicalResult
ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
-
// TODO: implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
@@ -1832,11 +1790,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
- if (succeeded(pat1.match(op)))
- return failure();
+ if (succeeded(pat1.matchAndRewrite(op, rewriter)))
+ return success();
ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
- if (succeeded(pat2.match(op)))
- return failure();
+ if (succeeded(pat2.matchAndRewrite(op, rewriter)))
+ return success();
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
More information about the Mlir-commits
mailing list