[Mlir-commits] [mlir] ec62e37 - [mlir] [vector] Add an optional filter to vector contract lowering patterns.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 17 09:06:00 PDT 2020


Author: Pierre Oechsel
Date: 2020-07-17T12:03:13-04:00
New Revision: ec62e37c86fa67a40bc9e04b9112668deb003b9a

URL: https://github.com/llvm/llvm-project/commit/ec62e37c86fa67a40bc9e04b9112668deb003b9a
DIFF: https://github.com/llvm/llvm-project/commit/ec62e37c86fa67a40bc9e04b9112668deb003b9a.diff

LOG: [mlir] [vector] Add an optional filter to vector contract lowering patterns.

Summary: Vector contract patterns were only parameterized by a `vectorTransformsOptions`. As a result, even if an mlir file was containing several occurrences of `vector.contract`, all of them would be lowered in the same way. More granularity might be required . This Diff adds a `constraint` argument to each of these patterns which allows the user to specify with more precision on which `vector.contract` should each of the lowering apply.

Differential Revision: https://reviews.llvm.org/D83960

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index e95329c3e505..0d18c5aa782d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -127,12 +127,18 @@ class ContractionOpToMatmulOpLowering
     : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
 
   ContractionOpToMatmulOpLowering(
       vector::VectorTransformsOptions vectorTransformsOptions,
-      MLIRContext *context)
+      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions) {}
+        vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
 
   LogicalResult match(vector::ContractionOp op) const override;
   void rewrite(vector::ContractionOp op,
@@ -141,6 +147,7 @@ class ContractionOpToMatmulOpLowering
 private:
   /// Options to control the vector patterns.
   vector::VectorTransformsOptions vectorTransformsOptions;
+  FilterConstraintType filter;
 };
 
 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
@@ -162,11 +169,18 @@ class ContractionOpToOuterProductOpLowering
     : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
   ContractionOpToOuterProductOpLowering(
       vector::VectorTransformsOptions vectorTransformsOptions,
-      MLIRContext *context)
+      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions) {}
+        vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
 
   LogicalResult match(vector::ContractionOp op) const override;
   void rewrite(vector::ContractionOp op,
@@ -175,6 +189,7 @@ class ContractionOpToOuterProductOpLowering
 private:
   /// Options to control the vector patterns.
   vector::VectorTransformsOptions vectorTransformsOptions;
+  FilterConstraintType filter;
 };
 
 /// Progressive lowering of ContractionOp.
@@ -194,11 +209,18 @@ class ContractionOpToOuterProductOpLowering
 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
 
   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
-                        MLIRContext *context)
+                        MLIRContext *context,
+                        FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions) {}
+        vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override;
@@ -206,6 +228,7 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 private:
   /// Options to control the vector patterns.
   vector::VectorTransformsOptions vectorTransformsOptions;
+  FilterConstraintType filter;
   // Lower one parallel dimension.
   Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
                       int64_t rhsIndex, PatternRewriter &rewriter) const;

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2f77fd5ff60a..a63862c1a4fe 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1581,6 +1581,9 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
       vector::VectorContractLowering::Matmul)
     return failure();
 
+  if (failed(filter(op)))
+    return failure();
+
   auto iteratorTypes = op.iterator_types().getValue();
   if (!isParallelIterator(iteratorTypes[0]) ||
       !isParallelIterator(iteratorTypes[1]) ||
@@ -1647,6 +1650,9 @@ ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
       vector::VectorContractLowering::OuterProduct)
     return failure();
 
+  if (failed(filter(op)))
+    return failure();
+
   // Determine if the parallel/reduction structure matches something
   // that can be expressed a reduction_size unrolled sequence.
   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
@@ -1808,6 +1814,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   // TODO: implement masks.
   if (llvm::size(op.masks()) != 0)
     return failure();
+
+  if (failed(filter(op)))
+    return failure();
+
   // TODO: support mixed mode contract lowering.
   if (op.getLhsType().getElementType() !=
           getElementTypeOrSelf(op.getAccType()) ||

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 82faadf100e9..6dae907b8bb0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
 // RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
 // RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
 
 #dotp_accesses = [
   affine_map<(i) -> (i)>,
@@ -1029,3 +1030,33 @@ func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2
     : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
   return %0 : vector<3x2xf32>
 }
+
+// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
+// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
+//      FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
+func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
+-> vector<4x4xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
+// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
+//      FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
+func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
+-> vector<3x4xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
+  return %0 : vector<3x4xf32>
+}
+
+
+
+

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 7e28ebbd9b72..2dffd88ed709 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -59,6 +59,11 @@ struct TestVectorContractionConversion
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
       llvm::cl::init(false)};
+  Option<bool> lowerToFilterOuterProduct{
+      *this, "vector-filter-outerproduct",
+      llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
+                     "vectors of size 4."),
+      llvm::cl::init(false)};
 
   void runOnFunction() override {
     OwningRewritePatternList patterns;
@@ -73,6 +78,22 @@ struct TestVectorContractionConversion
       return;
     }
 
+    // Test on one pattern in isolation.
+    if (lowerToFilterOuterProduct) {
+      VectorContractLowering lowering = VectorContractLowering::OuterProduct;
+      VectorTransformsOptions options{lowering};
+      patterns.insert<ContractionOpToOuterProductOpLowering>(
+          options, &getContext(), [](vector::ContractionOp op) {
+            // Only lowers vector.contract where the lhs as a type vector<MxNx?>
+            // where M is not 4.
+            if (op.getRhsType().getShape()[0] == 4)
+              return failure();
+            return success();
+          });
+      applyPatternsAndFoldGreedily(getFunction(), patterns);
+      return;
+    }
+
     // Test on all contract lowering patterns.
     VectorContractLowering contractLowering = VectorContractLowering::Dot;
     if (lowerToFlatMatrix)


        


More information about the Mlir-commits mailing list