[llvm-branch-commits] [mlir] 28070f1 - xla_lhlo dialect fixes to work with upstream

Uday Bondhugula via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:09 PDT 2021


Author: Uday Bondhugula
Date: 2021-09-22T14:28:52+05:30
New Revision: 28070f1be700f812dd228e51ecbe1efb3b4a73ea

URL: https://github.com/llvm/llvm-project/commit/28070f1be700f812dd228e51ecbe1efb3b4a73ea
DIFF: https://github.com/llvm/llvm-project/commit/28070f1be700f812dd228e51ecbe1efb3b4a73ea.diff

LOG: xla_lhlo dialect fixes to work with upstream

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
    mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
    mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
    mlir/test/Dialect/LHLO/lhlo_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
index 45309f5110ad..c17ee6d63bd4 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.h
@@ -20,28 +20,25 @@ limitations under the License.
 
 #include "llvm/ADT/StringRef.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 
 namespace mlir {
 class OpBuilder;
+}  // end namespace mlir
 
 #include "mlir/Dialect/LHLO/IR/LHLOStructs.h.inc"
 
-namespace xla_lhlo {
-
 #include "mlir/Dialect/LHLO/IR/LHLOOpsDialect.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LHLO/IR/LHLOOps.h.inc"
 
-}  // namespace xla_lhlo
-}  // end namespace mlir
 
 #endif  // MLIR_DIALECT_LHLO_IR_LHLO_OPS_H_

diff  --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
index 3ae266e82c3c..7f83425f138c 100644
--- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
+++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td
@@ -24,7 +24,7 @@ include "mlir/Dialect/LHLO/IR/HLOOpsBase.td"
 
 def LHLO_Dialect : Dialect {
   let name = "xla_lhlo";
-  let cppNamespace = "xla_lhlo";
+  let cppNamespace = "::mlir::xla_lhlo";
 }
 
 //===----------------------------------------------------------------------===//
@@ -382,7 +382,7 @@ def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
   StructFieldAttr<"output_feature_dimension", I64Attr>,
   StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
 
-  let description = "Structure of dimension information for conv op";
+  let summary = "Structure of dimension information for conv op";
 }
 
 def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
@@ -427,7 +427,7 @@ def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", LHLO_Dialect, [
                 StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
                 StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
   ]> {
-  let description = "Structure of dimension information for dot product";
+  let summary = "Structure of dimension information for dot product";
 }
 
 def LHLO_DotGeneralOp : LHLO_Op<"dot_general", []>, BASE_HLO_DotGeneralOp {
@@ -452,7 +452,7 @@ def GatherDimensionNumbers: StructAttr<"GatherDimensionNumbers", LHLO_Dialect,
       StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
       StructFieldAttr<"start_index_map", I64ElementsAttr>,
       StructFieldAttr<"index_vector_dim", I64Attr>]> {
-  let description = "Structure of dimension information for gather";
+  let summary = "Structure of dimension information for gather";
 }
 
 def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
@@ -489,7 +489,7 @@ def ScatterDimensionNumbers
       StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
       StructFieldAttr<"index_vector_dim", I64Attr>
     ]> {
-  let description = "Structure of dimension information for scatter";
+  let summary = "Structure of dimension information for scatter";
 }
 
 def LHLO_ScatterOp: LHLO_Op<"scatter", [RecursiveSideEffects]>,
@@ -597,9 +597,7 @@ def LHLO_TupleOp : LHLO_ReadOnlyOp<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
   let arguments = (ins Arg<Variadic<LHLO_BufferOrIntOrFP>, "", [MemRead]>:$input);
   let results = (outs NestedTupleOf<[LHLO_BufferOrIntOrFP]>);
 
-  let builders = [OpBuilder<
-                  "OpBuilder &builder, OperationState &results, "
-                  "ValueRange values">];
+  let builders = [OpBuilder<(ins "ValueRange":$values)>];
 }
 
 def LHLO_WhileOp
@@ -642,8 +640,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 
   let skipDefaultBuilders = 1;
   let builders = [
-     OpBuilder<"OpBuilder &builder, OperationState &result, "
-               "ArrayRef<NamedAttribute> attributes">
+     OpBuilder<(ins "ArrayRef<NamedAttribute>":$attributes)>
    ];
 }
 
@@ -653,9 +650,8 @@ def TerminatorOp :
   let description = [{
     Terminator operation for the LHLO dialect.
   }];
-  let builders = [OpBuilder<
-    "OpBuilder &b, OperationState &result, ValueRange operands",
-    [{ build(b, result, llvm::None, operands, llvm::None); }]
+  let builders = [OpBuilder<(ins "ValueRange":$operands),
+    [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]
   >];
 }
 

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 36dd26ba817e..e9869365e54e 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -48,7 +48,6 @@
 
 namespace mlir {
 
-<<<<<<< HEAD
 /// Add all the MLIR dialects to the provided registry.
 inline void registerAllDialects(DialectRegistry &registry) {
   // clang-format off
@@ -80,34 +79,9 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   sparse_tensor::SparseTensorDialect,
                   tensor::TensorDialect,
                   tosa::TosaDialect,
-                  x86vector::X86VectorDialect>();
+                  x86vector::X86VectorDialect,
+                  xla_lhlo::LHLODialect>();
   // clang-format on
-=======
-// This function should be called before creating any MLIRContext if one expect
-// all the possible dialects to be made available to the context automatically.
-inline void registerAllDialects() {
-  static bool init_once = []() {
-    registerDialect<AffineDialect>();
-    registerDialect<xla_lhlo::LHLODialect>();
-    registerDialect<avx512::AVX512Dialect>();
-    registerDialect<gpu::GPUDialect>();
-    registerDialect<LLVM::LLVMAVX512Dialect>();
-    registerDialect<LLVM::LLVMDialect>();
-    registerDialect<linalg::LinalgDialect>();
-    registerDialect<scf::SCFDialect>();
-    registerDialect<omp::OpenMPDialect>();
-    registerDialect<quant::QuantizationDialect>();
-    registerDialect<spirv::SPIRVDialect>();
-    registerDialect<StandardOpsDialect>();
-    registerDialect<vector::VectorDialect>();
-    registerDialect<NVVM::NVVMDialect>();
-    registerDialect<ROCDL::ROCDLDialect>();
-    registerDialect<SDBMDialect>();
-    registerDialect<shape::ShapeDialect>();
-    return true;
-  }();
-  (void)init_once;
->>>>>>> f0d77094085b... [MLIR] Add xla_lhlo dialect from tensorflow
 }
 
 /// Append all the MLIR dialects to the registry contained in the given context.

diff  --git a/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt b/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
index d99e31c89729..32d8597aaa8c 100644
--- a/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LHLO/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRLHLOOps
 
   LINK_LIBS PUBLIC
   MLIRIR
-  MLIRStandardOps
+  MLIRStandard
+  MLIRMemRef
   )

diff  --git a/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc b/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
index 871872a22f68..26251379be8f 100644
--- a/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
+++ b/mlir/lib/Dialect/LHLO/IR/LHLOOps.cc
@@ -20,17 +20,12 @@ limitations under the License.
 #include <stddef.h>
 #include <stdint.h>
 
-#include "llvm/ADT/APFloat.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/LHLO/IR/LHLOOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
@@ -39,17 +34,25 @@ limitations under the License.
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
-#include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FormatVariadic.h"
 
-namespace mlir {
+
+using namespace mlir;
+using namespace mlir::xla_lhlo;
+
+#include "mlir/Dialect/LHLO/IR/LHLOOpsDialect.cpp.inc"
 #include "mlir/Dialect/LHLO/IR/LHLOStructs.cpp.inc"
-namespace xla_lhlo {
 
-LHLODialect::LHLODialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void LHLODialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LHLO/IR/LHLOOps.cpp.inc"
@@ -59,6 +62,9 @@ LHLODialect::LHLODialect(MLIRContext *context)
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LHLO/IR/LHLOOps.cpp.inc"
 
+namespace mlir {
+namespace xla_lhlo {
+
 //===----------------------------------------------------------------------===//
 // ConstOp.
 //===----------------------------------------------------------------------===//
@@ -73,13 +79,14 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
   LogicalResult matchAndRewrite(xla_lhlo::ConstOp op,
                                 PatternRewriter& rewriter) const override {
     Value memref = op.output();
-    if (!memref.getDefiningOp<AllocOp>()) {
+    if (!memref.getDefiningOp<memref::AllocOp>()) {
       return failure();
     }
 
     // Check that all uses of the memref are either DeallocOps or this op.
-    for (Operation* user : memref.getUsers())
-      if (user != op.getOperation() && !isa<DeallocOp>(user)) return failure();
+    for (Operation *user : memref.getUsers())
+      if (user != op.getOperation() && !isa<memref::DeallocOp>(user))
+        return failure();
 
     rewriter.eraseOp(op);
     return success();

diff  --git a/mlir/test/Dialect/LHLO/lhlo_ops.mlir b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
index a288ccffc3e3..30a84c5b4bcb 100644
--- a/mlir/test/Dialect/LHLO/lhlo_ops.mlir
+++ b/mlir/test/Dialect/LHLO/lhlo_ops.mlir
@@ -210,11 +210,11 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
 func @while_op(%arg0: memref<4x?x16xf32>, %arg1: memref<4x?x16xf32>) {
     %c0_i32 = constant 0 : i32
     %c4_i32 = constant 4 : i32
-    %2 = alloc() : memref<4xi32>
+    %2 = memref.alloc() : memref<4xi32>
     "xla_lhlo.rng_uniform"(%c0_i32, %c4_i32, %2) : (i32, i32, memref<4xi32>) -> ()
     %c0_i32_0 = constant 0 : i32
     %3 = "xla_lhlo.tuple"(%c0_i32_0, %2) : (i32, memref<4xi32>) -> tuple<i32, memref<4xi32>>
-    dealloc %2 : memref<4xi32>
+    memref.dealloc %2 : memref<4xi32>
     %4 = "xla_lhlo.while"(%3) ( {
     ^bb0(%arg2: tuple<i32, memref<4xi32>>):  // no predecessors
       %7 = "xla_lhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple<i32, memref<4xi32>>) -> i32


        


More information about the llvm-branch-commits mailing list