[llvm-branch-commits] [mlir] eff9b54 - [MLIR] Introduce IfOp in the xla_lhlo dialect

Uday Bondhugula via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:17 PDT 2021


Author: Uday Bondhugula
Date: 2021-09-23T05:56:13+05:30
New Revision: eff9b542c7e3cbc44faac86046461a982ac9dcdc

URL: https://github.com/llvm/llvm-project/commit/eff9b542c7e3cbc44faac86046461a982ac9dcdc
DIFF: https://github.com/llvm/llvm-project/commit/eff9b542c7e3cbc44faac86046461a982ac9dcdc.diff

LOG: [MLIR] Introduce IfOp in the xla_lhlo dialect

Introduce LHLO IfOp to model conditionals on the memref form.  This is
lowered form form of HLO IfOp (the latter operates on tensors).  Its
design is similar to that of the LHLO WhileOp, taking in tuples of
memrefs or elemental types. The true and the false bodies return a tuple
of the same type.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
    mlir/test/Dialect/LHLO/lhlo_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
index 752992761485..711549940570 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
@@ -467,6 +467,53 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
   );
 }
 
+def LHLO_IfOp : LHLO_Op<"if", [AffineScope, RecursiveSideEffects]> {
+  string summary = "If operator";
+
+  string description = [{
+    Returns the result of executing either a true or false function depending on
+    the result of a condition function. In contrast to the HLO version, the
+    tuple operands for the true or false branch are a tuple of memrefs or int/fp
+    types. Both the true and false branches also return such a tuple type: they
+    both return the same type and this match the result type of the op.
+
+    Example:
+
+    ```mlir
+    func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) {
+      %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
+      %1 = "xla_lhlo.if"(%arg2, %0, %0) ( {
+        ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):
+        %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+        %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+        "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+      },  {
+        ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):  // no predecessors
+        %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+        %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+        "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+      }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>>
+      "xla_lhlo.terminator"() : () -> ()
+    }
+    ```
+
+    See https://www.tensorflow.org/xla/operation_semantics#conditional.
+  }];
+
+  let arguments = (ins Arg<LHLO_TupleOfBufferOrIntOrFP> : $init);
+
+  let arguments = (ins
+    LHLO_PredBufferOrI1:$pred,
+    LHLO_TupleOfBufferOrIntOrFP:$true_arg,
+    LHLO_TupleOfBufferOrIntOrFP:$false_arg
+  );
+
+  let regions = (region AnyRegion:$true_branch,
+                        AnyRegion:$false_branch);
+
+  let results = (outs Arg<LHLO_TupleOfBufferOrIntOrFP>);
+}
+
 def LHLO_MapOp : LHLO_Op<"map", [RecursiveSideEffects, SameOperandsShape]>,
                  BASE_HLO_MapOp {
   let description = [{

diff  --git a/mlir/test/Dialect/LHLO/lhlo_ops.mlir b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
index 30a84c5b4bcb..59ccfc8113be 100644
--- a/mlir/test/Dialect/LHLO/lhlo_ops.mlir
+++ b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
@@ -235,3 +235,40 @@ func @while_op(%arg0: memref<4x?x16xf32>, %arg1: memref<4x?x16xf32>) {
     }) : (tuple<i32, memref<4xi32>>) -> tuple<i32, memref<4xi32>>
     "xla_lhlo.terminator"() : () -> ()
 }
+
+// -----
+
+func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) {
+  %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
+  // CHECK: xla_lhlo.if
+  %1 = "xla_lhlo.if"(%arg2, %0, %0) ( {
+    ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):
+    %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+    %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+    "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+  },  {
+    ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):  // no predecessors
+    %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+    %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+    "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+  }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>>
+  "xla_lhlo.terminator"() : () -> ()
+}
+
+// CHECK-LABEL: func @lhlo_if_empty_arg
+func @lhlo_if_empty_arg(%arg0: memref<i1>) {
+  %cst = constant 1.000000e+00 : f32
+  %cst_0 = constant 0.000000e+00 : f32
+  %0 = "xla_lhlo.tuple"() : () -> tuple<>
+  // CHECK: xla_lhlo.if
+  %1 = "xla_lhlo.if"(%arg0, %0, %0) ( {
+    ^bb0(%arg1: tuple<>):
+    %2 = "xla_lhlo.tuple"(%cst, %cst_0) : (f32, f32) -> tuple<f32, f32>
+    "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> ()
+  },  {
+    ^bb0(%arg1: tuple<>):
+    %2 = "xla_lhlo.tuple"(%cst_0, %cst) : (f32, f32) -> tuple<f32, f32>
+    "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> ()
+  }) : (memref<i1>, tuple<>, tuple<>) -> tuple<f32, f32>
+  "xla_lhlo.terminator"() : () -> ()
+}


        


More information about the llvm-branch-commits mailing list