[llvm-branch-commits] [mlir] c6b7f22 - [MLIR] Add xla_lhlo dialect from tensorflow

Uday Bondhugula via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:08 PDT 2021


Author: Uday Bondhugula
Date: 2021-09-22T14:23:46+05:30
New Revision: c6b7f22240229f69af9ebb75d4b7a9b5f1dd0da1

URL: https://github.com/llvm/llvm-project/commit/c6b7f22240229f69af9ebb75d4b7a9b5f1dd0da1
DIFF: https://github.com/llvm/llvm-project/commit/c6b7f22240229f69af9ebb75d4b7a9b5f1dd0da1.diff

LOG: [MLIR] Add xla_lhlo dialect from tensorflow

Add xla_lhlo dialect from tensorflow (the version used by the right
branch of the private repo). This commit brings in just the LHLO ops
part of it. So, LHLO transforms or any conversions out of LHLO are not
included. This is to mainly to allow the use of xla_lhlo dialect ops
along with MLIR.  So, the users would be aware of LHLO ops (make use
of its accessors) as the xla_lhlo dialect would be registered. This is a
temporary arrangement before we can migrate to depending on mlir-hlo.
Depending on tensorflow just for xla_lhlo isn't really feasible.

Added: 
    mlir/include/mlir/Dialect/LHLO/CMakeLists.txt
    mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/LHLO/IR/HLOOpsBase.td
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
    mlir/lib/Dialect/LHLO/CMakeLists.txt
    mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
    mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
    mlir/test/Dialect/LHLO/lhlo_ops.mlir

Modified: 
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 44a9249cef839..60e31ed7cac55 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(EmitC)
 add_subdirectory(GPU)
 add_subdirectory(Math)
 add_subdirectory(Linalg)
+add_subdirectory(LHLO)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(OpenACC)

diff  --git a/mlir/include/mlir/Dialect/LHLO/CMakeLists.txt b/mlir/include/mlir/Dialect/LHLO/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LHLO/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..f5c4548ccf534
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_dialect(LHLOOps xla_lhlo)
+
+set(LLVM_TARGET_DEFINITIONS LHLOOps.td)
+mlir_tablegen(LHLOStructs.h.inc -gen-struct-attr-decls)
+mlir_tablegen(LHLOStructs.cpp.inc -gen-struct-attr-defs)
+add_public_tablegen_target(MLIRLHLOStructsGen)
+
+add_mlir_doc(LHLOOps -gen-op-doc LHLOOps Dialects/)

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/HLOOpsBase.td b/mlir/include/mlir/Dialect/LHLO/IR/HLOOpsBase.td
new file mode 100644
index 0000000000000..9312ed9e96f9f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LHLO/IR/HLOOpsBase.td
@@ -0,0 +1,1260 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef HLO_OPS_BASE
+#define HLO_OPS_BASE
+
+include "mlir/IR/OpBase.td"
+
+def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
+
+// TODO(hinsu): Use signed integers instead of signless integer which is being
+// used for legacy reasons.
+def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
+def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
+def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
+def HLO_I1 : SignlessIntOfWidths<[1]>;
+
+def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
+
+// The broadcasting dimensions correspond to a tuple that describes how a
+// smaller rank shape is broadcast into a larger rank shape. For example,
+// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
+// matching the matrix to dimensions 1 and 2 of the cuboid.
+defvar BroadcastDimAttr = I64ElementsAttr;
+
+//===----------------------------------------------------------------------===//
+// XLA on tensors type definitions.
+//===----------------------------------------------------------------------===//
+
+// Token type.
+def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
+
+// Any integer tensor types
+def HLO_IntTensor : TensorOf<[HLO_Int]>;
+
+// Any integer tensor type with rank 0 (i.e. representing a single integer).
+def HLO_ScalarIntTensor : ShapedContainerType<
+  [HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>,
+  "a 0-dim integer tensor">;
+
+// Any floating-point tensor types
+def HLO_FpTensor : TensorOf<[AnyFloat]>;
+
+def HLO_PredTensor : TensorOf<[HLO_Pred]>;
+
+def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
+
+def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
+
+def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
+
+def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
+
+def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
+
+// Dynamic representation of a shape vector as a tensor.
+def HLO_DimensionTensor : ShapedContainerType<
+    [HLO_DimensionValue],
+    And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
+    "a 1D tensor of dimensions">;
+
+// In general, static shaped tensor constraints should be avoided unless
+// it is for a legacy op which is only correct with static shapes.
+def HLO_StaticShapeTensor : StaticShapeTensorOf<[
+      AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
+
+//===----------------------------------------------------------------------===//
+// XLA on tensors combined type definitions.
+//===----------------------------------------------------------------------===//
+
+// Any integer or floating-point tensor types
+def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
+
+// Any integer or predicate tensor types
+def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
+
+// Any floating-point or complex tensor types
+def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>;
+
+// Any int, floating-point or complex tensor types
+def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
+
+// Any pred, int or floating-point tensor types
+def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
+
+//===----------------------------------------------------------------------===//
+// XLA nullary op definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_ConstOp {
+  string summary = "Constant operator";
+
+  string description = [{
+    Represents a constant value.
+  }];
+}
+
+class BASE_HLO_IotaOp {
+  string summary = "Iota operator";
+
+  string description = [{
+    Creates a rank 1 array of values starting at zero and incrementing by one.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA unary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
+
+class BASE_HLO_AbsOp {
+  string summary = "Absolute value operator";
+
+  string description = [{
+    Returns `abs(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_CeilOp {
+  string summary = "Ceil operator";
+
+  string description = [{
+    Returns `Ceil(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_ClzOp {
+  string summary = "Count-leading-zeros (Clz) operator";
+
+  string description = [{
+    Returns the number of leading zeros in each operand element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_ConvertOp {
+  string summary = "Convert operator";
+
+  string description = [{
+    Performs element-wise conversion of values from one type to another, e.g.
+    float to int.
+
+    See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
+  }];
+}
+
+class BASE_HLO_CosOp {
+  string summary = "Cos operator";
+
+  string description = [{
+    Returns `Cos(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_ExpOp {
+  string summary = "Exponential operator";
+
+  string description = [{
+    Returns `e^(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_Expm1Op {
+  string summary = "Exponential minus one operator";
+
+  string description = [{
+    Returns `e^(operand) - 1` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_FloorOp {
+  string summary = "Floor operator";
+
+  string description = [{
+    Returns `Floor(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_GetDimensionSizeOp {
+  string summary = "GetDimensionSize operator";
+
+  string description = [{
+    Returns the size of the given dimension of the operand.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#getdimensionsize.
+  }];
+}
+
+class BASE_HLO_ImagOp {
+  string summary = "Imag operator";
+
+  string description = [{
+    Returns `Imag(operand)` element-wise.
+  }];
+}
+
+class BASE_HLO_IsFiniteOp {
+  string summary = "IsFinite operator";
+
+  string description = [{
+    Tests whether each element of operand is finite, i.e., is not positive or
+    negative infinity, and is not NaN. Returns a tensor of 1-bit integers with
+    the same shape as the input, where each element is nonzero (i.e. true) if
+    and only if the corresponding input element is finite.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_LogOp {
+  string summary = "Logarithm operator";
+
+  string description = [{
+    Returns `log(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_Log1pOp {
+  string summary = "Log1p operator";
+
+  string description = [{
+    Returns `log(operand+1)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_NegOp {
+  string summary = "Negation operator";
+
+  string description = [{
+    Returns `-operand` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_NotOp {
+  string summary = "Not operator";
+
+  string description = [{
+    Returns `!operand` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_PopulationCountOp {
+  string summary = "PopulationCount operator";
+
+  string description = [{
+    Returns the number of bits set in each operand element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_RealOp {
+  string summary = "Real operator";
+
+  string description = [{
+    Returns `Real(operand)` element-wise.
+  }];
+}
+
+class BASE_HLO_RoundOp {
+  string summary = "Round operator";
+
+  string description = [{
+    Returns `Round(operand)` element-wise, rounding to nearest integer with
+    half-way cases rounding away from zero.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_RsqrtOp {
+  string summary = "Reciprocal Square-root operator";
+
+  string description = [{
+    Returns `1.0 / sqrt(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_SignOp {
+  string summary = "Sign operator";
+
+  string description = [{
+    Returns `sign(operand)` element-wise, where
+
+    ```
+    sign(x) = -1  : x < 0
+            = -0  : x = -0
+            = NaN : x = NaN
+            = +0  : x = +0
+            = 1   : x > 0
+    ```
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_SinOp {
+  string summary = "Sin operator";
+
+  string description = [{
+    Returns `Sin(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_SqrtOp {
+  string summary = "Square-root operator";
+
+  string description = [{
+    Returns `sqrt(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+class BASE_HLO_TanhOp {
+  string summary = "Tanh operator";
+
+  string description = [{
+    Returns `tanh(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA binary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_AddOp {
+  string summary = "Addition operator";
+
+  string description = [{
+    Returns `lhs + rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_ComplexOp {
+  string summary = "Complex operator";
+
+  string description = [{
+    Performs element-wise conversion of a pair of real and imaginary values to
+    a complex value.
+  }];
+}
+
+class BASE_HLO_DivOp {
+  string summary = "Division operator";
+
+  string description = [{
+    Returns `lhs / rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_MaxOp {
+  string summary = "Maximum operator";
+
+  string description = [{
+    Returns `max(lhs, rhs)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_MinOp {
+  string summary = "Minimum operator";
+
+  string description = [{
+    Returns `min(lhs, rhs)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_MulOp {
+  string summary = "Multiplication operator";
+
+  string description = [{
+    Returns `lhs * rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+class BASE_HLO_PowOp {
+  string summary = "Power operator";
+
+  string description = [{
+    Returns `lhs ^ rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_RemOp {
+  string summary = "Remainder operator";
+
+  string description = [{
+    Returns `lhs % rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_SubOp {
+  string summary = "Subtraction operator";
+
+  string description = [{
+    Returns `lhs - rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_ShiftLeftOp {
+  string summary = "Shift Left operator";
+
+  string description = [{
+    Returns `lhs << rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_ShiftRightArithmeticOp {
+  string summary = "Shift right arithmetic operator";
+
+  string description = [{
+    Returns arithmetic `lhs >> rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_ShiftRightLogicalOp {
+  string summary = "Shift right logical operator";
+
+  string description = [{
+    Returns logical `lhs >> rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_Atan2Op {
+  string summary = "Atan2 operator";
+
+  string description = [{
+    Returns `atan2(lhs/rhs)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_AndOp {
+  string summary = "Logical and";
+
+  string description = [{
+    Returns `logical_and(lhs, rhs)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_OrOp {
+  string summary = "Logical or";
+
+  string description = [{
+    Returns `logical_or(lhs, rhs)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+class BASE_HLO_XorOp {
+  string summary = "Logical xor";
+
+  string description = [{
+    Returns `logical_xor(lhs, rhs)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA control flow related op definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_CaseOp {
+  string summary = "Switch-Case operator";
+
+  string description = [{
+    Returns the result of executing `branches[index]`. If
+    `index` is < 0 or >= N, then `branches[N-1] is executed as
+    the default branch.
+
+    Each branch `branches[b]` must take in a single argument of same type as
+    `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type
+    of the returned value of each branch must be the same.
+
+    Note that only one of the branches will be executed depending on the value
+    of index.
+    See https://www.tensorflow.org/xla/operation_semantics#conditional.
+  }];
+
+}
+
+//===----------------------------------------------------------------------===//
+// XLA parallelism related op definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_ReplicaIdOp {
+  string summary = "ReplicaId operator";
+
+  string description = [{
+    Returns the unique ID (int32 scalar) of the replica.
+
+    The unique ID of each replica is an unsigned integer in the interval [0, N),
+    where N is the number of replicas. Since all the replicas are running the
+    same program, a ReplicaId() call in the program will return a 
diff erent
+    value on each replica.
+
+    See https://www.tensorflow.org/xla/operation_semantics#replicaid.
+  }];
+}
+
+
+class BASE_HLO_AllReduceOp {
+  string summary = "AllReduce operator";
+
+  string description = [{
+    Performs a custom reduction across replicas.
+
+    See https://www.tensorflow.org/xla/operation_semantics#allreduce.
+  }];
+}
+
+class BASE_HLO_ReduceOp {
+  string summary = "Reduce operator";
+
+  string description = [{
+    Returns the result of executing a reduction function on one or more arrays
+    in parallel.
+
+    See https://www.tensorflow.org/xla/operation_semantics#reduce.
+  }];
+}
+
+class BASE_HLO_ReduceWindowOp {
+  string summary = "ReduceWindow operator";
+
+  string description = [{
+    Returns the result of executing a reduction function over all elements in
+    each window of one or more arrays in parallel.
+
+    See https://www.tensorflow.org/xla/operation_semantics#reducewindow.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA tuple op definitions.
+//===----------------------------------------------------------------------===//
+class BASE_HLO_GetTupleElementOp {
+  string summary = "GetTupleElement operator";
+
+  string description = [{
+    Returns a member of a tuple specified by an index.
+
+    See https://www.tensorflow.org/xla/operation_semantics#gettupleelement.
+  }];
+}
+
+class BASE_HLO_TupleOp {
+   string summary = "XLA's tuple op";
+
+   string description = [{
+     Groups a set of tensor inputs into a single tuple object.
+
+     See https://www.tensorflow.org/xla/operation_semantics#tuple.
+   }];
+}
+
+//===----------------------------------------------------------------------===//
+// Precision Config enum definitions.
+//===----------------------------------------------------------------------===//
+
+// These mirror the XLA PrecisionConfig proto enum.
+def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
+def HLO_PRECISION_HIGH    : StrEnumAttrCase<"HIGH">;
+def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
+
+def HLO_PrecisionAttr : StrEnumAttr<"Precision",
+    "XLA precision for an operand. Has backend specific meaning.",
+    [HLO_PRECISION_DEFAULT,  HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>;
+
+// TODO(b/129153247) See if it's possible to also validate the size.
+def HLO_PrecisionConfigAttr:
+    OptionalAttr<
+          TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
+
+//===----------------------------------------------------------------------===//
+// Fast Fourier Transform Type enum definitions.
+//===----------------------------------------------------------------------===//
+
+// These mirror the XLA FftType proto enum.
+def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
+def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
+def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
+def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
+
+def HLO_FftTypeAttr : StrEnumAttr<"FftType",
+    "XLA fast fourier transform type.",
+    [HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
+     HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]>;
+
+//===----------------------------------------------------------------------===//
+// Comparison op definitions.
+//===----------------------------------------------------------------------===//
+
+// These mirror the XLA ComparisonDirection enum.
+def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
+def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
+def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
+def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
+def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
+def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
+
+def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
+    "Which comparison operation to perform.",
+    [
+      HLO_COMPARISON_DIRECTION_EQ,
+      HLO_COMPARISON_DIRECTION_NE,
+      HLO_COMPARISON_DIRECTION_GE,
+      HLO_COMPARISON_DIRECTION_GT,
+      HLO_COMPARISON_DIRECTION_LE,
+      HLO_COMPARISON_DIRECTION_LT
+    ]>;
+
+class BASE_HLO_CompareOp {
+  string summary = "Comparison operator";
+
+  string description = [{
+    Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Quantize op definitions.
+//===----------------------------------------------------------------------===//
+
+// These mirror the XLA ComparisonDirection enum.
+def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
+
+def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
+  "Dequantization mode. Only MIN_COMBINED is supported.",
+  [HLO_MIN_COMBINED]>;
+
+class BASE_HLO_DequantizeOp {
+  string summary = "Dequantize operator";
+
+  string description = [{
+    Dequantize the quantized input of packed uint32 to bfloat16. Only uint8 or
+    uint16 is supported for the original unpacked input.
+
+    Returns a tensor of shape [d0,..., dn * unpack_size] if unpacked input shape
+    is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T), where T is
+    the unpacked input type. If transpose_output is true, will return a tensor
+    of shape [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster
+    when input's rank higher than 1. The input needs to be transposed to use
+    transpose_output feature.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Slice definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_SliceOp {
+  string summary = "Slice operator";
+
+  string description = [{
+    Slices a portion of the `operand` into a new configuration.
+
+    See https://www.tensorflow.org/xla/operation_semantics#slice.
+  }];
+}
+
+class BASE_HLO_DynamicSliceOp {
+  string summary = "Dynamic Slice operator";
+
+  string description = [{
+    Extracts a sub-array from the input array at dynamic start_indices.
+
+    See https://www.tensorflow.org/xla/operation_semantics#dynamicslice.
+  }];
+}
+
+class BASE_HLO_DynamicUpdateSliceOp {
+  string summary = "Dynamic Update Slice operator";
+
+  string description = [{
+    DynamicUpdateSlice generates a result which is the value of the input array
+    operand, with a slice update overwritten at start_indices.
+
+    See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Other op definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_AllToAllOp {
+  string summary = "AllToAll";
+
+  string description = [{
+    AllToAll is a collective operation that sends data from all cores to all
+    cores. It has two phases:
+    - The scatter phase. On each core, the operand is split into `split_count`
+      number of blocks along the `split_dimension`, and the blocks are
+      scattered to all cores, e.g., the i-th block is sent to the i-th core.
+    - The gather phase. Each core concatenates the received blocks along the
+      `concat_dimension`.
+
+    The participating cores can be configured by:
+    - replica_groups: each ReplicaGroup contains a list of replica id
+      participating in the computation (replica id for the current replica can
+      be retrieved using ReplicaId op). AllToAll will be applied within
+      subgroups in the specified order. For example,
+      `replica_groups` = {{1,2,3}, {4,5,0}} means that an AllToAll will be applied
+      within replicas {1, 2, 3}, and in the gather phase, the received blocks
+      will be concatenated in the same order of 1, 2, 3. Then, another AllToAll
+      will be applied within replicas 4, 5, 0, and the concatenation order is
+      also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one
+      group, and the concatenation order is the numerical order (0, 1, 2, ...).
+
+    Prerequisites:
+    - The dimension size of the operand on the split_dimension is divisible by
+      `split_count`.
+    - The operand's shape is not tuple.
+
+    See https://www.tensorflow.org/xla/operation_semantics#alltoall
+  }];
+}
+
+class BASE_HLO_BatchNormGradOp {
+  string summary = "Batch Normalization Gradient";
+
+  string description = [{
+    Calculates gradients of batch norm.
+
+    See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad
+  }];
+}
+
+class BASE_HLO_BatchNormInferenceOp {
+  string summary = "Batch Normalization for Inference";
+
+  string description = [{
+    Normalizes an array across batch and spatial dimensions.
+
+    See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
+  }];
+}
+
+class BASE_HLO_BatchNormTrainingOp {
+  string summary = "Batch Normalization for Training";
+
+  string description = [{
+    Normalizes an array across batch and spatial dimensions.
+
+    See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
+  }];
+}
+
+class BASE_HLO_BitcastConvertOp {
+  string summary = "BitcastConvert operator";
+
+  string description = [{
+    Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
+    operation from a data shape to a target shape. The dimensions must match,
+    and the conversion is an element-wise one. Bitcast is implemented as a
+    low-level cast, so machines with 
diff erent floating-point representations
+    will give 
diff erent results.
+
+    See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
+  }];
+}
+
+class BASE_HLO_BroadcastOp  {
+  string summary = "Broadcast a tensor to a higher rank by prepending dimensions";
+
+  string description = [{
+    Broadcasts the operand tensor to a higher rank by prepending
+    `broadcast_sizes` to the dimensions. The current values of the operand are
+    copied into the other dimensions.
+
+    This is a more limited form of broadcasting, that corresponds to the XLA
+    client Broadcast method. For a more general form of broadcasting, see the
+    BroadcastInDimOp.
+
+    See https://www.tensorflow.org/xla/operation_semantics#broadcast.
+  }];
+}
+
+class BASE_HLO_BroadcastInDimOp  {
+  string summary = "Broadcast a tensor into the given shape by adding dimensions.";
+
+  string description = [{
+    Broadcasts the `operand` tensor to a higher rank. This is not the limited
+    form of broadcasting exposed as the XLA client broadcast op, but rather the
+    more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
+    and exposed in the XLA client BroadcastInDim method.
+
+    `broadcast_dimensions` maps the operand dimension number to the target shape
+    dimension number. It must have the same size as the rank of the operand. The
+    mapped dimensions must either be the same size or the dimension being
+    broadcast from must be size 1 (degenerate broadcasting).
+
+    For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
+    The scalar value will be broadcast to every element in the target shape.
+
+    See https://www.tensorflow.org/xla/broadcasting.
+  }];
+}
+
+class BASE_HLO_CholeskyOp {
+  string summary = "Cholesky operator";
+
+  string description = [{
+  Computes the Cholesky decomposition of a batch of symmetric (Hermitian)
+  positive definite matrices.
+
+  If lower is true, computes lower-triangular matrices l such that
+  `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such
+  that `a=Transpose(u).u`.
+
+  Input data is read only from the lower/upper triangle of a, depending on the
+  value of lower. Values from the other triangle are ignored. Output data is
+  returned in the same triangle; the values in the other triangle are
+  implementation-defined and may be anything.
+
+  If the rank of a is greater than 2, a is treated as a batch of matrices, where
+  all except the minor 2 dimensions are batch dimensions.
+
+  If a is not symmetric (Hermitian) positive definite, the result is
+  implementation-defined.
+
+    See https://www.tensorflow.org/xla/operation_semantics#cholesky.
+  }];
+}
+
+class BASE_HLO_ClampOp  {
+  string summary = "Clamp operator";
+
+  string description = [{
+    Clamps an operand to within the range between a minimum and maximum value.
+
+    Note: All three arrays must be the same shape. Alternatively, as a
+          restricted form of broadcasting, min and/or max can be a scalar (0D
+          tensor) of the element type of the tensor operand.
+
+    See https://www.tensorflow.org/xla/operation_semantics#clamp.
+  }];
+}
+
+class BASE_HLO_CollectivePermuteOp {
+  string summary = "CollectivePermute operator";
+
+  string description = [{
+    CollectivePermute is a collective operation that sends and receives data
+    cross replicas.
+    Note that there are the following restrictions on the source_target_pair:
+    - Any two pairs should not have the same target replica id, and they should
+    not have the same source replica id.
+    - If a replica id is not a target in any pair, then the output on that
+    replica is a tensor consists of 0(s) with the same shape as the input.
+
+    See https://www.tensorflow.org/xla/operation_semantics#collectivepermute.
+
+  }];
+}
+class BASE_HLO_ConcatenateOp {
+   string summary = "XLA's concatenate op";
+
+   string description = [{
+     Concatenates a set of tensors along the specified dimension.
+
+     See https://www.tensorflow.org/xla/operation_semantics#concatenate.
+   }];
+}
+
+class BASE_HLO_ConvOp {
+  string summary = "Convolution operator";
+
+  string description = [{
+    Computes a convolution of the kind used in neural networks.
+
+    See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
+  }];
+}
+
+class BASE_HLO_CopyOp {
+  string summary = "Copy operator";
+
+  string description = [{
+    Returns a copy of `operand`.
+  }];
+}
+
+class BASE_HLO_CrossReplicaSumOp {
+   string summary = "Sums input across replicated instances.";
+
+   string description = [{
+     For each of the replica groups, operands of the group devices are summed
+     so that each device has the sum.
+
+     For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`.
+     Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
+     and `B, D, F, H` as group 1. Thus we get the outputs:
+     `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
+
+     See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum.
+   }];
+}
+
+
+class BASE_HLO_CustomCallOp {
+  string summary = "CustomCall operator";
+
+  string description = [{
+    A custom call invokes code external to XLA. The `args` are passed to the
+    external code, and the external code is expected to produce a result of the
+    given type. The exact mechanism is backend-specific. For example, in the CPU
+    backend, a call instruction is emitted which targets a symbol with the name
+    `call_target_name`.
+
+    `call_target_name` and `backend_config` can be arbitrary strings, but
+    `call_target_name` should be short as it may be used in labels.
+    `backend_config` can encode arbitrarily large amounts of information.
+
+    See https://www.tensorflow.org/xla/operation_semantics#customcall.
+  }];
+}
+
+class BASE_HLO_DotOp {
+  string summary = "Dot operator";
+  string description = [{
+    Performs dot products between vectors, vector/matrix and matrix/matrix
+    multiplication.
+
+    See https://www.tensorflow.org/xla/operation_semantics#dot.
+  }];
+}
+
+class BASE_HLO_DotGeneralOp {
+  string summary = "General Dot operator";
+  string description = [{
+    Performs general dot products between vectors, vector/matrix and
+    matrix/matrix multiplication.
+
+    See https://www.tensorflow.org/xla/operation_semantics#dotgeneral.
+  }];
+}
+
+class BASE_HLO_FftOp {
+  string summary = "Fast fourier transform operator";
+
+  string description = [{
+    Returns the fast-fourier-transform of the input array.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#fft.
+  }];
+}
+
+class BASE_HLO_GatherOp{
+  string summary = "Gather operator";
+
+  string description = [{
+    Stitches together several slices of an input array.
+
+    See https://www.tensorflow.org/xla/operation_semantics#gather.
+  }];
+}
+
+class BASE_HLO_MapOp {
+  string summary = "Map operator";
+
+  string description = [{
+  Applies a scalar function over the given operands arrays, producing an array
+  of the same dimensions where each element is the result of the mapped function
+  applied to the corresponding elements in the input arrays.
+
+  The mapped function is an arbitrary computation with the restriction that it
+  has N inputs of scalar type T and a single output with type S. The output has
+  the same dimensions as the operands except that the element type T is replaced
+  with S.
+
+  See https://www.tensorflow.org/xla/operation_semantics#map.
+  }];
+}
+
+class BASE_HLO_ReshapeOp {
+  string summary = "Reshape operator";
+
+  string description = [{
+    Reshapes the dimensions of `operand` into a new configuration.
+
+    See https://www.tensorflow.org/xla/operation_semantics#reshape.
+  }];
+}
+
+class BASE_HLO_ScatterOp {
+  string summary = "Scatter operator";
+
+  string description = [{
+    Generates a result which is the value of the input array `operand`,
+    with several slices (at indices specified by `scatter_indices`)
+    updated with the values in `updates` using `update_computation`.
+
+    See https://www.tensorflow.org/xla/operation_semantics#scatter.
+  }];
+}
+
+class BASE_HLO_SelectOp {
+  string summary = "Select operator";
+
+  string description = [{
+    Constructs an output tensor from the elements of `on_true` and `on_false`
+    based on the values of `pred`.
+
+    `pred`, `on_true` and `on_false` must be broadcast compatible.
+  }];
+}
+
+class BASE_HLO_SelectAndScatterOp {
+  string summary = "SelectAndScatter operator";
+
+  string description = [{
+    Runs a windowed selection `select` function over `operand` with shape
+    `window_dimensions` and stride `window_strides`. This will produce an amount
+    of selected locations whose shape matches `source`. These are then scattered
+    to the output which is initialized with `init_value`.
+    Multiple scattered elements which land in the same output location are
+    combined using the `scatter` function.
+
+    See https://www.tensorflow.org/xla/operation_semantics#selectandscatter.
+  }];
+}
+
+class BASE_HLO_SetDimensionSizeOp {
+  string summary = "SetDimensionSize operator";
+
+  string description = [{
+    Sets the dynamic size of operand's given dimension. Pass through the operand
+    as result, with dynamic dimension tracked by the compiler. Padded values
+    will be ignored by downstream reduction ops.
+
+    See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize.
+  }];
+}
+
+class BASE_HLO_SortOp {
+  string summary = "Sort operator";
+
+  string description = [{
+    Sorts the given `operands` at the given `dimension` with the given
+    `comparator`.
+
+    See https://www.tensorflow.org/xla/operation_semantics#sort.
+  }];
+}
+
+class BASE_HLO_ReverseOp {
+  string summary = "Reverse operator";
+
+  string description = [{
+    Reverses the specified dimensions of `operand` according to the given
+    `dimensions`.
+
+    See https://www.tensorflow.org/xla/operation_semantics#rev_reverse.
+  }];
+}
+
+class BASE_HLO_PadOp {
+  string summary = "Pad operator";
+
+  string description = [{
+    Pads the edges of `operand` with the `padding_value` and according to
+    the passed configuration.
+
+    See https://www.tensorflow.org/xla/operation_semantics#pad.
+  }];
+}
+
+class BASE_HLO_TraceOp {
+  string summary = "Trace operator";
+
+  string description = [{
+    Emits a logging message `tag` with the `operand`.
+  }];
+}
+
+class BASE_HLO_TransposeOp {
+  string summary = "Transpose operator";
+
+  string description = [{
+    Permutes the dimensions of `operand` according to the given `permutation`.
+
+    `res_dimensions[i] = operand_dimensions[permutation[i]]`
+
+    See https://www.tensorflow.org/xla/operation_semantics#transpose.
+  }];
+}
+
+// These mirror the XLA Transpose enum in Triangular Solve options.
+def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
+def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
+def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
+def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
+
+def HLO_TransposeAttr : StrEnumAttr<"Transpose",
+    "Transpose options",
+    [
+      HLO_TRANSPOSE_INVALID,
+      HLO_NO_TRANSPOSE,
+      HLO_TRANSPOSE,
+      HLO_ADJOINT
+    ]>;
+
+class BASE_HLO_TriangularSolveOp {
+  string summary = "TriangularSolve operator";
+
+  string description = [{
+    Solves systems of linear equations with lower or upper triangular
+    coefficient matrices by forward- or back-substitution. Broadcasting along
+    leading dimensions, this routine solves one of the matrix systems
+    op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where
+    op(a) is either op(a) = a, or op(a) = Transpose(a), or
+    op(a) = Conj(Transpose(a)).
+
+    Input data is read only from the lower/upper triangle of a, depending on the
+    value of lower. Values from the other triangle are ignored. Output data is
+    returned in the same triangle; the values in the other triangle are
+    implementation-defined and may be anything.
+
+    If the rank of a and b are greater than 2, they are treated as batches of
+    matrices, where all except the minor 2 dimensions are batch dimensions. a
+    and b must have equal batch dimensions.
+
+    See https://www.tensorflow.org/xla/operation_semantics#triangularsolve.
+  }];
+
+}
+
+class BASE_HLO_RngUniformOp {
+  string summary = "RNG with uniform distribution.";
+
+  string description = [{
+    Constructs an output of a given shape with random numbers generated
+    following the uniform distribution over the interval `[a,b)`. The parameters
+    and output element type have to be a boolean type, an integral type or a
+    floating point types, and the types have to be consistent.
+
+    See https://www.tensorflow.org/xla/operation_semantics#rnguniform.
+  }];
+}
+
+class BASE_HLO_RngNormalOp {
+  string summary = "RNG with normal distribution.";
+
+  string description = [{
+    Constructs an output of a given shape with random numbers generated
+    following the normal distribution with parameters `mu` and `sigma`. The
+    parameters and output shape have to have a floating point elemental type.
+    The parameters furthermore have to be scalar valued.
+
+    See https://www.tensorflow.org/xla/operation_semantics#rngnormal.
+  }];
+}
+
+#endif // HLO_OPS_BASE

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
new file mode 100644
index 0000000000000..45309f5110ad2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
@@ -0,0 +1,47 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the operations used in the LXLA dialect.
+
+#ifndef MLIR_DIALECT_LHLO_IR_LHLO_OPS_H_
+#define MLIR_DIALECT_LHLO_IR_LHLO_OPS_H_
+
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
+#include "mlir/IR/Location.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/OpDefinition.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
+
+namespace mlir {
+class OpBuilder;
+
+#include "mlir/Dialect/LHLO/IR/LHLOStructs.h.inc"
+
+namespace xla_lhlo {
+
+#include "mlir/Dialect/LHLO/IR/LHLOOpsDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LHLO/IR/LHLOOps.h.inc"
+
+}  // namespace xla_lhlo
+}  // end namespace mlir
+
+#endif  // MLIR_DIALECT_LHLO_IR_LHLO_OPS_H_

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
new file mode 100644
index 0000000000000..3ae266e82c3ce
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
@@ -0,0 +1,672 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is the operation definition file for LXLA.
+
+#ifndef LHLO_OPS
+#define LHLO_OPS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LHLO/IR/HLOOpsBase.td"
+
+def LHLO_Dialect : Dialect {
+  let name = "xla_lhlo";
+  let cppNamespace = "xla_lhlo";
+}
+
+//===----------------------------------------------------------------------===//
+// XLA type definitions.
+//===----------------------------------------------------------------------===//
+
+// Any integer tensor types
+def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
+
+// A predicate buffer.
+def LHLO_I1Buffer : MemRefOf<[HLO_I1]>;
+
+// Any floating-point tensor types
+def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
+
+def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
+
+// A predicate buffer or i1.
+def LHLO_PredBufferOrI1 : AnyTypeOf<[LHLO_PredBuffer, I1]>;
+
+// Any integer or floating-point tensor types
+def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
+
+def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>;
+
+def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;
+
+def LHLO_BufferOrTuple : AnyTypeOf<[LHLO_Buffer, LHLO_TupleBuffer]>;
+
+def LHLO_TupleOfBufferOrIntOrFP
+    : NestedTupleOf<[LHLO_Buffer, AnyInteger, AnyFloat]>;
+
+def LHLO_BufferOrIntOrFP : AnyTypeOf<[LHLO_Buffer, AnyInteger, AnyFloat]>;
+
+// Any integer memref type with rank 0 (i.e. representing a single integer).
+def LHLO_ScalarIntBuffer : ShapedContainerType<
+  [HLO_Int], And<[IsMemRefTypePred, HasAnyRankOfPred<[0]>]>,
+  "a 0-dim integer memref">;
+
+//===----------------------------------------------------------------------===//
+// XLA nullary op definitions.
+//===----------------------------------------------------------------------===//
+
+class LHLO_Op<string mnemonic, list<OpTrait> traits> :
+  Op<LHLO_Dialect, mnemonic,
+    !listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
+
+/// A LHLO op that does not write to any of its memrefs.
+class LHLO_ReadOnlyOp<string mnemonic, list<OpTrait> traits>
+    : Op<LHLO_Dialect, mnemonic,
+         !listconcat([MemoryEffects<[MemRead]>], traits)>;
+
+def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
+  let arguments = (ins
+    ElementsAttr:$value,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+
+  let hasCanonicalizer = 1;
+}
+
+def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
+  let arguments = (ins I64Attr:$iota_dimension,
+                   Arg<LHLO_Buffer, "", [MemWrite]>:$output);
+}
+
+//===----------------------------------------------------------------------===//
+// XLA unary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
+
+// Unary element-wise ops may have elemental type operands when the result type
+// is a 0-d memref. So, we don't use `SameTypeOperands`.
+class LHLO_UnaryElementwiseOp<string mnemonic>
+    : LHLO_Op<mnemonic, [EquiMemRefAndEltTypeOperands]> {
+  let arguments = (ins Arg<LHLO_BufferOrIntOrFP, "", [MemRead]>:$input,
+                       Arg<LHLO_Buffer, "", [MemWrite]>:$output);
+}
+
+def LHLO_AbsOp : LHLO_UnaryElementwiseOp<"abs">,
+                 BASE_HLO_AbsOp;
+
+def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp;
+
+def LHLO_ConvertOp :  LHLO_Op<"convert", [SameOperandsShape]>, BASE_HLO_ConvertOp {
+  let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
+                       Arg<LHLO_Buffer, "", [MemWrite]>:$output);
+}
+
+def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp;
+
+def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp;
+
+def LHLO_GetDimensionSizeOp : LHLO_Op<"get_dimension_size", [NoSideEffect]>,
+                              BASE_HLO_GetDimensionSizeOp {
+  let arguments = (ins LHLO_Buffer
+                   : $operand, LHLO_Buffer
+                   : $output, I32Attr
+                   : $dimension);
+}
+
+def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
+  let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
+                       Arg<LHLO_Buffer, "", [MemWrite]>:$output);
+}
+
+def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp;
+
+def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
+
+def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
+  let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
+                       Arg<LHLO_Buffer, "", [MemWrite]>:$output);
+}
+
+def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp;
+
+def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
+
+def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
+
+def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
+
+def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
+
+//===----------------------------------------------------------------------===//
+// XLA binary elementwise op definitions.
+//===----------------------------------------------------------------------===//
+// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
+
+class LHLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
+        LHLO_Op<mnemonic, traits> {
+  let arguments = (ins
+      Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
+      Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
+      Arg<LHLO_Buffer, "", [MemWrite]>:$out,
+      OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
+  );
+}
+
+def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp;
+
+def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
+  let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
+                       Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
+                       Arg<LHLO_Buffer, "", [MemWrite]>:$output);
+}
+
+def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp;
+
+def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp;
+
+def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum", []>, BASE_HLO_MinOp;
+
+def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply", []>, BASE_HLO_MulOp;
+
+def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power", []>, BASE_HLO_PowOp;
+
+def LHLO_RemOp :
+      LHLO_BinaryElementwiseOp<"remainder", []>, BASE_HLO_RemOp;
+
+def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract", []>, BASE_HLO_SubOp;
+
+def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp;
+
+def LHLO_OrOp: LHLO_BinaryElementwiseOp<"or", []>, BASE_HLO_OrOp;
+
+//===----------------------------------------------------------------------===//
+// XLA control flow op definitions.
+//===----------------------------------------------------------------------===//
+
+// TODO(b/139813999): specify required function signature in a type-safe way.
+def LHLO_ReduceOp : LHLO_Op<"reduce", [SameVariadicOperandSize]>,
+                    BASE_HLO_ReduceOp {
+  let arguments = (ins
+    Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$operands,
+    Arg<Variadic<LHLO_BufferOrIntOrFP>, "", [MemRead]>:$init_values,
+    Arg<Variadic<LHLO_BufferOrTuple>, "", [MemWrite]>:$out,
+    I64ElementsAttr:$dimensions
+  );
+
+  let regions = (region SizedRegion<1>:$body);
+}
+
+def LHLO_ReduceWindowOp : LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_BufferOrIntOrFP, "", [MemRead]>:$init_value,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$out,
+    I64ElementsAttr:$window_dimensions,
+    // If strides or dilations attributes are missing then the default value is
+    // one for each of the input dimensions. Similarly, padding values are zero
+    // for both low and high in each of the dimensions, if not specified.
+    OptionalAttr<I64ElementsAttr>:$window_strides,
+    OptionalAttr<I64ElementsAttr>:$base_dilations,
+    OptionalAttr<I64ElementsAttr>:$window_dilations,
+    OptionalAttr<I64ElementsAttr>:$padding
+  );
+
+  let regions = (region SizedRegion<1>:$body);
+}
+
+def LHLO_CaseOp: LHLO_Op<"case", [
+      SingleBlockImplicitTerminator<"TerminatorOp">
+    ]>, BASE_HLO_CaseOp {
+
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$index,
+    Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$branch_operands,
+    Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$out
+  );
+
+  let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
+}
+
+//===----------------------------------------------------------------------===//
+// XLA tuple op definitions.
+//===----------------------------------------------------------------------===//
+
+def LHLO_GetTupleElementOp: LHLO_Op<"get_tuple_element", []>, BASE_HLO_GetTupleElementOp {
+  let arguments = (ins
+    Arg<LHLO_TupleOfBufferOrIntOrFP, "", [MemRead]>:$input,
+    I32Attr:$index
+  );
+
+  let results = (outs Arg<LHLO_BufferOrIntOrFP>);
+}
+
+def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
+    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
+    Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
+    OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
+    HLO_ComparisonDirectionAttr:$comparison_direction
+  );
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Slice definitions.
+//===----------------------------------------------------------------------===//
+
+def LHLO_SliceOp: LHLO_Op<
+      "slice",
+      [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
+    I64ElementsAttr:$start_indices,
+    I64ElementsAttr:$limit_indices,
+    I64ElementsAttr:$strides
+  );
+}
+
+def LHLO_DynamicSliceOp: LHLO_Op<"dynamic-slice", []> {
+  string summary = "dynamic slice operator";
+
+  string description = [{
+    Extracts a sub-array from the input array at dynamic indices specified at
+    `start_indices`.
+
+    https://www.tensorflow.org/xla/operation_semantics#dynamicslice
+  }];
+
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Variadic<AnyTypeOf<[LHLO_ScalarIntBuffer, AnyInteger]>>:$start_indices,
+    I64ElementsAttr:$slice_sizes,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
+  string summary = "dynamic update slice operator";
+
+  string description = [{
+    Generates a result which is the value of the input tensor operand, with a
+    slice update overwritten at `start_indices`.
+
+    https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
+  }];
+
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemRead]>:$update,
+    Arg<Variadic<AnyTypeOf<[LHLO_ScalarIntBuffer, AnyInteger]>>, "", [MemRead]>:$start_indices,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+//===----------------------------------------------------------------------===//
+// XLA Other op definitions.
+//===----------------------------------------------------------------------===//
+
+def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
+    BASE_HLO_BatchNormInferenceOp {
+
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemRead]>:$scale,
+    Arg<LHLO_Buffer, "", [MemRead]>:$offset,
+    Arg<LHLO_Buffer, "", [MemRead]>:$mean,
+    Arg<LHLO_Buffer, "", [MemRead]>:$variance,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
+    F32Attr:$epsilon,
+    I64Attr:$feature_index
+  );
+}
+
+def LHLO_BroadcastOp : LHLO_Op<"broadcast",
+      []>, BASE_HLO_BroadcastOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
+    I64ElementsAttr:$broadcast_sizes
+  );
+}
+
+def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
+      []>, BASE_HLO_BroadcastInDimOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
+    BroadcastDimAttr:$broadcast_dimensions
+  );
+}
+
+def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$min,
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemRead]>:$max,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
+   let arguments = (ins
+     Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
+     Arg<LHLO_Buffer, "", [MemWrite]>:$output,
+     I64Attr:$dimension
+   );
+}
+
+// TODO(bondhugula): Make this struct dialect independent so that it can be
+// shared between the HLO and LHLO dialects.
+def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
+  StructFieldAttr<"input_batch_dimension",I64Attr>,
+  StructFieldAttr<"input_feature_dimension", I64Attr>,
+  StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
+  StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
+  StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
+  StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
+  StructFieldAttr<"output_batch_dimension", I64Attr>,
+  StructFieldAttr<"output_feature_dimension", I64Attr>,
+  StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
+
+  let description = "Structure of dimension information for conv op";
+}
+
+def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
+    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
+    // Default value: one for each of the spatial dimension.
+    OptionalAttr<I64ElementsAttr>:$window_strides,
+    // Default value: zero for each of the spatial dimension.
+    OptionalAttr<I64ElementsAttr>:$padding,
+    // Default value: one for each of the spatial dimension.
+    OptionalAttr<I64ElementsAttr>:$lhs_dilation,
+    // Default value: one for each of the spatial dimension.
+    OptionalAttr<I64ElementsAttr>:$rhs_dilation,
+    ConvDimensionNumbers:$dimension_numbers,
+    I64Attr:$feature_group_count,
+    I64Attr:$batch_group_count,
+    HLO_PrecisionConfigAttr:$precision_config
+  );
+}
+
+def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
+    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
+    HLO_PrecisionConfigAttr:$precision_config,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", LHLO_Dialect, [
+                StructFieldAttr<"lhs_batching_dimensions",   I64ElementsAttr>,
+                StructFieldAttr<"rhs_batching_dimensions",   I64ElementsAttr>,
+                StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
+                StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
+  ]> {
+  let description = "Structure of dimension information for dot product";
+}
+
+def LHLO_DotGeneralOp : LHLO_Op<"dot_general", []>, BASE_HLO_DotGeneralOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
+    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
+    DotDimensionNumbers:$dot_dimension_numbers,
+    HLO_PrecisionConfigAttr:$precision_config,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_IsFiniteOp: LHLO_Op<"is_finite", []>, BASE_HLO_IsFiniteOp {
+  let arguments = (ins
+    Arg<LHLO_FpBuffer, "", [MemRead]>:$x,
+    Arg<LHLO_I1Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def GatherDimensionNumbers: StructAttr<"GatherDimensionNumbers", LHLO_Dialect,
+      [StructFieldAttr<"offset_dims", I64ElementsAttr>,
+      StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
+      StructFieldAttr<"start_index_map", I64ElementsAttr>,
+      StructFieldAttr<"index_vector_dim", I64Attr>]> {
+  let description = "Structure of dimension information for gather";
+}
+
+def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
+    GatherDimensionNumbers:$dimension_numbers,
+    I64ElementsAttr:$slice_sizes,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_MapOp : LHLO_Op<"map", [RecursiveSideEffects, SameOperandsShape]>,
+                 BASE_HLO_MapOp {
+  let description = [{
+    "Applies a scalar function over the given operands arrays, producing an
+    array of the same dimensions where each element is the result of the mapped
+        function applied to the corresponding elements in the input arrays.See
+    https:  // www.tensorflow.org/xla/operation_semantics#map
+  }];
+  let arguments = (ins
+    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
+    I64ElementsAttr:$dimensions,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+
+  let regions = (region SizedRegion<1>:$computation);
+}
+
+def ScatterDimensionNumbers
+    : StructAttr<"ScatterDimensionNumbers", LHLO_Dialect, [
+      StructFieldAttr<"update_window_dims", I64ElementsAttr>,
+      StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
+      StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
+      StructFieldAttr<"index_vector_dim", I64Attr>
+    ]> {
+  let description = "Structure of dimension information for scatter";
+}
+
+def LHLO_ScatterOp: LHLO_Op<"scatter", [RecursiveSideEffects]>,
+      BASE_HLO_ScatterOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
+    Arg<LHLO_Buffer, "", [MemRead]>:$updates,
+    ScatterDimensionNumbers:$scatter_dimension_numbers,
+    DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
+    DefaultValuedAttr<BoolAttr, "false">:$unique_indices,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+
+  let regions = (region SizedRegion<1>:$update_computation);
+}
+
+def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
+  let arguments = (ins
+    Arg<LHLO_BufferOrIntOrFP, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+//===----------------------------------------------------------------------===//
+// RngUniform operator that takes scalar low and high values and populates the
+// output memref consistent with HLO op xla_hlo::rng_uniform.
+//===----------------------------------------------------------------------===//
+def LHLO_RngUniformOp : LHLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
+  let arguments = (ins
+    AnyTypeOf<[AnyInteger, AnyFloat]>:$a,
+    AnyTypeOf<[AnyInteger, AnyFloat]>:$b,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
+  let arguments = (ins
+    Arg<LHLO_PredBufferOrI1, "", [MemRead]>:$pred,
+    Arg<LHLO_BufferOrIntOrFP, "", [MemRead]>:$on_true,
+    Arg<LHLO_BufferOrIntOrFP, "", [MemRead]>:$on_false,
+    Arg<LHLO_BufferOrIntOrFP, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>,
+      BASE_HLO_SelectAndScatterOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemRead]>:$source,
+    Arg<LHLO_BufferOrIntOrFP, "", [MemRead]>:$init_value,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$out,
+    OptionalAttr<I64ElementsAttr>:$window_dimensions,
+    OptionalAttr<I64ElementsAttr>:$window_strides,
+    OptionalAttr<I64ElementsAttr>:$padding
+  );
+
+  let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
+}
+
+def LHLO_TorchIndexSelectOp : LHLO_Op<"torch_index_select", [NoSideEffect]>,
+    BASE_HLO_SelectAndScatterOp {
+  let summary = "LHLO Torch Index Select Operation";
+  let description = [{
+      Gather slices from input axis dim according to indices.
+      See https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather-v2
+  }];
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$input,
+    Arg<LHLO_Buffer, "", [MemRead]>   :$index,
+    I64Attr:$dim,
+    I64Attr:$batch_dims,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    I64ElementsAttr:$dimensions,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
+    I64ElementsAttr:$edge_padding_low,
+    I64ElementsAttr:$edge_padding_high,
+    I64ElementsAttr:$interior_padding,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
+  let arguments = (ins
+    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
+    I64ElementsAttr:$permutation,
+    Arg<LHLO_Buffer, "", [MemWrite]>:$output
+  );
+}
+
+def LHLO_TupleOp : LHLO_ReadOnlyOp<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
+  let arguments = (ins Arg<Variadic<LHLO_BufferOrIntOrFP>, "", [MemRead]>:$input);
+  let results = (outs NestedTupleOf<[LHLO_BufferOrIntOrFP]>);
+
+  let builders = [OpBuilder<
+                  "OpBuilder &builder, OperationState &results, "
+                  "ValueRange values">];
+}
+
+def LHLO_WhileOp
+    : LHLO_Op<"while",
+              [AffineScope, RecursiveSideEffects, SameOperandsAndResultType]> {
+  string summary = "While operator";
+
+  string description = [{
+    Returns the result of executing a body function until the cond body returns
+    true.
+
+    See https://www.tensorflow.org/xla/operation_semantics#while.
+  }];
+
+  let arguments = (ins Arg<LHLO_TupleOfBufferOrIntOrFP> : $init);
+
+  let regions = (region AnyRegion:$cond, AnyRegion:$body);
+
+  let results = (outs Arg<LHLO_TupleOfBufferOrIntOrFP>);
+}
+
+//===----------------------------------------------------------------------===//
+// Late operations
+//===----------------------------------------------------------------------===//
+
+def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> {
+  let summary = "Fusion operator";
+  let description = [{
+    Models the fusion instruction generated by the XLA compiler's fusion pass.
+
+    Fusion instructions are generated by the fusion pass of the XLA compiler.
+    They serve as a hint to the backend that it is beneficial to emit the
+    contained instructions into a single loop nest or kernel. The XLA fusion
+    pass is designed such that it only generates fusion nodes that can be
+    handled by the XLA compilers backends.
+    The XLA runtime expects this hint to be followed, as it expects a single
+    kernel per HLO instruction. This restriction might be lifted in the future.
+  }];
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+     OpBuilder<"OpBuilder &builder, OperationState &result, "
+               "ArrayRef<NamedAttribute> attributes">
+   ];
+}
+
+def TerminatorOp :
+    LHLO_Op<"terminator", [Terminator]> {
+  let summary = "LHLO termination operation";
+  let description = [{
+    Terminator operation for the LHLO dialect.
+  }];
+  let builders = [OpBuilder<
+    "OpBuilder &b, OperationState &result, ValueRange operands",
+    [{ build(b, result, llvm::None, operands, llvm::None); }]
+  >];
+}
+
+def YieldOp :
+    LHLO_Op<"yield", [Terminator]> {
+  let summary = "LHLO yield operation";
+  let description = [{
+    Yield operation for the LHLO dialect.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+}
+
+#endif // LHLO_OPS

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 5cf0429942ca7..36dd26ba817ef 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LHLO/IR/LHLOOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -47,6 +48,7 @@
 
 namespace mlir {
 
+<<<<<<< HEAD
 /// Add all the MLIR dialects to the provided registry.
 inline void registerAllDialects(DialectRegistry &registry) {
   // clang-format off
@@ -80,6 +82,32 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   tosa::TosaDialect,
                   x86vector::X86VectorDialect>();
   // clang-format on
+=======
+// This function should be called before creating any MLIRContext if one expect
+// all the possible dialects to be made available to the context automatically.
+inline void registerAllDialects() {
+  static bool init_once = []() {
+    registerDialect<AffineDialect>();
+    registerDialect<xla_lhlo::LHLODialect>();
+    registerDialect<avx512::AVX512Dialect>();
+    registerDialect<gpu::GPUDialect>();
+    registerDialect<LLVM::LLVMAVX512Dialect>();
+    registerDialect<LLVM::LLVMDialect>();
+    registerDialect<linalg::LinalgDialect>();
+    registerDialect<scf::SCFDialect>();
+    registerDialect<omp::OpenMPDialect>();
+    registerDialect<quant::QuantizationDialect>();
+    registerDialect<spirv::SPIRVDialect>();
+    registerDialect<StandardOpsDialect>();
+    registerDialect<vector::VectorDialect>();
+    registerDialect<NVVM::NVVMDialect>();
+    registerDialect<ROCDL::ROCDLDialect>();
+    registerDialect<SDBMDialect>();
+    registerDialect<shape::ShapeDialect>();
+    return true;
+  }();
+  (void)init_once;
+>>>>>>> f0d77094085b... [MLIR] Add xla_lhlo dialect from tensorflow
 }
 
 /// Append all the MLIR dialects to the registry contained in the given context.

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 8a6f08ab3b837..52676ad31da62 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -7,6 +7,7 @@ add_subdirectory(Complex)
 add_subdirectory(DLTI)
 add_subdirectory(EmitC)
 add_subdirectory(GPU)
+add_subdirectory(LHLO)
 add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(Math)

diff  --git a/mlir/lib/Dialect/LHLO/CMakeLists.txt b/mlir/lib/Dialect/LHLO/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/LHLO/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt b/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..d99e31c89729f
--- /dev/null
+++ b/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRLHLOOps
+  LHLOOps.cc
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DR}/mlir/Dialect/LHLO
+
+  DEPENDS
+  MLIRIR
+  MLIRLHLOStructsGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRStandardOps
+  )

diff  --git a/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc b/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
new file mode 100644
index 0000000000000..871872a22f68c
--- /dev/null
+++ b/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
@@ -0,0 +1,120 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the operations used in the XLA dialect.
+
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/LHLO/IR/LHLOOps.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
+#include "mlir/IR/Location.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/OpDefinition.h"  // from @llvm-project
+#include "mlir/IR/OpImplementation.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+
+namespace mlir {
+#include "mlir/Dialect/LHLO/IR/LHLOStructs.cpp.inc"
+namespace xla_lhlo {
+
+LHLODialect::LHLODialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LHLO/IR/LHLOOps.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LHLO/IR/LHLOOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ConstOp.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// An lho.constant on an memref that is locally allocated and with no other
+/// users (other than dealloc's) can be deleted.
+struct EraseConstOp : public OpRewritePattern<ConstOp> {
+  using OpRewritePattern<ConstOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(xla_lhlo::ConstOp op,
+                                PatternRewriter& rewriter) const override {
+    Value memref = op.output();
+    if (!memref.getDefiningOp<AllocOp>()) {
+      return failure();
+    }
+
+    // Check that all uses of the memref are either DeallocOps or this op.
+    for (Operation* user : memref.getUsers())
+      if (user != op.getOperation() && !isa<DeallocOp>(user)) return failure();
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+}  // end anonymous namespace
+
+void xla_lhlo::ConstOp::getCanonicalizationPatterns(
+    OwningRewritePatternList& results, MLIRContext* context) {
+  results.insert<EraseConstOp>(context);
+}
+
+// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
+
+void FusionOp::build(OpBuilder &builder, OperationState &result,
+                     ArrayRef<NamedAttribute> attributes) {
+  result.addAttributes(attributes);
+  Region *bodyRegion = result.addRegion();
+  FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
+}
+
+//===----------------------------------------------------------------------===//
+// TupleOp
+//===----------------------------------------------------------------------===//
+
+void TupleOp::build(OpBuilder& builder, OperationState& result,
+                    ValueRange values) {
+  SmallVector<Type, 4> types;
+  types.reserve(values.size());
+  for (Value val : values) {
+    types.push_back(val.getType());
+  }
+  build(builder, result, builder.getTupleType(types), values);
+}
+
+}  // namespace xla_lhlo
+}  // namespace mlir

diff  --git a/mlir/test/Dialect/LHLO/lhlo_ops.mlir b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
new file mode 100644
index 0000000000000..a288ccffc3e34
--- /dev/null
+++ b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | mlir-opt | FileCheck %s
+
+func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
+  // expected-error at +1{{'xla_lhlo.tanh' op requires all operands to have the same or equivalent type}}
+  "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @add_memrefs
+func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
+  "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @abs_memref
+func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @convert_memref
+func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @exp_memref
+func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @log_memref
+func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @neg_memref
+func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_memref
+func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @sign_memref
+func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_memref
+func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @add_memref
+func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @div_memref
+func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @max_memref
+func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_memref
+func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @mul_memref
+func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @sub_memref
+func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @and_memref
+func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
+  "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @broadcast_in_dim_memref
+func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
+  "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
+func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
+  "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
+  return
+}
+
+// -----
+
+
+// CHECK-LABEL: func @reduce_memref
+func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
+  "xla_lhlo.reduce"(%input, %init, %out) ( {
+  ^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
+    "xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+    "xla_lhlo.terminator"() : () -> ()
+  } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
+  return
+}
+
+// -----
+
+// @bondhugula: Disabled when adding LHLO to MLIR.
+// XCHECK-LABEL: func @fusion_memref
+// func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
+//  "xla_lhlo.fusion"() ( {
+//    %0 = tensor_load %input1 : memref<10xf32>
+//    %1 = tensor_load %input2 : memref<10xf32>
+//    %2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+//    %3 = tensor_load %input3 : memref<10xf32>
+//    %4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+//    tensor_store %4, %out : memref<10xf32>
+//    "xla_lhlo.terminator"() : () -> ()
+//  } ) : () -> ()
+//  return
+//}
+
+// -----
+
+// CHECK-LABEL: func @case_memref
+func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
+  "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
+    ^bb0(%arg0: memref<f32>):
+      "xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
+      "xla_lhlo.terminator"() : () -> ()
+    },  {
+    ^bb0(%arg0: memref<f32>):
+      "xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
+      "xla_lhlo.terminator"() : () -> ()
+    },  {
+    ^bb0(%arg0: memref<f32>):
+      "xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+      "xla_lhlo.terminator"() : () -> ()
+    }
+  ) : (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
+  return
+}
+
+// -----
+
+// Test xla_lhlo.while op's affine scope trait. The while op encapsutes a
+// functional form of control flow while being able to model affine loop nests
+// in their regions.
+
+func @while_op(%arg0: memref<4x?x16xf32>, %arg1: memref<4x?x16xf32>) {
+    %c0_i32 = constant 0 : i32
+    %c4_i32 = constant 4 : i32
+    %2 = alloc() : memref<4xi32>
+    "xla_lhlo.rng_uniform"(%c0_i32, %c4_i32, %2) : (i32, i32, memref<4xi32>) -> ()
+    %c0_i32_0 = constant 0 : i32
+    %3 = "xla_lhlo.tuple"(%c0_i32_0, %2) : (i32, memref<4xi32>) -> tuple<i32, memref<4xi32>>
+    dealloc %2 : memref<4xi32>
+    %4 = "xla_lhlo.while"(%3) ( {
+    ^bb0(%arg2: tuple<i32, memref<4xi32>>):  // no predecessors
+      %7 = "xla_lhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32
+      %c4_i32_1 = constant 4 : i32
+      %8 = cmpi "slt", %7, %c4_i32_1 : i32
+      "xla_lhlo.yield"(%8) : (i1) -> ()
+    },  {
+    ^bb0(%arg2: tuple<i32, memref<4xi32>>):  // no predecessors
+      %7 = "xla_lhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32
+      %8 = "xla_lhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple<i32, memref<4xi32>>) -> memref<4xi32>
+      %idx = index_cast %7 : i32 to index
+      affine.for %i = 0 to 4 {
+        cmpi "eq", %idx, %i : index
+        // There should be no error from this.
+        affine.store %c0_i32, %8[%idx] : memref<4xi32>
+      }
+      "xla_lhlo.yield"(%arg2) : (tuple<i32, memref<4xi32>>) -> ()
+    }) : (tuple<i32, memref<4xi32>>) -> tuple<i32, memref<4xi32>>
+    "xla_lhlo.terminator"() : () -> ()
+}


        


More information about the llvm-branch-commits mailing list