[Mlir-commits] [mlir] [MLIR] Introduce a SelectLikeOpInterface (PR #104751)

Christian Ulmann llvmlistbot at llvm.org
Mon Aug 19 08:42:12 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/104751

>From b1d6cfdcf34dbb502c59e21646ea3aa32c2da7d6 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 19 Aug 2024 09:14:34 +0000
Subject: [PATCH 1/4] [MLIR] Introduce a SelectOpInterface

This commit introduces a `SelectOpInterface` that can be used to handle
select-like operations generically. Select operations are similar to
control flow operations, as they forward operands depending on
conditions. This is the reason why it was placed to the already existing
control flow interfaces.
---
 mlir/include/mlir/Dialect/Arith/IR/Arith.h    |  1 +
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  2 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  3 +-
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 21 ++++++++
 mlir/lib/Analysis/SliceWalk.cpp               |  6 ++-
 .../Transforms/InlinerInterfaceImpl.cpp       |  5 --
 .../Dialect/LLVMIR/inlining-alias-scopes.mlir | 48 +++++++++++++++++++
 7 files changed, 78 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 00cdb13feb29bb..77241319851e6c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 477478a4651cee..cddb3722c3ccff 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
     BooleanConditionOrMatchingShape<"condition", "result">,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
+    DeclareOpInterfaceMethods<SelectOpInterface>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 643522d5903fd0..6230f4d32994e5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
 def LLVM_SelectOp
     : LLVM_Op<"select",
           [Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
-           DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
+           DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+           DeclareOpInterfaceMethods<SelectOpInterface>]>,
       LLVM_Builder<
           "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
   let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 95ac5dea243aa4..7b6191c2332756 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -343,6 +343,27 @@ def RegionBranchTerminatorOpInterface :
   }];
 }
 
+def SelectOpInterface : OpInterface<"SelectOpInterface"> {
+  let description = [{
+    This interface provides information for select-like operations, i.e.,
+    operations that forward specific operands to the output, depending on a
+    condition.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Returns the operand that would be chosen for a false condition.
+      }], "::mlir::Value", "getFalseValue", (ins)>,
+    InterfaceMethod<[{
+        Returns the operand that would be chosen for a true condition.
+      }], "::mlir::Value", "getTrueValue", (ins)>,
+    InterfaceMethod<[{
+        Returns the condition operand.
+      }], "::mlir::Value", "getCondition", (ins)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 9d770639dc53ca..6736f1b73e421f 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
 
 std::optional<SmallVector<Value>>
 mlir::getControlFlowPredecessors(Value value) {
-  SmallVector<Value> result;
   if (OpResult opResult = dyn_cast<OpResult>(value)) {
-    auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
+    if (auto selectOp = opResult.getDefiningOp<SelectOpInterface>())
+      return SmallVector<Value>(
+          {selectOp.getTrueValue(), selectOp.getFalseValue()});
+    auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
     // If the interface is not implemented, there are no control flow
     // predecessors to work with.
     if (!regionOp)
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 1399d419735db9..031930dcfc2131 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
     if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
       return WalkContinuation::advanceTo(addrCast.getOperand());
 
-    // TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
-    if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
-      return WalkContinuation::advanceTo(
-          {selectOp.getTrueValue(), selectOp.getFalseValue()});
-
     // Attempt to advance to control flow predecessors.
     std::optional<SmallVector<Value>> controlFlowPredecessors =
         getControlFlowPredecessors(val);
diff --git a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
index bd5e7aa996ada7..6b369c50121050 100644
--- a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
@@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
   llvm.call @region(%arg0) : (!llvm.ptr) -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
+// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
+
+llvm.func @foo(%arg: i32)
+
+llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
+  %cond = llvm.load %arg1 : !llvm.ptr -> i1
+  %1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
+  %selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
+  %2 = llvm.load %selected : !llvm.ptr -> i32
+  llvm.call @foo(%2) : (i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @selects
+// CHECK: llvm.load
+// CHECK-NOT: alias_scopes
+// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
+// CHECK: llvm.load
+// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
+llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @foo(%arg: i32)
+
+llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
+  %selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
+  %2 = llvm.load %selected : !llvm.ptr -> i32
+  llvm.call @foo(%2) : (i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @multi_ptr_select
+// CHECK: llvm.load
+// CHECK-NOT: alias_scopes
+// CHECK-NOT: noalias_scopes
+// CHECK: llvm.call @foo
+llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}

>From 0163dee10e5784a063c753f7c2877caeb948edcb Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 19 Aug 2024 11:12:41 +0000
Subject: [PATCH 2/4] rename to SelectLikeOpInterface to avoid name clashes

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td        | 2 +-
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td           | 2 +-
 mlir/include/mlir/Interfaces/ControlFlowInterfaces.td | 2 +-
 mlir/lib/Analysis/SliceWalk.cpp                       | 2 +-
 4 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index cddb3722c3ccff..19a5e13a5d755d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1579,7 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
     BooleanConditionOrMatchingShape<"condition", "result">,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
-    DeclareOpInterfaceMethods<SelectOpInterface>,
+    DeclareOpInterfaceMethods<SelectLikeOpInterface>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6230f4d32994e5..71f249fa538ca9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -836,7 +836,7 @@ def LLVM_SelectOp
     : LLVM_Op<"select",
           [Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
            DeclareOpInterfaceMethods<FastmathFlagsInterface>,
-           DeclareOpInterfaceMethods<SelectOpInterface>]>,
+           DeclareOpInterfaceMethods<SelectLikeOpInterface>]>,
       LLVM_Builder<
           "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
   let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 7b6191c2332756..0cc80b4ae58dab 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -343,7 +343,7 @@ def RegionBranchTerminatorOpInterface :
   }];
 }
 
-def SelectOpInterface : OpInterface<"SelectOpInterface"> {
+def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
   let description = [{
     This interface provides information for select-like operations, i.e.,
     operations that forward specific operands to the output, depending on a
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 6736f1b73e421f..817d71a3452caf 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -105,7 +105,7 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
 std::optional<SmallVector<Value>>
 mlir::getControlFlowPredecessors(Value value) {
   if (OpResult opResult = dyn_cast<OpResult>(value)) {
-    if (auto selectOp = opResult.getDefiningOp<SelectOpInterface>())
+    if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>())
       return SmallVector<Value>(
           {selectOp.getTrueValue(), selectOp.getFalseValue()});
     auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();

>From a3adb2a3fcf1c74dbd2e366627a5d35817c9c82a Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 19 Aug 2024 11:13:57 +0000
Subject: [PATCH 3/4] address review comments

---
 mlir/include/mlir/Analysis/SliceWalk.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Analysis/SliceWalk.h b/mlir/include/mlir/Analysis/SliceWalk.h
index 481c5690c533ba..eb9ced2ff63b68 100644
--- a/mlir/include/mlir/Analysis/SliceWalk.h
+++ b/mlir/include/mlir/Analysis/SliceWalk.h
@@ -88,9 +88,9 @@ WalkContinuation walkSlice(mlir::ValueRange rootValues,
                            WalkCallback walkCallback);
 
 /// Computes a vector of all control predecessors of `value`. Relies on
-/// RegionBranchOpInterface and BranchOpInterface to determine predecessors.
-/// Returns nullopt if `value` has no predecessors or when the relevant
-/// operations are missing the interface implementations.
+/// RegionBranchOpInterface, BranchOpInterface, and SelectLikeOpInterface to
+/// determine predecessors. Returns nullopt if `value` has no predecessors or
+/// when the relevant operations are missing the interface implementations.
 std::optional<SmallVector<Value>> getControlFlowPredecessors(Value value);
 
 } // namespace mlir

>From bb63eb5ff62b81c193a405088d3e8b1fe8b9c49d Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 19 Aug 2024 15:41:59 +0000
Subject: [PATCH 4/4] improve comment

---
 .../mlir/Interfaces/ControlFlowInterfaces.td        | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 0cc80b4ae58dab..69bce78e946c83 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -347,7 +347,18 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
   let description = [{
     This interface provides information for select-like operations, i.e.,
     operations that forward specific operands to the output, depending on a
-    condition.
+    binary condition.
+
+    If the value of the condition is 1, then the `true` operand is returned,
+    and the third operand is ignored, even if it was poison.
+
+    If the value of the condition is 0, then the `false` operand is returned,
+    and the second operand is ignored, even if it was poison.
+
+    If the condition is poison, then poison is returned.
+
+    Implementing operations can also accept shaped conditions, in which case
+    the operation works element-wise.
   }];
   let cppNamespace = "::mlir";
 



More information about the Mlir-commits mailing list