[Mlir-commits] [mlir] 7c16f93 - [mlir][linalg] Remove template parameter from loop lowering.

Tobias Gysi llvmlistbot at llvm.org
Mon May 17 02:32:41 PDT 2021


Author: Tobias Gysi
Date: 2021-05-17T09:31:53Z
New Revision: 7c16f93c44caa341404ff78a14eba163cd243e5e

URL: https://github.com/llvm/llvm-project/commit/7c16f93c44caa341404ff78a14eba163cd243e5e
DIFF: https://github.com/llvm/llvm-project/commit/7c16f93c44caa341404ff78a14eba163cd243e5e.diff

LOG: [mlir][linalg] Remove template parameter from loop lowering.

Replace the templated linalgLowerOpToLoops method by three specialized methods linalgOpToLoops, LinalgOpToParallelLoops, and linalgOpToAffineLoops.

Differential Revision: https://reviews.llvm.org/D102324

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 501c34f5c46b0..491c59d838c45 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -342,21 +342,17 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
 LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
                                 SmallVectorImpl<Value> &newResults);
 
-/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
-template <typename LoopTy>
-Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
-                                           LinalgOp linalgOp);
-
 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
-LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
+Optional<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
+                                      LinalgOp linalgOp);
 
 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
-LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
-                                      LinalgOp linalgOp);
+Optional<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
+                                              LinalgOp linalgOp);
 
 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
-LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
-                                    LinalgOp linalgOp);
+Optional<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
+                                            LinalgOp linalgOp);
 
 //===----------------------------------------------------------------------===//
 // Preconditions that ensure the corresponding transformation succeeds and can
@@ -814,15 +810,15 @@ struct LinalgLoweringPattern : public RewritePattern {
       // TODO: Move lowering to library calls here.
       return failure();
     case LinalgLoweringType::Loops:
-      if (failed(linalgOpToLoops(rewriter, op)))
+      if (!linalgOpToLoops(rewriter, op))
         return failure();
       break;
     case LinalgLoweringType::AffineLoops:
-      if (failed(linalgOpToAffineLoops(rewriter, op)))
+      if (!linalgOpToAffineLoops(rewriter, op))
         return failure();
       break;
     case LinalgLoweringType::ParallelLoops:
-      if (failed(linalgOpToParallelLoops(rewriter, op)))
+      if (!linalgOpToParallelLoops(rewriter, op))
         return failure();
       break;
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b1bf213e9cbb6..317a9864516ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -378,18 +378,54 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
       getPoolingInput<IndexedValueType>(op, indices.inputs);
 }
 
+/// Replace the index operations in the body of the loop nest by the matching
+/// induction variables.
+static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
+                                                PatternRewriter &rewriter,
+                                                ArrayRef<Operation *> loopOps) {
+  // Extract the induction variables of the loop nest from outer to inner.
+  SmallVector<Value> allIvs;
+  for (Operation *loopOp : loopOps) {
+    llvm::TypeSwitch<Operation *>(loopOp)
+        .Case([&](scf::ParallelOp parallelOp) {
+          allIvs.append(parallelOp.getInductionVars().begin(),
+                        parallelOp.getInductionVars().end());
+        })
+        .Case([&](scf::ForOp forOp) {
+          allIvs.push_back(forOp.getInductionVar());
+        })
+        .Case([&](AffineForOp affineForOp) {
+          allIvs.push_back(affineForOp.getInductionVar());
+        })
+        .Default([&](Operation *op) { assert(false && "unexpected op"); });
+  }
+  assert(linalgOp.getNumLoops() == allIvs.size() &&
+         "expected the number of loops and induction variables to match");
+  // Replace the index operations in the body of the innermost loop op.
+  if (!loopOps.empty()) {
+    LoopLikeOpInterface loopOp = loopOps.back();
+    for (IndexOp indexOp :
+         llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
+      rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
+  }
+}
+
 template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
-                                                 OpBuilder &builder) {
+static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
+                                                 LinalgOp linalgOp) {
   using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
-  ScopedContext scope(builder, linalgOp.getLoc());
+  ScopedContext scope(rewriter, linalgOp.getLoc());
+
+  // Canonicalize indexed_generic operations before lowering them to loops.
+  if (isa<IndexedGenericOp>(linalgOp))
+    return llvm::None;
 
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (which is asserted in the inverse calculation).
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
 
-  auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
+  auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
   auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
 
   SmallVector<Value, 4> allIvs;
@@ -420,41 +456,11 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
     loopSet.insert(ivVal.getOwner()->getParentOp());
   }
   LinalgLoops loops(loopSet.begin(), loopSet.end());
+  // Replace all index operations in the loop body.
+  replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
   return loops;
 }
 
-/// Replace the index operations in the body of the loop nest by the matching
-/// induction variables.
-static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
-                                                PatternRewriter &rewriter,
-                                                ArrayRef<Operation *> loopOps) {
-  // Extract the induction variables of the loop nest from outer to inner.
-  SmallVector<Value> allIvs;
-  for (Operation *loopOp : loopOps) {
-    llvm::TypeSwitch<Operation *>(loopOp)
-        .Case([&](scf::ParallelOp parallelOp) {
-          allIvs.append(parallelOp.getInductionVars().begin(),
-                        parallelOp.getInductionVars().end());
-        })
-        .Case([&](scf::ForOp forOp) {
-          allIvs.push_back(forOp.getInductionVar());
-        })
-        .Case([&](AffineForOp affineForOp) {
-          allIvs.push_back(affineForOp.getInductionVar());
-        })
-        .Default([&](Operation *op) { assert(false && "unexpected op"); });
-  }
-  assert(linalgOp.getNumLoops() == allIvs.size() &&
-         "expected the number of loops and induction variables to match");
-  // Replace the index operations in the body of the innermost loop op.
-  if (!loopOps.empty()) {
-    LoopLikeOpInterface loopOp = loopOps.back();
-    for (IndexOp indexOp :
-         llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
-      rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
-  }
-}
-
 namespace {
 template <typename LoopType>
 class LinalgRewritePattern : public RewritePattern {
@@ -467,7 +473,7 @@ class LinalgRewritePattern : public RewritePattern {
     auto linalgOp = dyn_cast<LinalgOp>(op);
     if (!isa<LinalgOp>(op))
       return failure();
-    if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
+    if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))
       return failure();
     rewriter.eraseOp(op);
     return success();
@@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() {
   return std::make_unique<LowerToAffineLoops>();
 }
 
-/// Emits a loop nest with the proper body for `linalgOp`.
-template <typename LoopTy>
-Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
-                                   LinalgOp linalgOp) {
-  // Convert indexed_generic ops to generic ops before lowering them to loops.
-  if (isa<IndexedGenericOp>(linalgOp))
-    return llvm::None;
-
-  Optional<LinalgLoops> loopOps =
-      linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
-  if (loopOps.hasValue())
-    replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
-  return loopOps;
-}
-
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
-                                                LinalgOp linalgOp);
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
-                                               LinalgOp linalgOp);
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
-                                                    LinalgOp linalgOp);
-
 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
-LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
-                                                  LinalgOp linalgOp) {
-  Optional<LinalgLoops> loops =
-      linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
-  return loops ? success() : failure();
+Optional<LinalgLoops>
+mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
+                                    LinalgOp linalgOp) {
+  return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
 }
 
 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
-LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
-                                            LinalgOp linalgOp) {
-  Optional<LinalgLoops> loops =
-      linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
-  return loops ? success() : failure();
+Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+                                                    LinalgOp linalgOp) {
+  return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
 }
 
 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
-LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
-                                                    LinalgOp linalgOp) {
-  Optional<LinalgLoops> loops =
-      linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
-  return loops ? success() : failure();
+Optional<LinalgLoops>
+mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
+                                      LinalgOp linalgOp) {
+  return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
 }


        


More information about the Mlir-commits mailing list