[Mlir-commits] [mlir] fcd4ee5 - [mlir] Make ShapedTypeComponents contructible from ShapeAdaptor
Chia-hung Duan
llvmlistbot at llvm.org
Tue Mar 8 19:38:42 PST 2022
Author: Chia-hung Duan
Date: 2022-03-09T03:35:24Z
New Revision: fcd4ee52cd6dc80a8ae4b7af68d13a37fb761cfe
URL: https://github.com/llvm/llvm-project/commit/fcd4ee52cd6dc80a8ae4b7af68d13a37fb761cfe
DIFF: https://github.com/llvm/llvm-project/commit/fcd4ee52cd6dc80a8ae4b7af68d13a37fb761cfe.diff
LOG: [mlir] Make ShapedTypeComponents contructible from ShapeAdaptor
ValueShapeRange::getShape() returns ShapeAdaptor rather than ShapedType
and ShapeAdaptor allows implicit conversion to bool. It ends up that
ShapedTypeComponents can be constructed with ShapeAdaptor incorrectly.
The reason is that the type trait
std::is_constructible<ShapeStorageT, Arg>::value
is fulfilled because ShapeAdaptor can be converted to bool and it can be
used to construct ShapeStorageT. In the end, we won't give any warning
or error message when doing things like
inferredReturnShapes.emplace_back(valueShapeRange.getShape(0));
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D120845
Added:
Modified:
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 3ddd2c16667d9..eaf1f6ced34b8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -25,67 +25,9 @@
namespace mlir {
+class ShapedTypeComponents;
using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
-/// ShapedTypeComponents that represents the components of a ShapedType.
-/// The components consist of
-/// - A ranked or unranked shape with the dimension specification match those
-/// of ShapeType's getShape() (e.g., dynamic dimension represented using
-/// ShapedType::kDynamicSize)
-/// - A element type, may be unset (nullptr)
-/// - A attribute, may be unset (nullptr)
-/// Used by ShapedType type inferences.
-class ShapedTypeComponents {
- /// Internal storage type for shape.
- using ShapeStorageT = SmallVector<int64_t, 3>;
-
-public:
- /// Default construction is an unranked shape.
- ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
- ShapedTypeComponents(Type elementType)
- : elementType(elementType), attr(nullptr), ranked(false) {}
- ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
- ranked = shapedType.hasRank();
- elementType = shapedType.getElementType();
- if (ranked)
- dims = llvm::to_vector<4>(shapedType.getShape());
- }
- template <typename Arg, typename = typename std::enable_if_t<
- std::is_constructible<ShapeStorageT, Arg>::value>>
- ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
- Attribute attr = nullptr)
- : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
- ranked(true) {}
- ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
- Attribute attr = nullptr)
- : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
- ranked(true) {}
-
- /// Return the dimensions of the shape.
- /// Requires: shape is ranked.
- ArrayRef<int64_t> getDims() const {
- assert(ranked && "requires ranked shape");
- return dims;
- }
-
- /// Return whether the shape has a rank.
- bool hasRank() const { return ranked; };
-
- /// Return the element type component.
- Type getElementType() const { return elementType; };
-
- /// Return the raw attribute component.
- Attribute getAttribute() const { return attr; };
-
-private:
- friend class ShapeAdaptor;
-
- ShapeStorageT dims;
- Type elementType;
- Attribute attr;
- bool ranked{false};
-};
-
/// Adaptor class to abstract the
diff erences between whether value is from
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
class ShapeAdaptor {
@@ -137,7 +79,7 @@ class ShapeAdaptor {
int64_t getNumElements() const;
/// Returns whether valid (non-null) shape.
- operator bool() const { return !val.isNull(); }
+ explicit operator bool() const { return !val.isNull(); }
/// Dumps textual repesentation to stderr.
void dump() const;
@@ -148,6 +90,71 @@ class ShapeAdaptor {
PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
};
+/// ShapedTypeComponents that represents the components of a ShapedType.
+/// The components consist of
+/// - A ranked or unranked shape with the dimension specification match those
+/// of ShapeType's getShape() (e.g., dynamic dimension represented using
+/// ShapedType::kDynamicSize)
+/// - A element type, may be unset (nullptr)
+/// - A attribute, may be unset (nullptr)
+/// Used by ShapedType type inferences.
+class ShapedTypeComponents {
+ /// Internal storage type for shape.
+ using ShapeStorageT = SmallVector<int64_t, 3>;
+
+public:
+ /// Default construction is an unranked shape.
+ ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
+ ShapedTypeComponents(Type elementType)
+ : elementType(elementType), attr(nullptr), ranked(false) {}
+ ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
+ ranked = shapedType.hasRank();
+ elementType = shapedType.getElementType();
+ if (ranked)
+ dims = llvm::to_vector<4>(shapedType.getShape());
+ }
+ ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
+ ranked = adaptor.hasRank();
+ elementType = adaptor.getElementType();
+ if (ranked)
+ adaptor.getDims(*this);
+ }
+ template <typename Arg, typename = typename std::enable_if_t<
+ std::is_constructible<ShapeStorageT, Arg>::value>>
+ ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
+ Attribute attr = nullptr)
+ : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
+ ranked(true) {}
+ ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
+ Attribute attr = nullptr)
+ : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
+ ranked(true) {}
+
+ /// Return the dimensions of the shape.
+ /// Requires: shape is ranked.
+ ArrayRef<int64_t> getDims() const {
+ assert(ranked && "requires ranked shape");
+ return dims;
+ }
+
+ /// Return whether the shape has a rank.
+ bool hasRank() const { return ranked; };
+
+ /// Return the element type component.
+ Type getElementType() const { return elementType; };
+
+ /// Return the raw attribute component.
+ Attribute getAttribute() const { return attr; };
+
+private:
+ friend class ShapeAdaptor;
+
+ ShapeStorageT dims;
+ Type elementType;
+ Attribute attr;
+ bool ranked{false};
+};
+
/// Range of values and shapes (corresponding effectively to Shapes dialect's
/// ValueShape type concept).
// Currently this exposes the Value (of operands) and Type of the Value. This is
More information about the Mlir-commits
mailing list