[Mlir-commits] [mlir] 7c16f93 - [mlir][linalg] Remove template parameter from loop lowering.
Tobias Gysi
llvmlistbot at llvm.org
Mon May 17 02:32:41 PDT 2021
Author: Tobias Gysi
Date: 2021-05-17T09:31:53Z
New Revision: 7c16f93c44caa341404ff78a14eba163cd243e5e
URL: https://github.com/llvm/llvm-project/commit/7c16f93c44caa341404ff78a14eba163cd243e5e
DIFF: https://github.com/llvm/llvm-project/commit/7c16f93c44caa341404ff78a14eba163cd243e5e.diff
LOG: [mlir][linalg] Remove template parameter from loop lowering.
Replace the templated linalgLowerOpToLoops method by three specialized methods linalgOpToLoops, LinalgOpToParallelLoops, and linalgOpToAffineLoops.
Differential Revision: https://reviews.llvm.org/D102324
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 501c34f5c46b0..491c59d838c45 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -342,21 +342,17 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults);
-/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
-template <typename LoopTy>
-Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp);
-
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
-LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
+Optional<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
-LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp);
+Optional<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
-LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp);
+Optional<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
@@ -814,15 +810,15 @@ struct LinalgLoweringPattern : public RewritePattern {
// TODO: Move lowering to library calls here.
return failure();
case LinalgLoweringType::Loops:
- if (failed(linalgOpToLoops(rewriter, op)))
+ if (!linalgOpToLoops(rewriter, op))
return failure();
break;
case LinalgLoweringType::AffineLoops:
- if (failed(linalgOpToAffineLoops(rewriter, op)))
+ if (!linalgOpToAffineLoops(rewriter, op))
return failure();
break;
case LinalgLoweringType::ParallelLoops:
- if (failed(linalgOpToParallelLoops(rewriter, op)))
+ if (!linalgOpToParallelLoops(rewriter, op))
return failure();
break;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b1bf213e9cbb6..317a9864516ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -378,18 +378,54 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
getPoolingInput<IndexedValueType>(op, indices.inputs);
}
+/// Replace the index operations in the body of the loop nest by the matching
+/// induction variables.
+static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Operation *> loopOps) {
+ // Extract the induction variables of the loop nest from outer to inner.
+ SmallVector<Value> allIvs;
+ for (Operation *loopOp : loopOps) {
+ llvm::TypeSwitch<Operation *>(loopOp)
+ .Case([&](scf::ParallelOp parallelOp) {
+ allIvs.append(parallelOp.getInductionVars().begin(),
+ parallelOp.getInductionVars().end());
+ })
+ .Case([&](scf::ForOp forOp) {
+ allIvs.push_back(forOp.getInductionVar());
+ })
+ .Case([&](AffineForOp affineForOp) {
+ allIvs.push_back(affineForOp.getInductionVar());
+ })
+ .Default([&](Operation *op) { assert(false && "unexpected op"); });
+ }
+ assert(linalgOp.getNumLoops() == allIvs.size() &&
+ "expected the number of loops and induction variables to match");
+ // Replace the index operations in the body of the innermost loop op.
+ if (!loopOps.empty()) {
+ LoopLikeOpInterface loopOp = loopOps.back();
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
+ rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
+ }
+}
+
template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
- OpBuilder &builder) {
+static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
- ScopedContext scope(builder, linalgOp.getLoc());
+ ScopedContext scope(rewriter, linalgOp.getLoc());
+
+ // Canonicalize indexed_generic operations before lowering them to loops.
+ if (isa<IndexedGenericOp>(linalgOp))
+ return llvm::None;
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
- auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
+ auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
SmallVector<Value, 4> allIvs;
@@ -420,41 +456,11 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
loopSet.insert(ivVal.getOwner()->getParentOp());
}
LinalgLoops loops(loopSet.begin(), loopSet.end());
+ // Replace all index operations in the loop body.
+ replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
return loops;
}
-/// Replace the index operations in the body of the loop nest by the matching
-/// induction variables.
-static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
- PatternRewriter &rewriter,
- ArrayRef<Operation *> loopOps) {
- // Extract the induction variables of the loop nest from outer to inner.
- SmallVector<Value> allIvs;
- for (Operation *loopOp : loopOps) {
- llvm::TypeSwitch<Operation *>(loopOp)
- .Case([&](scf::ParallelOp parallelOp) {
- allIvs.append(parallelOp.getInductionVars().begin(),
- parallelOp.getInductionVars().end());
- })
- .Case([&](scf::ForOp forOp) {
- allIvs.push_back(forOp.getInductionVar());
- })
- .Case([&](AffineForOp affineForOp) {
- allIvs.push_back(affineForOp.getInductionVar());
- })
- .Default([&](Operation *op) { assert(false && "unexpected op"); });
- }
- assert(linalgOp.getNumLoops() == allIvs.size() &&
- "expected the number of loops and induction variables to match");
- // Replace the index operations in the body of the innermost loop op.
- if (!loopOps.empty()) {
- LoopLikeOpInterface loopOp = loopOps.back();
- for (IndexOp indexOp :
- llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
- rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
- }
-}
-
namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
@@ -467,7 +473,7 @@ class LinalgRewritePattern : public RewritePattern {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op))
return failure();
- if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
+ if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))
return failure();
rewriter.eraseOp(op);
return success();
@@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
-/// Emits a loop nest with the proper body for `linalgOp`.
-template <typename LoopTy>
-Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp) {
- // Convert indexed_generic ops to generic ops before lowering them to loops.
- if (isa<IndexedGenericOp>(linalgOp))
- return llvm::None;
-
- Optional<LinalgLoops> loopOps =
- linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
- if (loopOps.hasValue())
- replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
- return loopOps;
-}
-
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
- LinalgOp linalgOp);
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
- LinalgOp linalgOp);
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
- LinalgOp linalgOp);
-
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
-LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp) {
- Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
- return loops ? success() : failure();
+Optional<LinalgLoops>
+mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
+ return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
}
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
-LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp) {
- Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
- return loops ? success() : failure();
+Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
+ return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
}
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
-LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
- LinalgOp linalgOp) {
- Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
- return loops ? success() : failure();
+Optional<LinalgLoops>
+mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
+ return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
}
More information about the Mlir-commits
mailing list