[Mlir-commits] [mlir] 7f66e18 - [MLIR] Add InferTypeOpInterface to scf.if op

Frederik Gossen llvmlistbot at llvm.org
Thu Jan 19 10:20:12 PST 2023


Author: Frederik Gossen
Date: 2023-01-19T13:19:50-05:00
New Revision: 7f66e1833f62d6f7269adc60ac18bbaa820f64ae

URL: https://github.com/llvm/llvm-project/commit/7f66e1833f62d6f7269adc60ac18bbaa820f64ae
DIFF: https://github.com/llvm/llvm-project/commit/7f66e1833f62d6f7269adc60ac18bbaa820f64ae.diff

LOG: [MLIR] Add InferTypeOpInterface to scf.if op

Differential Revision: https://reviews.llvm.org/D142049

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCF.h
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index c1e7bc33b4ef4..5453f3862e744 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/RegionKindInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 9e1752b69174e..05adc85434778 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -610,12 +611,11 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
 // IfOp
 //===----------------------------------------------------------------------===//
 
-def IfOp : SCF_Op<"if",
-      [DeclareOpInterfaceMethods<RegionBranchOpInterface,
-                                 ["getNumRegionInvocations",
-                                  "getRegionInvocationBounds"]>,
-       SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
-       NoRegionArguments]> {
+def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+    "getNumRegionInvocations", "getRegionInvocationBounds"]>,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>,
+    SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
+    NoRegionArguments]> {
   let summary = "if-then-else operation";
   let description = [{
     The `scf.if` operation represents an if-then-else construct for

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8699f1d7b162d..af2adb994145a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1467,6 +1467,23 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
   return false;
 }
 
+LogicalResult
+IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+                       ValueRange operands, DictionaryAttr attrs,
+                       RegionRange regions,
+                       SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (regions.empty())
+    return failure();
+  Region *r = regions.front();
+  assert(!r->empty());
+  Block &b = r->front();
+  auto yieldOp = llvm::dyn_cast<YieldOp>(b.getTerminator());
+  TypeRange types = yieldOp.getOperandTypes();
+  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
+                             types.end());
+  return success();
+}
+
 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
                  bool withElseRegion) {
   build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion);
@@ -1516,19 +1533,24 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
   // Build then region.
   OpBuilder::InsertionGuard guard(builder);
   Region *thenRegion = result.addRegion();
-  Block *thenBlock = builder.createBlock(thenRegion);
+  builder.createBlock(thenRegion);
   thenBuilder(builder, result.location);
 
-  // Infer types if there are any.
-  if (auto yieldOp = llvm::dyn_cast<YieldOp>(thenBlock->getTerminator()))
-    result.addTypes(yieldOp.getOperandTypes());
-
   // Build else region.
   Region *elseRegion = result.addRegion();
-  if (!elseBuilder)
-    return;
-  builder.createBlock(elseRegion);
-  elseBuilder(builder, result.location);
+  if (elseBuilder) {
+    builder.createBlock(elseRegion);
+    elseBuilder(builder, result.location);
+  }
+
+  // Infer result types.
+  SmallVector<Type> inferredReturnTypes;
+  MLIRContext *ctx = builder.getContext();
+  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
+  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
+                                 result.regions, inferredReturnTypes))) {
+    result.addTypes(inferredReturnTypes);
+  }
 }
 
 LogicalResult IfOp::verify() {

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 31e5bd21537f3..0258832570882 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1900,6 +1900,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":ControlFlowInterfacesTdFiles",
+        ":InferTypeOpInterfaceTdFiles",
         ":LoopLikeInterfaceTdFiles",
         ":ParallelCombiningOpInterfaceTdFiles",
         ":SideEffectInterfacesTdFiles",
@@ -2929,6 +2930,7 @@ cc_library(
         ":ControlFlowInterfaces",
         ":FuncDialect",
         ":IR",
+        ":InferTypeOpInterface",
         ":LoopLikeInterface",
         ":MemRefDialect",
         ":ParallelCombiningOpInterface",


        


More information about the Mlir-commits mailing list