[llvm-branch-commits] [mlir] eff9b54 - [MLIR] Introduce IfOp in the xla_lhlo dialect
Uday Bondhugula via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:17 PDT 2021
Author: Uday Bondhugula
Date: 2021-09-23T05:56:13+05:30
New Revision: eff9b542c7e3cbc44faac86046461a982ac9dcdc
URL: https://github.com/llvm/llvm-project/commit/eff9b542c7e3cbc44faac86046461a982ac9dcdc
DIFF: https://github.com/llvm/llvm-project/commit/eff9b542c7e3cbc44faac86046461a982ac9dcdc.diff
LOG: [MLIR] Introduce IfOp in the xla_lhlo dialect
Introduce LHLO IfOp to model conditionals on the memref form. This is
lowered form form of HLO IfOp (the latter operates on tensors). Its
design is similar to that of the LHLO WhileOp, taking in tuples of
memrefs or elemental types. The true and the false bodies return a tuple
of the same type.
Added:
Modified:
mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
mlir/test/Dialect/LHLO/lhlo_ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
index 752992761485..711549940570 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
@@ -467,6 +467,53 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
);
}
+def LHLO_IfOp : LHLO_Op<"if", [AffineScope, RecursiveSideEffects]> {
+ string summary = "If operator";
+
+ string description = [{
+ Returns the result of executing either a true or false function depending on
+ the result of a condition function. In contrast to the HLO version, the
+ tuple operands for the true or false branch are a tuple of memrefs or int/fp
+ types. Both the true and false branches also return such a tuple type: they
+ both return the same type and this match the result type of the op.
+
+ Example:
+
+ ```mlir
+ func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) {
+ %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
+ %1 = "xla_lhlo.if"(%arg2, %0, %0) ( {
+ ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):
+ %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+ %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+ "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+ }, {
+ ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>): // no predecessors
+ %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+ %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+ "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+ }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>>
+ "xla_lhlo.terminator"() : () -> ()
+ }
+ ```
+
+ See https://www.tensorflow.org/xla/operation_semantics#conditional.
+ }];
+
+ let arguments = (ins Arg<LHLO_TupleOfBufferOrIntOrFP> : $init);
+
+ let arguments = (ins
+ LHLO_PredBufferOrI1:$pred,
+ LHLO_TupleOfBufferOrIntOrFP:$true_arg,
+ LHLO_TupleOfBufferOrIntOrFP:$false_arg
+ );
+
+ let regions = (region AnyRegion:$true_branch,
+ AnyRegion:$false_branch);
+
+ let results = (outs Arg<LHLO_TupleOfBufferOrIntOrFP>);
+}
+
def LHLO_MapOp : LHLO_Op<"map", [RecursiveSideEffects, SameOperandsShape]>,
BASE_HLO_MapOp {
let description = [{
diff --git a/mlir/test/Dialect/LHLO/lhlo_ops.mlir b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
index 30a84c5b4bcb..59ccfc8113be 100644
--- a/mlir/test/Dialect/LHLO/lhlo_ops.mlir
+++ b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
@@ -235,3 +235,40 @@ func @while_op(%arg0: memref<4x?x16xf32>, %arg1: memref<4x?x16xf32>) {
}) : (tuple<i32, memref<4xi32>>) -> tuple<i32, memref<4xi32>>
"xla_lhlo.terminator"() : () -> ()
}
+
+// -----
+
+func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) {
+ %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
+ // CHECK: xla_lhlo.if
+ %1 = "xla_lhlo.if"(%arg2, %0, %0) ( {
+ ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):
+ %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+ %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+ "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+ }, {
+ ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>): // no predecessors
+ %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+ %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+ "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+ }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>>
+ "xla_lhlo.terminator"() : () -> ()
+}
+
+// CHECK-LABEL: func @lhlo_if_empty_arg
+func @lhlo_if_empty_arg(%arg0: memref<i1>) {
+ %cst = constant 1.000000e+00 : f32
+ %cst_0 = constant 0.000000e+00 : f32
+ %0 = "xla_lhlo.tuple"() : () -> tuple<>
+ // CHECK: xla_lhlo.if
+ %1 = "xla_lhlo.if"(%arg0, %0, %0) ( {
+ ^bb0(%arg1: tuple<>):
+ %2 = "xla_lhlo.tuple"(%cst, %cst_0) : (f32, f32) -> tuple<f32, f32>
+ "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> ()
+ }, {
+ ^bb0(%arg1: tuple<>):
+ %2 = "xla_lhlo.tuple"(%cst_0, %cst) : (f32, f32) -> tuple<f32, f32>
+ "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> ()
+ }) : (memref<i1>, tuple<>, tuple<>) -> tuple<f32, f32>
+ "xla_lhlo.terminator"() : () -> ()
+}
More information about the llvm-branch-commits
mailing list