[Mlir-commits] [mlir] 05d2297 - [mlir][linalg] Always lower index operations during loop lowering.
Tobias Gysi
llvmlistbot at llvm.org
Tue May 4 07:31:17 PDT 2021
Author: Tobias Gysi
Date: 2021-05-04T14:30:59Z
New Revision: 05d2297b869444465134a17ce625b35f859958d0
URL: https://github.com/llvm/llvm-project/commit/05d2297b869444465134a17ce625b35f859958d0
DIFF: https://github.com/llvm/llvm-project/commit/05d2297b869444465134a17ce625b35f859958d0.diff
LOG: [mlir][linalg] Always lower index operations during loop lowering.
Ensure the index operations are lowered on all linalg loop lowering paths.
Differential Revision: https://reviews.llvm.org/D101827
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d76ccd91fdbf7..7c7ffb2dde577 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -337,18 +337,21 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults);
-/// Emits a loop nest of `LoopTy` with the proper body for `op`.
+/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
template <typename LoopTy>
-Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
+Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
-/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
-/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
-/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index aa0297fdab7d1..5d5dc2ac071bb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -457,18 +457,17 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
}
template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
+static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
- ScopedContext scope(builder, op->getLoc());
+ ScopedContext scope(builder, linalgOp.getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
- auto linalgOp = cast<LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
- auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
+ auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
SmallVector<Value, 4> allIvs;
@@ -477,7 +476,7 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
- llvm::TypeSwitch<Operation *>(op)
+ llvm::TypeSwitch<Operation *>(linalgOp)
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
IndexedGenericOp, LinalgOp>([&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op);
@@ -546,10 +545,8 @@ class LinalgRewritePattern : public RewritePattern {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op))
return failure();
- Optional<LinalgLoops> loopOps = linalgOpToLoopsImpl<LoopType>(op, rewriter);
- if (!loopOps.hasValue())
+ if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
return failure();
- replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
rewriter.eraseOp(op);
return success();
}
@@ -695,40 +692,48 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
-/// Emits a loop nest with the proper body for `op`.
+/// Emits a loop nest with the proper body for `linalgOp`.
template <typename LoopTy>
-Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
- Operation *op) {
- return linalgOpToLoopsImpl<LoopTy>(op, builder);
+Optional<LinalgLoops>
+mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
+ Optional<LinalgLoops> loopOps =
+ linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
+ if (loopOps.hasValue())
+ replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
+ return loopOps;
}
template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
- Operation *op);
+mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
- Operation *op);
+mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
- Operation *op);
+mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
+ LinalgOp linalgOp);
-/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
- Operation *op) {
- Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
+/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
+ Optional<LinalgLoops> loops =
+ linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
return loops ? success() : failure();
}
-/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
- Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
+/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
+ Optional<LinalgLoops> loops =
+ linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
return loops ? success() : failure();
}
-/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
- Operation *op) {
+/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
+ LinalgOp linalgOp) {
Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
+ linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
return loops ? success() : failure();
}
More information about the Mlir-commits
mailing list