[Mlir-commits] [mlir] [MLIR] Introduce a SelectOpInterface (PR #104751)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 19 02:29:06 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit introduces a `SelectOpInterface` that can be used to handle select-like operations generically. Select operations are similar to control flow operations, as they forward operands depending on conditions. This is the reason why it was placed to the already existing control flow interfaces.

Note that this will be also very interesting for the dataflow analysis to use.

---
Full diff: https://github.com/llvm/llvm-project/pull/104751.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/IR/Arith.h (+1) 
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+2) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+2-1) 
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+21) 
- (modified) mlir/lib/Analysis/SliceWalk.cpp (+4-2) 
- (modified) mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp (-5) 
- (modified) mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir (+48) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 00cdb13feb29bb..77241319851e6c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 477478a4651cee..cddb3722c3ccff 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
     BooleanConditionOrMatchingShape<"condition", "result">,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
+    DeclareOpInterfaceMethods<SelectOpInterface>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 643522d5903fd0..6230f4d32994e5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
 def LLVM_SelectOp
     : LLVM_Op<"select",
           [Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
-           DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
+           DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+           DeclareOpInterfaceMethods<SelectOpInterface>]>,
       LLVM_Builder<
           "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
   let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 95ac5dea243aa4..7b6191c2332756 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -343,6 +343,27 @@ def RegionBranchTerminatorOpInterface :
   }];
 }
 
+def SelectOpInterface : OpInterface<"SelectOpInterface"> {
+  let description = [{
+    This interface provides information for select-like operations, i.e.,
+    operations that forward specific operands to the output, depending on a
+    condition.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Returns the operand that would be chosen for a false condition.
+      }], "::mlir::Value", "getFalseValue", (ins)>,
+    InterfaceMethod<[{
+        Returns the operand that would be chosen for a true condition.
+      }], "::mlir::Value", "getTrueValue", (ins)>,
+    InterfaceMethod<[{
+        Returns the condition operand.
+      }], "::mlir::Value", "getCondition", (ins)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 9d770639dc53ca..6736f1b73e421f 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
 
 std::optional<SmallVector<Value>>
 mlir::getControlFlowPredecessors(Value value) {
-  SmallVector<Value> result;
   if (OpResult opResult = dyn_cast<OpResult>(value)) {
-    auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
+    if (auto selectOp = opResult.getDefiningOp<SelectOpInterface>())
+      return SmallVector<Value>(
+          {selectOp.getTrueValue(), selectOp.getFalseValue()});
+    auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
     // If the interface is not implemented, there are no control flow
     // predecessors to work with.
     if (!regionOp)
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 1399d419735db9..031930dcfc2131 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
     if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
       return WalkContinuation::advanceTo(addrCast.getOperand());
 
-    // TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
-    if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
-      return WalkContinuation::advanceTo(
-          {selectOp.getTrueValue(), selectOp.getFalseValue()});
-
     // Attempt to advance to control flow predecessors.
     std::optional<SmallVector<Value>> controlFlowPredecessors =
         getControlFlowPredecessors(val);
diff --git a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
index bd5e7aa996ada7..6b369c50121050 100644
--- a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
@@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
   llvm.call @region(%arg0) : (!llvm.ptr) -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
+// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
+
+llvm.func @foo(%arg: i32)
+
+llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
+  %cond = llvm.load %arg1 : !llvm.ptr -> i1
+  %1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
+  %selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
+  %2 = llvm.load %selected : !llvm.ptr -> i32
+  llvm.call @foo(%2) : (i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @selects
+// CHECK: llvm.load
+// CHECK-NOT: alias_scopes
+// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
+// CHECK: llvm.load
+// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
+llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @foo(%arg: i32)
+
+llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
+  %selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
+  %2 = llvm.load %selected : !llvm.ptr -> i32
+  llvm.call @foo(%2) : (i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @multi_ptr_select
+// CHECK: llvm.load
+// CHECK-NOT: alias_scopes
+// CHECK-NOT: noalias_scopes
+// CHECK: llvm.call @foo
+llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/104751


More information about the Mlir-commits mailing list