[Mlir-commits] [mlir] 05d2297 - [mlir][linalg] Always lower index operations during loop lowering.

Tobias Gysi llvmlistbot at llvm.org
Tue May 4 07:31:17 PDT 2021


Author: Tobias Gysi
Date: 2021-05-04T14:30:59Z
New Revision: 05d2297b869444465134a17ce625b35f859958d0

URL: https://github.com/llvm/llvm-project/commit/05d2297b869444465134a17ce625b35f859958d0
DIFF: https://github.com/llvm/llvm-project/commit/05d2297b869444465134a17ce625b35f859958d0.diff

LOG: [mlir][linalg] Always lower index operations during loop lowering.

Ensure the index operations are lowered on all linalg loop lowering paths.

Differential Revision: https://reviews.llvm.org/D101827

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d76ccd91fdbf7..7c7ffb2dde577 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -337,18 +337,21 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
 LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
                                 SmallVectorImpl<Value> &newResults);
 
-/// Emits a loop nest of `LoopTy` with the proper body for `op`.
+/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
 template <typename LoopTy>
-Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
+Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
+                                           LinalgOp linalgOp);
 
-/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
 
-/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
+                                      LinalgOp linalgOp);
 
-/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
+                                    LinalgOp linalgOp);
 
 //===----------------------------------------------------------------------===//
 // Preconditions that ensure the corresponding transformation succeeds and can

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index aa0297fdab7d1..5d5dc2ac071bb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -457,18 +457,17 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
 }
 
 template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
+static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
                                                  OpBuilder &builder) {
   using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
-  ScopedContext scope(builder, op->getLoc());
+  ScopedContext scope(builder, linalgOp.getLoc());
 
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (which is asserted in the inverse calculation).
-  auto linalgOp = cast<LinalgOp>(op);
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
 
-  auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
+  auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
   auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
 
   SmallVector<Value, 4> allIvs;
@@ -477,7 +476,7 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
       [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
         assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());
-        llvm::TypeSwitch<Operation *>(op)
+        llvm::TypeSwitch<Operation *>(linalgOp)
             .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
                   IndexedGenericOp, LinalgOp>([&](auto op) {
               emitScalarImplementation<IndexedValueTy>(allIvs, op);
@@ -546,10 +545,8 @@ class LinalgRewritePattern : public RewritePattern {
     auto linalgOp = dyn_cast<LinalgOp>(op);
     if (!isa<LinalgOp>(op))
       return failure();
-    Optional<LinalgLoops> loopOps = linalgOpToLoopsImpl<LoopType>(op, rewriter);
-    if (!loopOps.hasValue())
+    if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
       return failure();
-    replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
     rewriter.eraseOp(op);
     return success();
   }
@@ -695,40 +692,48 @@ mlir::createConvertLinalgToAffineLoopsPass() {
   return std::make_unique<LowerToAffineLoops>();
 }
 
-/// Emits a loop nest with the proper body for `op`.
+/// Emits a loop nest with the proper body for `linalgOp`.
 template <typename LoopTy>
-Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
-                                                         Operation *op) {
-  return linalgOpToLoopsImpl<LoopTy>(op, builder);
+Optional<LinalgLoops>
+mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
+                                   LinalgOp linalgOp) {
+  Optional<LinalgLoops> loopOps =
+      linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
+  if (loopOps.hasValue())
+    replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
+  return loopOps;
 }
 
 template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
-                                                Operation *op);
+mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
+                                                LinalgOp linalgOp);
 template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
-                                               Operation *op);
+mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
+                                               LinalgOp linalgOp);
 template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
-                                                    Operation *op);
+mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
+                                                    LinalgOp linalgOp);
 
-/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
-                                                  Operation *op) {
-  Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
+/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
+                                                  LinalgOp linalgOp) {
+  Optional<LinalgLoops> loops =
+      linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
   return loops ? success() : failure();
 }
 
-/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
-  Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
+/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+                                            LinalgOp linalgOp) {
+  Optional<LinalgLoops> loops =
+      linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
   return loops ? success() : failure();
 }
 
-/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
-                                                    Operation *op) {
+/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
+                                                    LinalgOp linalgOp) {
   Optional<LinalgLoops> loops =
-      linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
+      linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
   return loops ? success() : failure();
 }


        


More information about the Mlir-commits mailing list