[Mlir-commits] [mlir] d425f58 - [mlir] Make ValueShapeRange a new class

Jacques Pienaar llvmlistbot at llvm.org
Mon Jul 26 17:08:44 PDT 2021


Author: Jacques Pienaar
Date: 2021-07-26T17:08:32-07:00
New Revision: d425f58939ad9ef88ee2d1578a87c25d4e121128

URL: https://github.com/llvm/llvm-project/commit/d425f58939ad9ef88ee2d1578a87c25d4e121128
DIFF: https://github.com/llvm/llvm-project/commit/d425f58939ad9ef88ee2d1578a87c25d4e121128.diff

LOG: [mlir] Make ValueShapeRange a new class

Retaining old interface and should be constructable as previous, change would have been NFC except it this doesn't implicitly work with OpAdaptor generated in C++14.

Differential Revision: https://reviews.llvm.org/D106772

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 0d6755f211d59..0535ff1e714dc 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -78,11 +78,31 @@ class ShapedTypeComponents {
 
 /// Range of values and shapes (corresponding effectively to Shapes dialect's
 /// ValueShape type concept).
-using ValueShapeRange = ValueRange;
+class ValueShapeRange : public ValueRange::RangeBaseT {
+public:
+  ValueShapeRange(ValueRange values) : RangeBaseT(values) {}
+  template <typename Arg, typename = typename std::enable_if_t<
+                              std::is_constructible<ValueRange, Arg>::value>>
+  ValueShapeRange(Arg &&arg)
+      : ValueShapeRange(ValueRange(std::forward<Arg>(arg))) {}
+  ValueShapeRange(const std::initializer_list<Value> &values)
+      : ValueShapeRange(ValueRange(values)) {}
+
+  /// Returns the types of the values within this range.
+  /// Note: This returns only the types of Values in the ValueRange and not a
+  /// more refined type.
+  using type_iterator = ValueTypeIterator<iterator>;
+  using type_range = ValueTypeRange<ValueRange>;
+  type_range getTypes() const { return {begin(), end()}; }
+  auto getType() const { return getTypes(); }
+
+  /// Returns the Values in the ValueRange.
+  ValueRange getValues() const { return ValueRange(begin(), end()); };
+};
 
 namespace detail {
-// Helper function to infer return tensor returns types given element and shape
-// inference function.
+// Helper function to infer return tensor returns types given element and
+// shape inference function.
 //
 // TODO: Consider generating typedefs for trait member functions if this usage
 // becomes more common.

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1a9d28834e3a5..5ae614f125f46 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -736,7 +736,8 @@ static LogicalResult ReduceInferReturnTypes(
 #define REDUCE_SHAPE_INFER(OP)                                                 \
   LogicalResult OP::inferReturnTypeComponents(                                 \
       MLIRContext *context, ::llvm::Optional<Location> location,               \
-      ValueRange operands, DictionaryAttr attributes, RegionRange regions,     \
+      ValueShapeRange operands, DictionaryAttr attributes,                     \
+      RegionRange regions,                                                     \
       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
     return ReduceInferReturnTypes(operands[0],                                 \
                                   attributes.get("axis").cast<IntegerAttr>(),  \
@@ -802,7 +803,8 @@ static LogicalResult NAryInferReturnTypes(
 #define NARY_SHAPE_INFER(OP)                                                   \
   LogicalResult OP::inferReturnTypeComponents(                                 \
       MLIRContext *context, ::llvm::Optional<Location> location,               \
-      ValueRange operands, DictionaryAttr attributes, RegionRange regions,     \
+      ValueShapeRange operands, DictionaryAttr attributes,                     \
+      RegionRange regions,                                                     \
       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
     return NAryInferReturnTypes(operands, inferredReturnShapes);               \
   }
@@ -892,7 +894,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
-  Conv2DOp::Adaptor adaptor(operands);
+  Conv2DOp::Adaptor adaptor(operands.getValues());
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -953,7 +955,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
-  Conv2DOp::Adaptor adaptor(operands);
+  Conv2DOp::Adaptor adaptor(operands.getValues());
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1040,7 +1042,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
-  DepthwiseConv2DOp::Adaptor adaptor(operands);
+  DepthwiseConv2DOp::Adaptor adaptor(operands.getValues());
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1114,7 +1116,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  TransposeConv2DOp::Adaptor adaptor(operands);
+  TransposeConv2DOp::Adaptor adaptor(operands.getValues());
   llvm::SmallVector<int64_t> outputShape;
   getI64Values(attributes.get("out_shape").cast<ArrayAttr>(), outputShape);
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 18be25a32cada..76a7f41fd186d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -785,7 +785,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
     DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   // Create return type consisting of the last element of the first operand.
-  auto operandType = *operands.getTypes().begin();
+  auto operandType = operands.front().getType();
   auto sval = operandType.dyn_cast<ShapedType>();
   if (!sval) {
     return emitOptionalError(location, "only shaped type operands allowed");


        


More information about the Mlir-commits mailing list