[Mlir-commits] [mlir] 1353cbc - [mlir][Vector] NFC - Use matchAndRewrite in ContractionOp lowering patterns

Nicolas Vasilache llvmlistbot at llvm.org
Thu Aug 6 06:04:49 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-06T09:02:25-04:00
New Revision: 1353cbc2570b2fe4b418a9acea9778eca5625fb7

URL: https://github.com/llvm/llvm-project/commit/1353cbc2570b2fe4b418a9acea9778eca5625fb7
DIFF: https://github.com/llvm/llvm-project/commit/1353cbc2570b2fe4b418a9acea9778eca5625fb7.diff

LOG: [mlir][Vector] NFC - Use matchAndRewrite in ContractionOp lowering patterns

Replace the use of separate match and rewrite which unnecessarily duplicates logic.

Differential Revision: https://reviews.llvm.org/D85421

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index e6c7b7abebd5..35855b3b2137 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -208,9 +208,8 @@ class ContractionOpToMatmulOpLowering
       : OpRewritePattern<vector::ContractionOp>(context),
         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
 
-  LogicalResult match(vector::ContractionOp op) const override;
-  void rewrite(vector::ContractionOp op,
-               PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -250,9 +249,8 @@ class ContractionOpToOuterProductOpLowering
       : OpRewritePattern<vector::ContractionOp>(context),
         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
 
-  LogicalResult match(vector::ContractionOp op) const override;
-  void rewrite(vector::ContractionOp op,
-               PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 33fbed65ace6..922168947ccf 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1576,16 +1576,14 @@ namespace mlir {
 //
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
+LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
   // TODO: implement masks
   if (llvm::size(op.masks()) != 0)
     return failure();
-
   if (vectorTransformsOptions.vectorContractLowering !=
       vector::VectorContractLowering::Matmul)
     return failure();
-
   if (failed(filter(op)))
     return failure();
 
@@ -1598,11 +1596,10 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
   if (!isRowMajorMatmul(op.indexing_maps()))
     return failure();
 
-  return success();
-}
+  Type elementType = op.getLhsType().getElementType();
+  if (!elementType.isIntOrFloat())
+    return failure();
 
-void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
-                                              PatternRewriter &rewriter) const {
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
   int64_t lhsRows = lhsType.getDimSize(0);
@@ -1622,12 +1619,12 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
                                                 lhsColumns, rhsColumns);
   mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
                                              mul);
-  Type elementType = op.getLhsType().getElementType();
-  assert(elementType.isIntOrFloat());
   if (elementType.isa<IntegerType>())
     rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
   else
     rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
+
+  return success();
 }
 
 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
@@ -1645,8 +1642,8 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
 ///
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult
-ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
+LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
   // TODO: implement masks
   if (llvm::size(op.masks()) != 0)
     return failure();
@@ -1658,50 +1655,6 @@ ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
   if (failed(filter(op)))
     return failure();
 
-  // Determine if the parallel/reduction structure matches something
-  // that can be expressed a reduction_size unrolled sequence.
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-  AffineExpr m, n, k;
-  bindDims(op.getContext(), m, n, k);
-  auto iteratorTypes = op.iterator_types().getValue();
-  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
-  if (isParallelIterator(iteratorTypes[0]) &&
-      isParallelIterator(iteratorTypes[1]) &&
-      isReductionIterator(iteratorTypes[2])) {
-    //
-    // Two outer parallel, one inner reduction (matmat flavor).
-    // When lowering to outerproduct we can support all permutations.
-    //
-    if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
-        maps != infer({{m, k}, {n, k}, {m, n}}) &&
-        maps != infer({{k, m}, {k, n}, {m, n}}) &&
-        maps != infer({{k, m}, {n, k}, {m, n}}) &&
-        maps != infer({{m, k}, {k, n}, {n, m}}) &&
-        maps != infer({{m, k}, {n, k}, {n, m}}) &&
-        maps != infer({{k, m}, {k, n}, {n, m}}) &&
-        maps != infer({{k, m}, {n, k}, {n, m}}))
-      return failure();
-    return success();
-  } else if (isParallelIterator(iteratorTypes[0]) &&
-             isReductionIterator(iteratorTypes[1])) {
-    //
-    // One outer parallel, one inner reduction (matvec flavor)
-    // See if a series of AXPY operations chained through FMA operations
-    // could replace the default DOT implementation.
-    //
-    if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec
-        maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec
-        maps != infer({{n}, {m, n}, {m}}) && // vec-mat
-        maps != infer({{n}, {n, m}, {m}}))   // vec-mat-trans
-      return failure();
-    return success();
-  }
-  return failure();
-}
-
-void ContractionOpToOuterProductOpLowering::rewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   int64_t reductionSize = 0;
   VectorType lhsType = op.getLhsType();
@@ -1759,13 +1712,14 @@ void ContractionOpToOuterProductOpLowering::rewrite(
       Value tmp = lhs;
       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
       rhs = tmp;
+    } else {
+      return failure();
     }
-  } else {
+  } else if (isParallelIterator(iteratorTypes[0]) &&
+             isReductionIterator(iteratorTypes[1])) {
     //
     // One outer parallel, one inner reduction (matvec flavor)
     //
-    assert(isParallelIterator(iteratorTypes[0]) &&
-           isReductionIterator(iteratorTypes[1]));
     if (maps == infer({{m, n}, {n}, {m}})) {
       // Case mat-vec: transpose.
       reductionSize = lhsType.getDimSize(1);
@@ -1782,7 +1736,11 @@ void ContractionOpToOuterProductOpLowering::rewrite(
       // Case vec-mat-trans: swap and ready to go.
       reductionSize = lhsType.getDimSize(0);
       std::swap(lhs, rhs);
+    } else {
+      return failure();
     }
+  } else {
+    return failure();
   }
   assert(reductionSize > 0);
 
@@ -1793,6 +1751,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
     res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
   }
   rewriter.replaceOp(op, res);
+  return success();
 }
 
 /// Progressive lowering of ContractionOp.
@@ -1815,7 +1774,6 @@ void ContractionOpToOuterProductOpLowering::rewrite(
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
-
   // TODO: implement masks.
   if (llvm::size(op.masks()) != 0)
     return failure();
@@ -1832,11 +1790,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
   ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
-  if (succeeded(pat1.match(op)))
-    return failure();
+  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
+    return success();
   ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
-  if (succeeded(pat2.match(op)))
-    return failure();
+  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
+    return success();
 
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();


        


More information about the Mlir-commits mailing list