[Mlir-commits] [mlir] 7f66e18 - [MLIR] Add InferTypeOpInterface to scf.if op
Frederik Gossen
llvmlistbot at llvm.org
Thu Jan 19 10:20:12 PST 2023
Author: Frederik Gossen
Date: 2023-01-19T13:19:50-05:00
New Revision: 7f66e1833f62d6f7269adc60ac18bbaa820f64ae
URL: https://github.com/llvm/llvm-project/commit/7f66e1833f62d6f7269adc60ac18bbaa820f64ae
DIFF: https://github.com/llvm/llvm-project/commit/7f66e1833f62d6f7269adc60ac18bbaa820f64ae.diff
LOG: [MLIR] Add InferTypeOpInterface to scf.if op
Differential Revision: https://reviews.llvm.org/D142049
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index c1e7bc33b4ef4..5453f3862e744 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 9e1752b69174e..05adc85434778 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -610,12 +611,11 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
// IfOp
//===----------------------------------------------------------------------===//
-def IfOp : SCF_Op<"if",
- [DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getNumRegionInvocations",
- "getRegionInvocationBounds"]>,
- SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
- NoRegionArguments]> {
+def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+ "getNumRegionInvocations", "getRegionInvocationBounds"]>,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
+ NoRegionArguments]> {
let summary = "if-then-else operation";
let description = [{
The `scf.if` operation represents an if-then-else construct for
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8699f1d7b162d..af2adb994145a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1467,6 +1467,23 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
return false;
}
+LogicalResult
+IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ValueRange operands, DictionaryAttr attrs,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (regions.empty())
+ return failure();
+ Region *r = regions.front();
+ assert(!r->empty());
+ Block &b = r->front();
+ auto yieldOp = llvm::dyn_cast<YieldOp>(b.getTerminator());
+ TypeRange types = yieldOp.getOperandTypes();
+ inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
+ types.end());
+ return success();
+}
+
void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
bool withElseRegion) {
build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion);
@@ -1516,19 +1533,24 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
// Build then region.
OpBuilder::InsertionGuard guard(builder);
Region *thenRegion = result.addRegion();
- Block *thenBlock = builder.createBlock(thenRegion);
+ builder.createBlock(thenRegion);
thenBuilder(builder, result.location);
- // Infer types if there are any.
- if (auto yieldOp = llvm::dyn_cast<YieldOp>(thenBlock->getTerminator()))
- result.addTypes(yieldOp.getOperandTypes());
-
// Build else region.
Region *elseRegion = result.addRegion();
- if (!elseBuilder)
- return;
- builder.createBlock(elseRegion);
- elseBuilder(builder, result.location);
+ if (elseBuilder) {
+ builder.createBlock(elseRegion);
+ elseBuilder(builder, result.location);
+ }
+
+ // Infer result types.
+ SmallVector<Type> inferredReturnTypes;
+ MLIRContext *ctx = builder.getContext();
+ auto attrDict = DictionaryAttr::get(ctx, result.attributes);
+ if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
+ result.regions, inferredReturnTypes))) {
+ result.addTypes(inferredReturnTypes);
+ }
}
LogicalResult IfOp::verify() {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 31e5bd21537f3..0258832570882 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1900,6 +1900,7 @@ td_library(
includes = ["include"],
deps = [
":ControlFlowInterfacesTdFiles",
+ ":InferTypeOpInterfaceTdFiles",
":LoopLikeInterfaceTdFiles",
":ParallelCombiningOpInterfaceTdFiles",
":SideEffectInterfacesTdFiles",
@@ -2929,6 +2930,7 @@ cc_library(
":ControlFlowInterfaces",
":FuncDialect",
":IR",
+ ":InferTypeOpInterface",
":LoopLikeInterface",
":MemRefDialect",
":ParallelCombiningOpInterface",
More information about the Mlir-commits
mailing list