[Mlir-commits] [mlir] [MLIR] Introduce a SelectOpInterface (PR #104751)
Christian Ulmann
llvmlistbot at llvm.org
Mon Aug 19 03:42:44 PDT 2024
https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/104751
>From b1d6cfdcf34dbb502c59e21646ea3aa32c2da7d6 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 19 Aug 2024 09:14:34 +0000
Subject: [PATCH] [MLIR] Introduce a SelectOpInterface
This commit introduces a `SelectOpInterface` that can be used to handle
select-like operations generically. Select operations are similar to
control flow operations, as they forward operands depending on
conditions. This is the reason why it was placed to the already existing
control flow interfaces.
---
mlir/include/mlir/Dialect/Arith/IR/Arith.h | 1 +
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 2 +
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 +-
.../mlir/Interfaces/ControlFlowInterfaces.td | 21 ++++++++
mlir/lib/Analysis/SliceWalk.cpp | 6 ++-
.../Transforms/InlinerInterfaceImpl.cpp | 5 --
.../Dialect/LLVMIR/inlining-alias-scopes.mlir | 48 +++++++++++++++++++
7 files changed, 78 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 00cdb13feb29bb..77241319851e6c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 477478a4651cee..cddb3722c3ccff 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
BooleanConditionOrMatchingShape<"condition", "result">,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
+ DeclareOpInterfaceMethods<SelectOpInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 643522d5903fd0..6230f4d32994e5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
def LLVM_SelectOp
: LLVM_Op<"select",
[Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
- DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
+ DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+ DeclareOpInterfaceMethods<SelectOpInterface>]>,
LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 95ac5dea243aa4..7b6191c2332756 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -343,6 +343,27 @@ def RegionBranchTerminatorOpInterface :
}];
}
+def SelectOpInterface : OpInterface<"SelectOpInterface"> {
+ let description = [{
+ This interface provides information for select-like operations, i.e.,
+ operations that forward specific operands to the output, depending on a
+ condition.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the operand that would be chosen for a false condition.
+ }], "::mlir::Value", "getFalseValue", (ins)>,
+ InterfaceMethod<[{
+ Returns the operand that would be chosen for a true condition.
+ }], "::mlir::Value", "getTrueValue", (ins)>,
+ InterfaceMethod<[{
+ Returns the condition operand.
+ }], "::mlir::Value", "getCondition", (ins)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 9d770639dc53ca..6736f1b73e421f 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
std::optional<SmallVector<Value>>
mlir::getControlFlowPredecessors(Value value) {
- SmallVector<Value> result;
if (OpResult opResult = dyn_cast<OpResult>(value)) {
- auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
+ if (auto selectOp = opResult.getDefiningOp<SelectOpInterface>())
+ return SmallVector<Value>(
+ {selectOp.getTrueValue(), selectOp.getFalseValue()});
+ auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
// If the interface is not implemented, there are no control flow
// predecessors to work with.
if (!regionOp)
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 1399d419735db9..031930dcfc2131 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
return WalkContinuation::advanceTo(addrCast.getOperand());
- // TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
- if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
- return WalkContinuation::advanceTo(
- {selectOp.getTrueValue(), selectOp.getFalseValue()});
-
// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
getControlFlowPredecessors(val);
diff --git a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
index bd5e7aa996ada7..6b369c50121050 100644
--- a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
@@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
llvm.call @region(%arg0) : (!llvm.ptr) -> ()
llvm.return
}
+
+// -----
+
+// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
+// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
+
+llvm.func @foo(%arg: i32)
+
+llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
+ %cond = llvm.load %arg1 : !llvm.ptr -> i1
+ %1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
+ %selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
+ %2 = llvm.load %selected : !llvm.ptr -> i32
+ llvm.call @foo(%2) : (i32) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: llvm.func @selects
+// CHECK: llvm.load
+// CHECK-NOT: alias_scopes
+// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
+// CHECK: llvm.load
+// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
+llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @foo(%arg: i32)
+
+llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
+ %selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
+ %2 = llvm.load %selected : !llvm.ptr -> i32
+ llvm.call @foo(%2) : (i32) -> ()
+ llvm.return
+}
+
+// CHECK-LABEL: llvm.func @multi_ptr_select
+// CHECK: llvm.load
+// CHECK-NOT: alias_scopes
+// CHECK-NOT: noalias_scopes
+// CHECK: llvm.call @foo
+llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
More information about the Mlir-commits
mailing list