[Mlir-commits] [mlir] d425f58 - [mlir] Make ValueShapeRange a new class
Jacques Pienaar
llvmlistbot at llvm.org
Mon Jul 26 17:08:44 PDT 2021
Author: Jacques Pienaar
Date: 2021-07-26T17:08:32-07:00
New Revision: d425f58939ad9ef88ee2d1578a87c25d4e121128
URL: https://github.com/llvm/llvm-project/commit/d425f58939ad9ef88ee2d1578a87c25d4e121128
DIFF: https://github.com/llvm/llvm-project/commit/d425f58939ad9ef88ee2d1578a87c25d4e121128.diff
LOG: [mlir] Make ValueShapeRange a new class
Retaining old interface and should be constructable as previous, change would have been NFC except it this doesn't implicitly work with OpAdaptor generated in C++14.
Differential Revision: https://reviews.llvm.org/D106772
Added:
Modified:
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 0d6755f211d59..0535ff1e714dc 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -78,11 +78,31 @@ class ShapedTypeComponents {
/// Range of values and shapes (corresponding effectively to Shapes dialect's
/// ValueShape type concept).
-using ValueShapeRange = ValueRange;
+class ValueShapeRange : public ValueRange::RangeBaseT {
+public:
+ ValueShapeRange(ValueRange values) : RangeBaseT(values) {}
+ template <typename Arg, typename = typename std::enable_if_t<
+ std::is_constructible<ValueRange, Arg>::value>>
+ ValueShapeRange(Arg &&arg)
+ : ValueShapeRange(ValueRange(std::forward<Arg>(arg))) {}
+ ValueShapeRange(const std::initializer_list<Value> &values)
+ : ValueShapeRange(ValueRange(values)) {}
+
+ /// Returns the types of the values within this range.
+ /// Note: This returns only the types of Values in the ValueRange and not a
+ /// more refined type.
+ using type_iterator = ValueTypeIterator<iterator>;
+ using type_range = ValueTypeRange<ValueRange>;
+ type_range getTypes() const { return {begin(), end()}; }
+ auto getType() const { return getTypes(); }
+
+ /// Returns the Values in the ValueRange.
+ ValueRange getValues() const { return ValueRange(begin(), end()); };
+};
namespace detail {
-// Helper function to infer return tensor returns types given element and shape
-// inference function.
+// Helper function to infer return tensor returns types given element and
+// shape inference function.
//
// TODO: Consider generating typedefs for trait member functions if this usage
// becomes more common.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1a9d28834e3a5..5ae614f125f46 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -736,7 +736,8 @@ static LogicalResult ReduceInferReturnTypes(
#define REDUCE_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::llvm::Optional<Location> location, \
- ValueRange operands, DictionaryAttr attributes, RegionRange regions, \
+ ValueShapeRange operands, DictionaryAttr attributes, \
+ RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
return ReduceInferReturnTypes(operands[0], \
attributes.get("axis").cast<IntegerAttr>(), \
@@ -802,7 +803,8 @@ static LogicalResult NAryInferReturnTypes(
#define NARY_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::llvm::Optional<Location> location, \
- ValueRange operands, DictionaryAttr attributes, RegionRange regions, \
+ ValueShapeRange operands, DictionaryAttr attributes, \
+ RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
return NAryInferReturnTypes(operands, inferredReturnShapes); \
}
@@ -892,7 +894,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
- Conv2DOp::Adaptor adaptor(operands);
+ Conv2DOp::Adaptor adaptor(operands.getValues());
int32_t inputWidth = ShapedType::kDynamicSize;
int32_t inputHeight = ShapedType::kDynamicSize;
@@ -953,7 +955,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
- Conv2DOp::Adaptor adaptor(operands);
+ Conv2DOp::Adaptor adaptor(operands.getValues());
int32_t inputWidth = ShapedType::kDynamicSize;
int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1040,7 +1042,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
- DepthwiseConv2DOp::Adaptor adaptor(operands);
+ DepthwiseConv2DOp::Adaptor adaptor(operands.getValues());
int32_t inputWidth = ShapedType::kDynamicSize;
int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1114,7 +1116,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- TransposeConv2DOp::Adaptor adaptor(operands);
+ TransposeConv2DOp::Adaptor adaptor(operands.getValues());
llvm::SmallVector<int64_t> outputShape;
getI64Values(attributes.get("out_shape").cast<ArrayAttr>(), outputShape);
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 18be25a32cada..76a7f41fd186d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -785,7 +785,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
- auto operandType = *operands.getTypes().begin();
+ auto operandType = operands.front().getType();
auto sval = operandType.dyn_cast<ShapedType>();
if (!sval) {
return emitOptionalError(location, "only shaped type operands allowed");
More information about the Mlir-commits
mailing list