[Mlir-commits] [mlir] 1125c5c - [MLIR] Remove scf.if builder with explicit result types and callbacks
Frederik Gossen
llvmlistbot at llvm.org
Fri Jan 20 07:52:23 PST 2023
Author: Frederik Gossen
Date: 2023-01-20T10:52:08-05:00
New Revision: 1125c5c0b2cf13aa112a7531eb89fd1b771aa13b
URL: https://github.com/llvm/llvm-project/commit/1125c5c0b2cf13aa112a7531eb89fd1b771aa13b
DIFF: https://github.com/llvm/llvm-project/commit/1125c5c0b2cf13aa112a7531eb89fd1b771aa13b.diff
LOG: [MLIR] Remove scf.if builder with explicit result types and callbacks
Instead, use the builder and infer the return type based on the inner `yield` ops.
Also, fix uses that do not create the terminator as required for the callback builders.
Differential Revision: https://reviews.llvm.org/D142056
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 05adc85434778..b3b88250c3247 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -667,21 +667,15 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
let skipDefaultBuilders = 1;
let builders = [
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond)>,
OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
"bool":$withElseRegion)>,
- // TODO: Remove builder when it is no longer used to create invalid `if` ops
- // (with a type mispatch between the op and it's inner `yield` op).
- OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "buildTerminatedBody">:$thenBuilder,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "nullptr">:$elseBuilder)>,
OpBuilder<(ins "Value":$cond,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$thenBuilder,
CArg<"function_ref<void(OpBuilder &, Location)>",
- "nullptr">:$elseBuilder)>
+ "nullptr">:$elseBuilder)>,
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 4e746ea561cb8..16cbfca3f3e2a 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -92,7 +92,7 @@ Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
Type indexTy = lb.getIndexType();
broadcastedDim =
lb.create<IfOp>(
- TypeRange{indexTy}, outOfBounds,
+ outOfBounds,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, broadcastedDim);
},
@@ -293,7 +293,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
loc, arith::CmpIPredicate::ult, iv, rankDiff);
broadcastable =
b.create<IfOp>(
- loc, TypeRange{i1Ty}, outOfBounds,
+ loc, outOfBounds,
[&](OpBuilder &b, Location loc) {
// Non existent dimensions are always broadcastable
b.create<scf::YieldOp>(loc, broadcastable);
@@ -522,7 +522,7 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
firstRank, rank);
auto same = rewriter.create<IfOp>(
- loc, i1Ty, eqRank,
+ loc, eqRank,
[&](OpBuilder &b, Location loc) {
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value init =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index dfaa8b35671e6..401e227629540 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -192,7 +192,7 @@ static Value generateInBoundsCheck(
// If the condition is non-empty, generate an SCF::IfOp.
if (cond) {
auto check = lb.create<scf::IfOp>(
- resultTypes, cond,
+ cond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 3cd4677c2cb8b..880a8ca3663a3 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -645,7 +645,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
};
// Dispatch either single block compute function, or launch async dispatch.
- b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
+ b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
}
// Dispatch parallel compute functions by submitting all async compute tasks
@@ -910,8 +910,8 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, blockSize, numIters);
- b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
- dispatchBlockAligned, dispatchDefault);
+ b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
+ dispatchDefault);
b.create<scf::YieldOp>();
} else {
dispatchDefault(b, loc);
@@ -919,7 +919,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
};
// Replace the `scf.parallel` operation with the parallel compute function.
- b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
+ b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
// Parallel operation was replaced with a block iteration loop.
rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b870330250418..15e1a68d50ab4 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1485,44 +1485,41 @@ IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
+void IfOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTypes, Value cond) {
+ result.addTypes(resultTypes);
+ result.addOperands(cond);
+
+ // Build regions.
+ OpBuilder::InsertionGuard guard(builder);
+ result.addRegion();
+ result.addRegion();
+}
+
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
bool withElseRegion) {
- build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion);
+ build(builder, result, TypeRange{}, cond, withElseRegion);
}
void IfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value cond, bool withElseRegion) {
- auto addTerminator = [&](OpBuilder &nested, Location loc) {
- if (resultTypes.empty())
- IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
- loc);
- };
-
- build(builder, result, resultTypes, cond, addTerminator,
- withElseRegion ? addTerminator
- : function_ref<void(OpBuilder &, Location)>());
-}
-
-void IfOp::build(OpBuilder &builder, OperationState &result,
- TypeRange resultTypes, Value cond,
- function_ref<void(OpBuilder &, Location)> thenBuilder,
- function_ref<void(OpBuilder &, Location)> elseBuilder) {
- assert(thenBuilder && "the builder callback for 'then' must be present");
- result.addOperands(cond);
result.addTypes(resultTypes);
+ result.addOperands(cond);
// Build then region.
OpBuilder::InsertionGuard guard(builder);
Region *thenRegion = result.addRegion();
builder.createBlock(thenRegion);
- thenBuilder(builder, result.location);
+ if (resultTypes.empty())
+ IfOp::ensureTerminator(*thenRegion, builder, result.location);
// Build else region.
Region *elseRegion = result.addRegion();
- if (!elseBuilder)
- return;
- builder.createBlock(elseRegion);
- elseBuilder(builder, result.location);
+ if (withElseRegion) {
+ builder.createBlock(elseRegion);
+ if (resultTypes.empty())
+ IfOp::ensureTerminator(*elseRegion, builder, result.location);
+ }
}
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
@@ -1730,9 +1727,10 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
[](OpResult result) { return result.getType(); });
// Create a replacement operation with empty then and else regions.
- auto emptyBuilder = [](OpBuilder &, Location) {};
- auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
- emptyBuilder, emptyBuilder);
+ auto newOp =
+ rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
+ rewriter.createBlock(&newOp.getThenRegion());
+ rewriter.createBlock(&newOp.getElseRegion());
// Move the bodies and replace the terminators (note there is a then and
// an else region since the operation returns results).
@@ -1796,7 +1794,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
if (nonHoistable.size() == op->getNumResults())
return failure();
- IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
+ IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
+ /*withElseRegion=*/false);
if (replacement.thenBlock())
rewriter.eraseBlock(replacement.thenBlock());
replacement.getThenRegion().takeBody(op.getThenRegion());
@@ -2249,6 +2248,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
Value newCondition = rewriter.create<arith::AndIOp>(
loc, op.getCondition(), nestedIf.getCondition());
auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+ Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
SmallVector<Value> results;
llvm::append_range(results, newIf.getResults());
@@ -2258,11 +2258,6 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
results[idx] = rewriter.create<arith::SelectOp>(
op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
- Block *newIfBlock = newIf.thenBlock();
- if (newIfBlock)
- rewriter.eraseOp(newIfBlock->getTerminator());
- else
- newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
rewriter.setInsertionPointToEnd(newIf.thenBlock());
rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 3fc470e1874ad..0ab41b6343f67 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -632,7 +632,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
// creating SliceOps with result dimensions of size 0 at runtime.
if (generateZeroSliceGuard && dynHasZeroLenCond) {
auto result = b.create<scf::IfOp>(
- loc, resultType, dynHasZeroLenCond,
+ loc, dynHasZeroLenCond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp
index 662ba6c09987d..9536f3233b814 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPaddingPatterns.cpp
@@ -81,8 +81,8 @@ struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
Operation *newOp = builder.clone(*padOp);
builder.create<scf::YieldOp>(loc, newOp->getResults());
};
- rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond,
- thenBuilder, elseBuilder);
+ rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, ifCond, thenBuilder,
+ elseBuilder);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c8b0fc48da06c..48995afa6876d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1126,7 +1126,7 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value newResult =
rewriter
.create<scf::IfOp>(
- loc, distrType, isInsertingLane,
+ loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertElementOp>(
@@ -1257,7 +1257,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
builder.create<scf::YieldOp>(loc, distributedDest);
};
newResult = rewriter
- .create<scf::IfOp>(loc, distrDestType, isInsertingLane,
+ .create<scf::IfOp>(loc, isInsertingLane,
/*thenBuilder=*/insertingBuilder,
/*elseBuilder=*/nonInsertingBuilder)
.getResult(0);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index c4aad0f5cc97e..ee23b5494f707 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -252,7 +252,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
+ loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
@@ -307,7 +307,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
+ loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
@@ -358,7 +358,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
Value memref = xferOp.getSource();
return b
.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
+ loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
More information about the Mlir-commits
mailing list