[Mlir-commits] [mlir] ec62e37 - [mlir] [vector] Add an optional filter to vector contract lowering patterns.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jul 17 09:06:00 PDT 2020
Author: Pierre Oechsel
Date: 2020-07-17T12:03:13-04:00
New Revision: ec62e37c86fa67a40bc9e04b9112668deb003b9a
URL: https://github.com/llvm/llvm-project/commit/ec62e37c86fa67a40bc9e04b9112668deb003b9a
DIFF: https://github.com/llvm/llvm-project/commit/ec62e37c86fa67a40bc9e04b9112668deb003b9a.diff
LOG: [mlir] [vector] Add an optional filter to vector contract lowering patterns.
Summary: Vector contract patterns were only parameterized by a `vectorTransformsOptions`. As a result, even if an mlir file was containing several occurrences of `vector.contract`, all of them would be lowered in the same way. More granularity might be required . This Diff adds a `constraint` argument to each of these patterns which allows the user to specify with more precision on which `vector.contract` should each of the lowering apply.
Differential Revision: https://reviews.llvm.org/D83960
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index e95329c3e505..0d18c5aa782d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -127,12 +127,18 @@ class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
- MLIRContext *context)
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult match(vector::ContractionOp op) const override;
void rewrite(vector::ContractionOp op,
@@ -141,6 +147,7 @@ class ContractionOpToMatmulOpLowering
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
@@ -162,11 +169,18 @@ class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
- MLIRContext *context)
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult match(vector::ContractionOp op) const override;
void rewrite(vector::ContractionOp op,
@@ -175,6 +189,7 @@ class ContractionOpToOuterProductOpLowering
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
};
/// Progressive lowering of ContractionOp.
@@ -194,11 +209,18 @@ class ContractionOpToOuterProductOpLowering
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
- MLIRContext *context)
+ MLIRContext *context,
+ FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
@@ -206,6 +228,7 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2f77fd5ff60a..a63862c1a4fe 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1581,6 +1581,9 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
vector::VectorContractLowering::Matmul)
return failure();
+ if (failed(filter(op)))
+ return failure();
+
auto iteratorTypes = op.iterator_types().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
@@ -1647,6 +1650,9 @@ ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
vector::VectorContractLowering::OuterProduct)
return failure();
+ if (failed(filter(op)))
+ return failure();
+
// Determine if the parallel/reduction structure matches something
// that can be expressed a reduction_size unrolled sequence.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
@@ -1808,6 +1814,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
// TODO: support mixed mode contract lowering.
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 82faadf100e9..6dae907b8bb0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
#dotp_accesses = [
affine_map<(i) -> (i)>,
@@ -1029,3 +1030,33 @@ func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
+
+// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
+// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
+// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
+func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
+-> vector<4x4xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
+// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
+// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
+func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
+-> vector<3x4xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
+
+
+
+
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 7e28ebbd9b72..2dffd88ed709 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -59,6 +59,11 @@ struct TestVectorContractionConversion
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
+ Option<bool> lowerToFilterOuterProduct{
+ *this, "vector-filter-outerproduct",
+ llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
+ "vectors of size 4."),
+ llvm::cl::init(false)};
void runOnFunction() override {
OwningRewritePatternList patterns;
@@ -73,6 +78,22 @@ struct TestVectorContractionConversion
return;
}
+ // Test on one pattern in isolation.
+ if (lowerToFilterOuterProduct) {
+ VectorContractLowering lowering = VectorContractLowering::OuterProduct;
+ VectorTransformsOptions options{lowering};
+ patterns.insert<ContractionOpToOuterProductOpLowering>(
+ options, &getContext(), [](vector::ContractionOp op) {
+ // Only lowers vector.contract where the lhs as a type vector<MxNx?>
+ // where M is not 4.
+ if (op.getRhsType().getShape()[0] == 4)
+ return failure();
+ return success();
+ });
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+ return;
+ }
+
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
More information about the Mlir-commits
mailing list