[Mlir-commits] [mlir] 1d6a724 - [MLIR] Change FunctionType::get() and TupleType::get() to use TypeRange

Rahul Joshi llvmlistbot at llvm.org
Tue Aug 4 12:44:07 PDT 2020


Author: Rahul Joshi
Date: 2020-08-04T12:43:40-07:00
New Revision: 1d6a724aa1c11a37ff083cf637f91852e96ce11f

URL: https://github.com/llvm/llvm-project/commit/1d6a724aa1c11a37ff083cf637f91852e96ce11f
DIFF: https://github.com/llvm/llvm-project/commit/1d6a724aa1c11a37ff083cf637f91852e96ce11f.diff

LOG: [MLIR] Change FunctionType::get() and TupleType::get() to use TypeRange

- Moved TypeRange into its own header/cpp file, and add hashing support.
- Change FunctionType::get() and TupleType::get() to use TypeRange

Differential Revision: https://reviews.llvm.org/D85075

Added: 
    mlir/include/mlir/IR/TypeRange.h
    mlir/lib/IR/TypeRange.cpp

Modified: 
    flang/lib/Lower/RTBuilder.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/IR/StandardTypes.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/IR/TypeDetail.h
    mlir/lib/IR/Types.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/RTBuilder.h b/flang/lib/Lower/RTBuilder.h
index 2f66fa8efac0..3855f6816d6e 100644
--- a/flang/lib/Lower/RTBuilder.h
+++ b/flang/lib/Lower/RTBuilder.h
@@ -168,7 +168,7 @@ constexpr TypeBuilderFunc getModel<const Fortran::runtime::NamelistGroup &>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     // FIXME: a namelist group must be some well-defined data structure, use a
     // tuple as a proxy for the moment
-    return mlir::TupleType::get(llvm::None, context);
+    return mlir::TupleType::get(context);
   };
 }
 template <>

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 4256727905f5..c27585a6e343 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -73,8 +73,8 @@ class Builder {
   IntegerType getI64Type();
   IntegerType getIntegerType(unsigned width);
   IntegerType getIntegerType(unsigned width, bool isSigned);
-  FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
-  TupleType getTupleType(ArrayRef<Type> elementTypes);
+  FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+  TupleType getTupleType(TypeRange elementTypes);
   NoneType getNoneType();
 
   /// Get or construct an instance of the type 'ty' with provided arguments.

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index e3afaf316154..23a37cc2e2a9 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/InterfaceSupport.h"
@@ -624,104 +625,6 @@ class OpPrintingFlags {
 // Operation Value-Iterators
 //===----------------------------------------------------------------------===//
 
-//===----------------------------------------------------------------------===//
-// TypeRange
-
-/// This class provides an abstraction over the various 
diff erent ranges of
-/// value types. In many cases, this prevents the need to explicitly materialize
-/// a SmallVector/std::vector. This class should be used in places that are not
-/// suitable for a more derived type (e.g. ArrayRef) or a template range
-/// parameter.
-class TypeRange
-    : public llvm::detail::indexed_accessor_range_base<
-          TypeRange,
-          llvm::PointerUnion<const Value *, const Type *, OpOperand *>, Type,
-          Type, Type> {
-public:
-  using RangeBaseT::RangeBaseT;
-  TypeRange(ArrayRef<Type> types = llvm::None);
-  explicit TypeRange(OperandRange values);
-  explicit TypeRange(ResultRange values);
-  explicit TypeRange(ValueRange values);
-  explicit TypeRange(ArrayRef<Value> values);
-  explicit TypeRange(ArrayRef<BlockArgument> values)
-      : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
-  template <typename ValueRangeT>
-  TypeRange(ValueTypeRange<ValueRangeT> values)
-      : TypeRange(ValueRangeT(values.begin().getCurrent(),
-                              values.end().getCurrent())) {}
-  template <typename Arg,
-            typename = typename std::enable_if_t<
-                std::is_constructible<ArrayRef<Type>, Arg>::value>>
-  TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
-  TypeRange(std::initializer_list<Type> types)
-      : TypeRange(ArrayRef<Type>(types)) {}
-
-private:
-  /// The owner of the range is either:
-  /// * A pointer to the first element of an array of values.
-  /// * A pointer to the first element of an array of types.
-  /// * A pointer to the first element of an array of operands.
-  using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *>;
-
-  /// See `llvm::detail::indexed_accessor_range_base` for details.
-  static OwnerT offset_base(OwnerT object, ptr
diff _t index);
-  /// See `llvm::detail::indexed_accessor_range_base` for details.
-  static Type dereference_iterator(OwnerT object, ptr
diff _t index);
-
-  /// Allow access to `offset_base` and `dereference_iterator`.
-  friend RangeBaseT;
-};
-
-//===----------------------------------------------------------------------===//
-// ValueTypeRange
-
-/// This class implements iteration on the types of a given range of values.
-template <typename ValueIteratorT>
-class ValueTypeIterator final
-    : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
-  static Type unwrap(Value value) { return value.getType(); }
-
-public:
-  using reference = Type;
-
-  /// Provide a const dereference method.
-  Type operator*() const { return unwrap(*this->I); }
-
-  /// Initializes the type iterator to the specified value iterator.
-  ValueTypeIterator(ValueIteratorT it)
-      : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
-};
-
-/// This class implements iteration on the types of a given range of values.
-template <typename ValueRangeT>
-class ValueTypeRange final
-    : public llvm::iterator_range<
-          ValueTypeIterator<typename ValueRangeT::iterator>> {
-public:
-  using llvm::iterator_range<
-      ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
-  template <typename Container>
-  ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
-
-  /// Compare this range with another.
-  template <typename OtherT>
-  bool operator==(const OtherT &other) const {
-    return llvm::size(*this) == llvm::size(other) &&
-           std::equal(this->begin(), this->end(), other.begin());
-  }
-  template <typename OtherT>
-  bool operator!=(const OtherT &other) const {
-    return !(*this == other);
-  }
-};
-
-template <typename RangeT>
-inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
-  return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
-         std::equal(lhs.begin(), lhs.end(), rhs.begin());
-}
-
 //===----------------------------------------------------------------------===//
 // OperandRange
 

diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 1ac24359cbb6..3daf226603a8 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -632,10 +632,10 @@ class TupleType
 
   /// Get or create a new TupleType with the provided element types. Assumes the
   /// arguments define a well-formed type.
-  static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
+  static TupleType get(TypeRange elementTypes, MLIRContext *context);
 
   /// Get or create an empty tuple type.
-  static TupleType get(MLIRContext *context) { return get({}, context); }
+  static TupleType get(MLIRContext *context);
 
   /// Return the elements types for this tuple.
   ArrayRef<Type> getTypes() const;

diff  --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
new file mode 100644
index 000000000000..8e41ad1665f9
--- /dev/null
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -0,0 +1,181 @@
+//===- TypeRange.h ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the TypeRange and ValueTypeRange classes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_TYPERANGE_H
+#define MLIR_IR_TYPERANGE_H
+
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/PointerUnion.h"
+
+namespace mlir {
+class OperandRange;
+class ResultRange;
+class Type;
+class Value;
+class ValueRange;
+template <typename ValueRangeT>
+class ValueTypeRange;
+
+//===----------------------------------------------------------------------===//
+// TypeRange
+
+/// This class provides an abstraction over the various 
diff erent ranges of
+/// value types. In many cases, this prevents the need to explicitly materialize
+/// a SmallVector/std::vector. This class should be used in places that are not
+/// suitable for a more derived type (e.g. ArrayRef) or a template range
+/// parameter.
+class TypeRange
+    : public llvm::detail::indexed_accessor_range_base<
+          TypeRange,
+          llvm::PointerUnion<const Value *, const Type *, OpOperand *>, Type,
+          Type, Type> {
+public:
+  using RangeBaseT::RangeBaseT;
+  TypeRange(ArrayRef<Type> types = llvm::None);
+  explicit TypeRange(OperandRange values);
+  explicit TypeRange(ResultRange values);
+  explicit TypeRange(ValueRange values);
+  explicit TypeRange(ArrayRef<Value> values);
+  explicit TypeRange(ArrayRef<BlockArgument> values)
+      : TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
+  template <typename ValueRangeT>
+  TypeRange(ValueTypeRange<ValueRangeT> values)
+      : TypeRange(ValueRangeT(values.begin().getCurrent(),
+                              values.end().getCurrent())) {}
+  template <typename Arg,
+            typename = typename std::enable_if_t<
+                std::is_constructible<ArrayRef<Type>, Arg>::value>>
+  TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
+  TypeRange(std::initializer_list<Type> types)
+      : TypeRange(ArrayRef<Type>(types)) {}
+
+private:
+  /// The owner of the range is either:
+  /// * A pointer to the first element of an array of values.
+  /// * A pointer to the first element of an array of types.
+  /// * A pointer to the first element of an array of operands.
+  using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *>;
+
+  /// See `llvm::detail::indexed_accessor_range_base` for details.
+  static OwnerT offset_base(OwnerT object, ptr
diff _t index);
+  /// See `llvm::detail::indexed_accessor_range_base` for details.
+  static Type dereference_iterator(OwnerT object, ptr
diff _t index);
+
+  /// Allow access to `offset_base` and `dereference_iterator`.
+  friend RangeBaseT;
+};
+
+/// Make TypeRange hashable.
+inline ::llvm::hash_code hash_value(TypeRange arg) {
+  return ::llvm::hash_combine_range(arg.begin(), arg.end());
+}
+
+//===----------------------------------------------------------------------===//
+// ValueTypeRange
+
+/// This class implements iteration on the types of a given range of values.
+template <typename ValueIteratorT>
+class ValueTypeIterator final
+    : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
+  static Type unwrap(Value value) { return value.getType(); }
+
+public:
+  using reference = Type;
+
+  /// Provide a const dereference method.
+  Type operator*() const { return unwrap(*this->I); }
+
+  /// Initializes the type iterator to the specified value iterator.
+  ValueTypeIterator(ValueIteratorT it)
+      : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
+};
+
+/// This class implements iteration on the types of a given range of values.
+template <typename ValueRangeT>
+class ValueTypeRange final
+    : public llvm::iterator_range<
+          ValueTypeIterator<typename ValueRangeT::iterator>> {
+public:
+  using llvm::iterator_range<
+      ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
+  template <typename Container>
+  ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
+
+  /// Compare this range with another.
+  template <typename OtherT>
+  bool operator==(const OtherT &other) const {
+    return llvm::size(*this) == llvm::size(other) &&
+           std::equal(this->begin(), this->end(), other.begin());
+  }
+  template <typename OtherT>
+  bool operator!=(const OtherT &other) const {
+    return !(*this == other);
+  }
+};
+
+template <typename RangeT>
+inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
+  return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
+         std::equal(lhs.begin(), lhs.end(), rhs.begin());
+}
+
+} // namespace mlir
+
+namespace llvm {
+
+// Provide DenseMapInfo for TypeRange.
+template <>
+struct DenseMapInfo<mlir::TypeRange> {
+  static mlir::TypeRange getEmptyKey() {
+    return mlir::TypeRange(getEmptyKeyPointer(), 0);
+  }
+
+  static mlir::TypeRange getTombstoneKey() {
+    return mlir::TypeRange(getTombstoneKeyPointer(), 0);
+  }
+
+  static unsigned getHashValue(mlir::TypeRange val) { return hash_value(val); }
+
+  static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
+    if (isEmptyKey(rhs))
+      return isEmptyKey(lhs);
+    if (isTombstoneKey(rhs))
+      return isTombstoneKey(lhs);
+    return lhs == rhs;
+  }
+
+private:
+  static const mlir::Type *getEmptyKeyPointer() {
+    return DenseMapInfo<mlir::Type *>::getEmptyKey();
+  }
+
+  static const mlir::Type *getTombstoneKeyPointer() {
+    return DenseMapInfo<mlir::Type *>::getTombstoneKey();
+  }
+
+  static bool isEmptyKey(mlir::TypeRange range) {
+    if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+      return type == getEmptyKeyPointer();
+    return false;
+  }
+
+  static bool isTombstoneKey(mlir::TypeRange range) {
+    if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+      return type == getTombstoneKeyPointer();
+    return false;
+  }
+};
+
+} // namespace llvm
+
+#endif // MLIR_IR_TYPERANGE_H

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 83636585c499..ed63f696a84c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -21,6 +21,7 @@ class IndexType;
 class IntegerType;
 class MLIRContext;
 class TypeStorage;
+class TypeRange;
 
 namespace detail {
 struct FunctionTypeStorage;
@@ -259,21 +260,17 @@ class FunctionType
 public:
   using Base::Base;
 
-  static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+  static FunctionType get(TypeRange inputs, TypeRange results,
                           MLIRContext *context);
 
   // Input types.
   unsigned getNumInputs() const { return getSubclassData(); }
-
   Type getInput(unsigned i) const { return getInputs()[i]; }
-
   ArrayRef<Type> getInputs() const;
 
   // Result types.
   unsigned getNumResults() const;
-
   Type getResult(unsigned i) const { return getResults()[i]; }
-
   ArrayRef<Type> getResults() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cc2200e84da5..47129d7bd615 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -549,15 +549,13 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
   else
     p << op.getOperand(0);
 
-  p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
+  auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
+  p << '(' << args << ')';
   p.printOptionalAttrDict(op.getAttrs(), {"callee"});
 
   // Reconstruct the function MLIR function type from operand and result types.
-  SmallVector<Type, 8> argTypes(
-      llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
-
   p << " : "
-    << FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
+    << FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext());
 }
 
 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a67e79ac4a7c..a78e2427b2fe 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -749,8 +749,7 @@ static LogicalResult verify(CallOp op) {
 }
 
 FunctionType CallOp::getCalleeType() {
-  SmallVector<Type, 8> argTypes(getOperandTypes());
-  return FunctionType::get(argTypes, getResultTypes(), getContext());
+  return FunctionType::get(getOperandTypes(), getResultTypes(), getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d89158ea5d87..69b1a0efb58d 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -67,12 +67,11 @@ IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
 }
 
-FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
-                                      ArrayRef<Type> results) {
+FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
   return FunctionType::get(inputs, results, context);
 }
 
-TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
+TupleType Builder::getTupleType(TypeRange elementTypes) {
   return TupleType::get(elementTypes, context);
 }
 

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index d90db0832f56..553408f6fb36 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_library(MLIRIR
   StandardTypes.cpp
   SymbolTable.cpp
   Types.cpp
+  TypeRange.cpp
   TypeUtilities.cpp
   Value.cpp
   Verifier.cpp

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index ef2b377cb1f3..b477a8a23900 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -360,45 +360,6 @@ Operation *detail::TrailingOpResult::getOwner() {
 // Operation Value-Iterators
 //===----------------------------------------------------------------------===//
 
-//===----------------------------------------------------------------------===//
-// TypeRange
-
-TypeRange::TypeRange(ArrayRef<Type> types)
-    : TypeRange(types.data(), types.size()) {}
-TypeRange::TypeRange(OperandRange values)
-    : TypeRange(values.begin().getBase(), values.size()) {}
-TypeRange::TypeRange(ResultRange values)
-    : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(),
-                                                         values.size())) {}
-TypeRange::TypeRange(ArrayRef<Value> values)
-    : TypeRange(values.data(), values.size()) {}
-TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
-  detail::ValueRangeOwner owner = values.begin().getBase();
-  if (auto *op = reinterpret_cast<Operation *>(owner.ptr.dyn_cast<void *>()))
-    this->base = op->getResultTypes().drop_front(owner.startIndex).data();
-  else if (auto *operand = owner.ptr.dyn_cast<OpOperand *>())
-    this->base = operand;
-  else
-    this->base = owner.ptr.get<const Value *>();
-}
-
-/// See `llvm::detail::indexed_accessor_range_base` for details.
-TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptr
diff _t index) {
-  if (auto *value = object.dyn_cast<const Value *>())
-    return {value + index};
-  if (auto *operand = object.dyn_cast<OpOperand *>())
-    return {operand + index};
-  return {object.dyn_cast<const Type *>() + index};
-}
-/// See `llvm::detail::indexed_accessor_range_base` for details.
-Type TypeRange::dereference_iterator(OwnerT object, ptr
diff _t index) {
-  if (auto *value = object.dyn_cast<const Value *>())
-    return (value + index)->getType();
-  if (auto *operand = object.dyn_cast<OpOperand *>())
-    return (operand + index)->get().getType();
-  return object.dyn_cast<const Type *>()[index];
-}
-
 //===----------------------------------------------------------------------===//
 // OperandRange
 

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 2d1f8d8eb6f0..70b00cf8963a 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -638,10 +638,13 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
 
 /// Get or create a new TupleType with the provided element types. Assumes the
 /// arguments define a well-formed type.
-TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
+TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
   return Base::get(context, StandardTypes::Tuple, elementTypes);
 }
 
+/// Get or create an empty tuple type.
+TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
+
 /// Return the elements types for this tuple.
 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
 

diff  --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 72f1585be2d0..783983473a38 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -15,7 +15,9 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeRange.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/TrailingObjects.h"
 
@@ -105,7 +107,7 @@ struct FunctionTypeStorage : public TypeStorage {
         inputsAndResults(inputsAndResults) {}
 
   /// The hash key used for uniquing.
-  using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
+  using KeyTy = std::pair<TypeRange, TypeRange>;
   bool operator==(const KeyTy &key) const {
     return key == KeyTy(getInputs(), getResults());
   }
@@ -113,7 +115,7 @@ struct FunctionTypeStorage : public TypeStorage {
   /// Construction.
   static FunctionTypeStorage *construct(TypeStorageAllocator &allocator,
                                         const KeyTy &key) {
-    ArrayRef<Type> inputs = key.first, results = key.second;
+    TypeRange inputs = key.first, results = key.second;
 
     // Copy the inputs and results into the bump pointer.
     SmallVector<Type, 16> types;
@@ -320,13 +322,13 @@ struct ComplexTypeStorage : public TypeStorage {
 struct TupleTypeStorage final
     : public TypeStorage,
       public llvm::TrailingObjects<TupleTypeStorage, Type> {
-  using KeyTy = ArrayRef<Type>;
+  using KeyTy = TypeRange;
 
   TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {}
 
   /// Construction.
   static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
-                                     ArrayRef<Type> key) {
+                                     TypeRange key) {
     // Allocate a new storage instance.
     auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size());
     auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage));

diff  --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
new file mode 100644
index 000000000000..f3f6fb54c707
--- /dev/null
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -0,0 +1,50 @@
+//===- TypeRange.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Operation.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TypeRange
+
+TypeRange::TypeRange(ArrayRef<Type> types)
+    : TypeRange(types.data(), types.size()) {}
+TypeRange::TypeRange(OperandRange values)
+    : TypeRange(values.begin().getBase(), values.size()) {}
+TypeRange::TypeRange(ResultRange values)
+    : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(),
+                                                         values.size())) {}
+TypeRange::TypeRange(ArrayRef<Value> values)
+    : TypeRange(values.data(), values.size()) {}
+TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
+  detail::ValueRangeOwner owner = values.begin().getBase();
+  if (auto *op = reinterpret_cast<Operation *>(owner.ptr.dyn_cast<void *>()))
+    this->base = op->getResultTypes().drop_front(owner.startIndex).data();
+  else if (auto *operand = owner.ptr.dyn_cast<OpOperand *>())
+    this->base = operand;
+  else
+    this->base = owner.ptr.get<const Value *>();
+}
+
+/// See `llvm::detail::indexed_accessor_range_base` for details.
+TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptr
diff _t index) {
+  if (const auto *value = object.dyn_cast<const Value *>())
+    return {value + index};
+  if (auto *operand = object.dyn_cast<OpOperand *>())
+    return {operand + index};
+  return {object.dyn_cast<const Type *>() + index};
+}
+/// See `llvm::detail::indexed_accessor_range_base` for details.
+Type TypeRange::dereference_iterator(OwnerT object, ptr
diff _t index) {
+  if (const auto *value = object.dyn_cast<const Value *>())
+    return (value + index)->getType();
+  if (auto *operand = object.dyn_cast<OpOperand *>())
+    return (operand + index)->get().getType();
+  return object.dyn_cast<const Type *>()[index];
+}

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 25902c2863bb..fea2cc6648e3 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,7 +34,7 @@ void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
                                MLIRContext *context) {
   return Base::get(context, Type::Kind::Function, inputs, results);
 }


        


More information about the Mlir-commits mailing list