[Mlir-commits] [mlir] 2a99e70 - [mlir][Linalg] NFC: Add utility function to tile, fuse and set marker to use loop.parallel.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 13 13:25:23 PDT 2020


Author: MaheshRavishankar
Date: 2020-04-13T13:23:06-07:00
New Revision: 2a99e700e0f337c34c2d9d1cb5e4dc1d312fa248

URL: https://github.com/llvm/llvm-project/commit/2a99e700e0f337c34c2d9d1cb5e4dc1d312fa248
DIFF: https://github.com/llvm/llvm-project/commit/2a99e700e0f337c34c2d9d1cb5e4dc1d312fa248.diff

LOG: [mlir][Linalg] NFC: Add utility function to tile, fuse and set marker to use loop.parallel.

This change is NFC since the facility to tile and generate
loop.parallel loops already exists in Linalg.

Differential Revision: https://reviews.llvm.org/D77965

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index 3bff0f1bbf6e..4340366845c1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -63,12 +63,18 @@ LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
                                        ArrayRef<int64_t> sizes,
                                        StringRef linalgMarker,
                                        ArrayRef<unsigned> permutation);
+LogicalResult tileLinalgOpToParallelLoopsAndSetMarker(
+    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+    StringRef linalgMarker, ArrayRef<unsigned> permutation);
 
 /// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and
 /// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
 LogicalResult tileAndFuseLinalgOpAndSetMarker(
     PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
     ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
+LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
+    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+    ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
 
 using LinalgLoops = SmallVector<Operation *, 4>;
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 8511c8efffce..2e7043d9f24a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -40,11 +40,16 @@ using llvm::SetVector;
 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
     "__internal_linalg_transform__";
 
-LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
-    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
-    StringRef linalgMarker, ArrayRef<unsigned> permutation) {
+using TileFn = Optional<TiledLinalgOp>(OpBuilder &, LinalgOp, ArrayRef<int64_t>,
+                                       ArrayRef<unsigned>, OperationFolder *);
+
+static LogicalResult
+tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
+                             Operation *op, ArrayRef<int64_t> sizes,
+                             StringRef linalgMarker,
+                             ArrayRef<unsigned> permutation) {
   assert(permutation.empty() || permutation.size() == sizes.size());
-  auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation);
+  auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr);
   if (!tileRes)
     return failure();
   tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
@@ -52,10 +57,26 @@ LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
   return success();
 }
 
-LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
+LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
     PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
-    ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
-  auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+    StringRef linalgMarker, ArrayRef<unsigned> permutation) {
+  return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes,
+                                      linalgMarker, permutation);
+}
+LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker(
+    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+    StringRef linalgMarker, ArrayRef<unsigned> permutation) {
+  return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op,
+                                      sizes, linalgMarker, permutation);
+}
+
+static LogicalResult
+tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
+                                    Operation *op, ArrayRef<int64_t> sizes,
+                                    ArrayRef<int64_t> operandIndicesToFuse,
+                                    StringRef linalgMarker) {
+  auto tileRes =
+      tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr);
   if (!tileRes)
     return failure();
   tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
@@ -89,6 +110,20 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
   return success();
 }
 
+LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
+    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+    ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
+  return tileAndFuseLinalgOpAndSetMarkerImpl(
+      tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker);
+}
+LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
+    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+    ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
+  return tileAndFuseLinalgOpAndSetMarkerImpl(
+      tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse,
+      linalgMarker);
+}
+
 bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
     Operation *consumerOp, Value consumedView,
     function_ref<bool(Operation *)> isaOpType) {


        


More information about the Mlir-commits mailing list