[llvm-branch-commits] [mlir] 51de919 - Rename xla_lhlo dialect and namespace -> lmhlo

Uday Bondhugula via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:18 PDT 2021


Author: Uday Bondhugula
Date: 2021-09-23T06:12:54+05:30
New Revision: 51de91957abebee3c6fdf870f1388b266ea011b1

URL: https://github.com/llvm/llvm-project/commit/51de91957abebee3c6fdf870f1388b266ea011b1
DIFF: https://github.com/llvm/llvm-project/commit/51de91957abebee3c6fdf870f1388b266ea011b1.diff

LOG: Rename xla_lhlo dialect and namespace -> lmhlo

This change of name is being done to avoid a clash in the python binding
with the xla_lhlo in Monolith's TensorFlow. While xla_lhlo was moved to
MLIR proper in our branches, Monolith's TensorFlow still has it and its
methods would get called instead of those in MLIR linked in by Monolith.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
    mlir/test/Dialect/LHLO/invalid.mlir
    mlir/test/Dialect/LHLO/lhlo_ops.mlir
    mlir/test/mlir-opt/commandline.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
index f5c4548ccf534..b852d42171bd7 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LHLO/IR/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_mlir_dialect(LHLOOps xla_lhlo)
+add_mlir_dialect(LHLOOps lmhlo)
 
 set(LLVM_TARGET_DEFINITIONS LHLOOps.td)
 mlir_tablegen(LHLOStructs.h.inc -gen-struct-attr-decls)

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
index c17ee6d63bd40..94a9de69cfd27 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
@@ -40,5 +40,4 @@ class OpBuilder;
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LHLO/IR/LHLOOps.h.inc"
 
-
 #endif  // MLIR_DIALECT_LHLO_IR_LHLO_OPS_H_

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
index 711549940570c..0ee24446cad8b 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
@@ -23,8 +23,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LHLO/IR/HLOOpsBase.td"
 
 def LHLO_Dialect : Dialect {
-  let name = "xla_lhlo";
-  let cppNamespace = "::mlir::xla_lhlo";
+  let name = "lmhlo";
+  let cppNamespace = "::mlir::lmhlo";
 }
 
 //===----------------------------------------------------------------------===//
@@ -481,19 +481,19 @@ def LHLO_IfOp : LHLO_Op<"if", [AffineScope, RecursiveSideEffects]> {
 
     ```mlir
     func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) {
-      %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
-      %1 = "xla_lhlo.if"(%arg2, %0, %0) ( {
+      %0 = "lmhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
+      %1 = "lmhlo.if"(%arg2, %0, %0) ( {
         ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):
-        %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
-        %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
-        "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+        %2 = "lmhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+        %3 = "lmhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+        "lmhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
       },  {
         ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):  // no predecessors
-        %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
-        %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
-        "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+        %2 = "lmhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+        %3 = "lmhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+        "lmhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
       }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>>
-      "xla_lhlo.terminator"() : () -> ()
+      "lmhlo.terminator"() : () -> ()
     }
     ```
 

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e9869365e54eb..43712b70d5362 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -80,7 +80,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   tensor::TensorDialect,
                   tosa::TosaDialect,
                   x86vector::X86VectorDialect,
-                  xla_lhlo::LHLODialect>();
+                  lmhlo::LHLODialect>();
   // clang-format on
 }
 

diff  --git a/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc b/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
index 27d2596f167e1..9824f70b004f6 100644
--- a/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
+++ b/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
@@ -47,7 +47,7 @@ limitations under the License.
 
 
 using namespace mlir;
-using namespace mlir::xla_lhlo;
+using namespace mlir::lmhlo;
 
 #include "mlir/Dialect/LHLO/IR/LHLOOpsDialect.cpp.inc"
 #include "mlir/Dialect/LHLO/IR/LHLOStructs.cpp.inc"
@@ -59,10 +59,10 @@ void LHLODialect::initialize() {
       >();
 }
 
-using xla_lhlo::ConstOp;
-using xla_lhlo::FusionOp;
-using xla_lhlo::GetTupleElementOp;
-using xla_lhlo::TupleOp;
+using lmhlo::ConstOp;
+using lmhlo::FusionOp;
+using lmhlo::GetTupleElementOp;
+using lmhlo::TupleOp;
 
 //===----------------------------------------------------------------------===//
 // ConstOp.
@@ -75,7 +75,7 @@ namespace {
 struct EraseConstOp : public OpRewritePattern<ConstOp> {
   using OpRewritePattern<ConstOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(xla_lhlo::ConstOp op,
+  LogicalResult matchAndRewrite(lmhlo::ConstOp op,
                                 PatternRewriter& rewriter) const override {
     Value memref = op.output();
     if (!memref.getDefiningOp<memref::AllocOp>()) {
@@ -94,7 +94,7 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
 
 }  // end anonymous namespace
 
-void xla_lhlo::ConstOp::getCanonicalizationPatterns(
+void lmhlo::ConstOp::getCanonicalizationPatterns(
     OwningRewritePatternList& results, MLIRContext* context) {
   results.insert<EraseConstOp>(context);
 }

diff  --git a/mlir/test/Dialect/LHLO/invalid.mlir b/mlir/test/Dialect/LHLO/invalid.mlir
index 98e744c0167c1..5bff4c063f842 100644
--- a/mlir/test/Dialect/LHLO/invalid.mlir
+++ b/mlir/test/Dialect/LHLO/invalid.mlir
@@ -2,10 +2,10 @@
 
 func @passthrough(%arg : memref<8xi32>) {
   %c0_i32 = constant 0 : i32
-  %tuple = "xla_lhlo.tuple"(%c0_i32, %arg) : (i32, memref<8xi32>) -> (tuple<i32, memref<8xi32>>)
-  %elt = "xla_lhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<i32, memref<8xi32>>) -> i32
-  // expected-error at +1{{'xla_lhlo.get_tuple_element' op has return type memref<8xi32>, but expected i32}}
-  %mem = "xla_lhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<i32, memref<8xi32>>) -> memref<8xi32>
+  %tuple = "lmhlo.tuple"(%c0_i32, %arg) : (i32, memref<8xi32>) -> (tuple<i32, memref<8xi32>>)
+  %elt = "lmhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<i32, memref<8xi32>>) -> i32
+  // expected-error at +1{{'lmhlo.get_tuple_element' op has return type memref<8xi32>, but expected i32}}
+  %mem = "lmhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<i32, memref<8xi32>>) -> memref<8xi32>
   return
 }
 
@@ -14,16 +14,16 @@ func @passthrough(%arg : memref<8xi32>) {
 func @pass_wrong_number_of_arguments(%arg : memref<8xi32>){
     %c0 = constant 0 : i32
     %c1 = constant 1 : i32
-    // expected-error at +1{{'xla_lhlo.tuple' op has return type tuple<i32>, but expected tuple<i32, memref<8xi32>, i32>}}
-    %tuple = "xla_lhlo.tuple"(%c0, %arg, %c1) : (i32, memref<8xi32>, i32) -> (tuple<i32>)
+    // expected-error at +1{{'lmhlo.tuple' op has return type tuple<i32>, but expected tuple<i32, memref<8xi32>, i32>}}
+    %tuple = "lmhlo.tuple"(%c0, %arg, %c1) : (i32, memref<8xi32>, i32) -> (tuple<i32>)
 }
 
 // -----
 
 func @pass_wrong_type(%arg : i32){
     %c = constant 0 : i32
-    // expected-error at +1{{'xla_lhlo.tuple' op has return type tuple<i32, memref<8xi32>>, but expected tuple<i32, i32>}}
-    %tuple = "xla_lhlo.tuple"(%c, %arg) : (i32, i32) -> (tuple<i32, memref<8xi32>>)
+    // expected-error at +1{{'lmhlo.tuple' op has return type tuple<i32, memref<8xi32>>, but expected tuple<i32, i32>}}
+    %tuple = "lmhlo.tuple"(%c, %arg) : (i32, i32) -> (tuple<i32, memref<8xi32>>)
     return 
 }
 
@@ -31,7 +31,7 @@ func @pass_wrong_type(%arg : i32){
 
 func @pass_wrong_order(%arg : memref<8xi32>){
     %c = constant 0 : i32
-    // expected-error at +1{{'xla_lhlo.tuple' op has return type tuple<memref<8xi32>, i32>, but expected tuple<i32, memref<8xi32>>}}
-    %tuple = "xla_lhlo.tuple"(%c, %arg) : (i32, memref<8xi32>) -> (tuple<memref<8xi32>, i32>)
+    // expected-error at +1{{'lmhlo.tuple' op has return type tuple<memref<8xi32>, i32>, but expected tuple<i32, memref<8xi32>>}}
+    %tuple = "lmhlo.tuple"(%c, %arg) : (i32, memref<8xi32>) -> (tuple<memref<8xi32>, i32>)
     return 
 }
\ No newline at end of file

diff  --git a/mlir/test/Dialect/LHLO/lhlo_ops.mlir b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
index 59ccfc8113be5..76cd68672a6d0 100644
--- a/mlir/test/Dialect/LHLO/lhlo_ops.mlir
+++ b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt %s -verify-diagnostics -split-input-file | mlir-opt | FileCheck %s
 
 func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
-  // expected-error at +1{{'xla_lhlo.tanh' op requires all operands to have the same or equivalent type}}
-  "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
+  // expected-error at +1{{'lmhlo.tanh' op requires all operands to have the same or equivalent type}}
+  "lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
   return
 }
 
@@ -10,7 +10,7 @@ func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
 
 // CHECK-LABEL: func @add_memrefs
 func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
-  "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
+  "lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
   return
 }
 
@@ -18,7 +18,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1
 
 // CHECK-LABEL: func @abs_memref
 func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -26,7 +26,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @convert_memref
 func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -34,7 +34,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @exp_memref
 func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -42,7 +42,7 @@ func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @log_memref
 func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -50,7 +50,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @neg_memref
 func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -58,7 +58,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @rsqrt_memref
 func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -66,7 +66,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @sign_memref
 func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -74,7 +74,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @tanh_memref
 func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -82,7 +82,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
 
 // CHECK-LABEL: func @add_memref
 func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -90,7 +90,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @div_memref
 func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -98,7 +98,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @max_memref
 func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -106,7 +106,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @min_memref
 func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -114,7 +114,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @mul_memref
 func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -122,7 +122,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @sub_memref
 func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -130,7 +130,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @and_memref
 func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
-  "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
+  "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
   return
 }
 
@@ -138,7 +138,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
 
 // CHECK-LABEL: func @broadcast_in_dim_memref
 func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
-  "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
+  "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
   return
 }
 
@@ -146,7 +146,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -
 
 // CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
 func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
-  "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
+  "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
   return
 }
 
@@ -155,10 +155,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
 
 // CHECK-LABEL: func @reduce_memref
 func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
-  "xla_lhlo.reduce"(%input, %init, %out) ( {
+  "lmhlo.reduce"(%input, %init, %out) ( {
   ^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
-    "xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
-    "xla_lhlo.terminator"() : () -> ()
+    "lmhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+    "lmhlo.terminator"() : () -> ()
   } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
   return
 }
@@ -168,14 +168,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
 // @bondhugula: Disabled when adding LHLO to MLIR.
 // XCHECK-LABEL: func @fusion_memref
 // func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
-//  "xla_lhlo.fusion"() ( {
+//  "lmhlo.fusion"() ( {
 //    %0 = tensor_load %input1 : memref<10xf32>
 //    %1 = tensor_load %input2 : memref<10xf32>
 //    %2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 //    %3 = tensor_load %input3 : memref<10xf32>
 //    %4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 //    tensor_store %4, %out : memref<10xf32>
-//    "xla_lhlo.terminator"() : () -> ()
+//    "lmhlo.terminator"() : () -> ()
 //  } ) : () -> ()
 //  return
 //}
@@ -184,18 +184,18 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
 
 // CHECK-LABEL: func @case_memref
 func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
-  "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
+  "lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
     ^bb0(%arg0: memref<f32>):
-      "xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
-      "xla_lhlo.terminator"() : () -> ()
+      "lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
+      "lmhlo.terminator"() : () -> ()
     },  {
     ^bb0(%arg0: memref<f32>):
-      "xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
-      "xla_lhlo.terminator"() : () -> ()
+      "lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
+      "lmhlo.terminator"() : () -> ()
     },  {
     ^bb0(%arg0: memref<f32>):
-      "xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
-      "xla_lhlo.terminator"() : () -> ()
+      "lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+      "lmhlo.terminator"() : () -> ()
     }
   ) : (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
   return
@@ -203,7 +203,7 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
 
 // -----
 
-// Test xla_lhlo.while op's affine scope trait. The while op encapsutes a
+// Test lmhlo.while op's affine scope trait. The while op encapsutes a
 // functional form of control flow while being able to model affine loop nests
 // in their regions.
 
@@ -211,64 +211,64 @@ func @while_op(%arg0: memref<4x?x16xf32>, %arg1: memref<4x?x16xf32>) {
     %c0_i32 = constant 0 : i32
     %c4_i32 = constant 4 : i32
     %2 = memref.alloc() : memref<4xi32>
-    "xla_lhlo.rng_uniform"(%c0_i32, %c4_i32, %2) : (i32, i32, memref<4xi32>) -> ()
+    "lmhlo.rng_uniform"(%c0_i32, %c4_i32, %2) : (i32, i32, memref<4xi32>) -> ()
     %c0_i32_0 = constant 0 : i32
-    %3 = "xla_lhlo.tuple"(%c0_i32_0, %2) : (i32, memref<4xi32>) -> tuple<i32, memref<4xi32>>
+    %3 = "lmhlo.tuple"(%c0_i32_0, %2) : (i32, memref<4xi32>) -> tuple<i32, memref<4xi32>>
     memref.dealloc %2 : memref<4xi32>
-    %4 = "xla_lhlo.while"(%3) ( {
+    %4 = "lmhlo.while"(%3) ( {
     ^bb0(%arg2: tuple<i32, memref<4xi32>>):  // no predecessors
-      %7 = "xla_lhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32
+      %7 = "lmhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32
       %c4_i32_1 = constant 4 : i32
       %8 = cmpi "slt", %7, %c4_i32_1 : i32
-      "xla_lhlo.yield"(%8) : (i1) -> ()
+      "lmhlo.yield"(%8) : (i1) -> ()
     },  {
     ^bb0(%arg2: tuple<i32, memref<4xi32>>):  // no predecessors
-      %7 = "xla_lhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32
-      %8 = "xla_lhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple<i32, memref<4xi32>>) -> memref<4xi32>
+      %7 = "lmhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32
+      %8 = "lmhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple<i32, memref<4xi32>>) -> memref<4xi32>
       %idx = index_cast %7 : i32 to index
       affine.for %i = 0 to 4 {
         cmpi "eq", %idx, %i : index
         // There should be no error from this.
         affine.store %c0_i32, %8[%idx] : memref<4xi32>
       }
-      "xla_lhlo.yield"(%arg2) : (tuple<i32, memref<4xi32>>) -> ()
+      "lmhlo.yield"(%arg2) : (tuple<i32, memref<4xi32>>) -> ()
     }) : (tuple<i32, memref<4xi32>>) -> tuple<i32, memref<4xi32>>
-    "xla_lhlo.terminator"() : () -> ()
+    "lmhlo.terminator"() : () -> ()
 }
 
 // -----
 
 func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) {
-  %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
-  // CHECK: xla_lhlo.if
-  %1 = "xla_lhlo.if"(%arg2, %0, %0) ( {
+  %0 = "lmhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>
+  // CHECK: lmhlo.if
+  %1 = "lmhlo.if"(%arg2, %0, %0) ( {
     ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):
-    %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
-    %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
-    "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+    %2 = "lmhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+    %3 = "lmhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+    "lmhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
   },  {
     ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>):  // no predecessors
-    %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
-    %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
-    "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
+    %2 = "lmhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32>
+    %3 = "lmhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>>
+    "lmhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> ()
   }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>>
-  "xla_lhlo.terminator"() : () -> ()
+  "lmhlo.terminator"() : () -> ()
 }
 
 // CHECK-LABEL: func @lhlo_if_empty_arg
 func @lhlo_if_empty_arg(%arg0: memref<i1>) {
   %cst = constant 1.000000e+00 : f32
   %cst_0 = constant 0.000000e+00 : f32
-  %0 = "xla_lhlo.tuple"() : () -> tuple<>
-  // CHECK: xla_lhlo.if
-  %1 = "xla_lhlo.if"(%arg0, %0, %0) ( {
+  %0 = "lmhlo.tuple"() : () -> tuple<>
+  // CHECK: lmhlo.if
+  %1 = "lmhlo.if"(%arg0, %0, %0) ( {
     ^bb0(%arg1: tuple<>):
-    %2 = "xla_lhlo.tuple"(%cst, %cst_0) : (f32, f32) -> tuple<f32, f32>
-    "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> ()
+    %2 = "lmhlo.tuple"(%cst, %cst_0) : (f32, f32) -> tuple<f32, f32>
+    "lmhlo.yield"(%2) : (tuple<f32, f32>) -> ()
   },  {
     ^bb0(%arg1: tuple<>):
-    %2 = "xla_lhlo.tuple"(%cst_0, %cst) : (f32, f32) -> tuple<f32, f32>
-    "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> ()
+    %2 = "lmhlo.tuple"(%cst_0, %cst) : (f32, f32) -> tuple<f32, f32>
+    "lmhlo.yield"(%2) : (tuple<f32, f32>) -> ()
   }) : (memref<i1>, tuple<>, tuple<>) -> tuple<f32, f32>
-  "xla_lhlo.terminator"() : () -> ()
+  "lmhlo.terminator"() : () -> ()
 }

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 75fcdf5314651..61c43dda22362 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -13,6 +13,7 @@
 // CHECK-NEXT: gpu
 // CHECK-NEXT: linalg
 // CHECK-NEXT: llvm
+// CHECK-NEXT: lmhlo
 // CHECK-NEXT: math
 // CHECK-NEXT: memref
 // CHECK-NEXT: nvvm


        


More information about the llvm-branch-commits mailing list