[Mlir-commits] [mlir] fcd4ee5 - [mlir] Make ShapedTypeComponents contructible from ShapeAdaptor

Chia-hung Duan llvmlistbot at llvm.org
Tue Mar 8 19:38:42 PST 2022


Author: Chia-hung Duan
Date: 2022-03-09T03:35:24Z
New Revision: fcd4ee52cd6dc80a8ae4b7af68d13a37fb761cfe

URL: https://github.com/llvm/llvm-project/commit/fcd4ee52cd6dc80a8ae4b7af68d13a37fb761cfe
DIFF: https://github.com/llvm/llvm-project/commit/fcd4ee52cd6dc80a8ae4b7af68d13a37fb761cfe.diff

LOG: [mlir] Make ShapedTypeComponents contructible from ShapeAdaptor

ValueShapeRange::getShape() returns ShapeAdaptor rather than ShapedType
and ShapeAdaptor allows implicit conversion to bool. It ends up that
ShapedTypeComponents can be constructed with ShapeAdaptor incorrectly.
The reason is that the type trait
  std::is_constructible<ShapeStorageT, Arg>::value
is fulfilled because ShapeAdaptor can be converted to bool and it can be
used to construct ShapeStorageT. In the end, we won't give any warning
or error message when doing things like
  inferredReturnShapes.emplace_back(valueShapeRange.getShape(0));

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D120845

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 3ddd2c16667d9..eaf1f6ced34b8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -25,67 +25,9 @@
 
 namespace mlir {
 
+class ShapedTypeComponents;
 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
 
-/// ShapedTypeComponents that represents the components of a ShapedType.
-/// The components consist of
-///  - A ranked or unranked shape with the dimension specification match those
-///    of ShapeType's getShape() (e.g., dynamic dimension represented using
-///    ShapedType::kDynamicSize)
-///  - A element type, may be unset (nullptr)
-///  - A attribute, may be unset (nullptr)
-/// Used by ShapedType type inferences.
-class ShapedTypeComponents {
-  /// Internal storage type for shape.
-  using ShapeStorageT = SmallVector<int64_t, 3>;
-
-public:
-  /// Default construction is an unranked shape.
-  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
-  ShapedTypeComponents(Type elementType)
-      : elementType(elementType), attr(nullptr), ranked(false) {}
-  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
-    ranked = shapedType.hasRank();
-    elementType = shapedType.getElementType();
-    if (ranked)
-      dims = llvm::to_vector<4>(shapedType.getShape());
-  }
-  template <typename Arg, typename = typename std::enable_if_t<
-                              std::is_constructible<ShapeStorageT, Arg>::value>>
-  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
-                       Attribute attr = nullptr)
-      : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
-        ranked(true) {}
-  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
-                       Attribute attr = nullptr)
-      : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
-        ranked(true) {}
-
-  /// Return the dimensions of the shape.
-  /// Requires: shape is ranked.
-  ArrayRef<int64_t> getDims() const {
-    assert(ranked && "requires ranked shape");
-    return dims;
-  }
-
-  /// Return whether the shape has a rank.
-  bool hasRank() const { return ranked; };
-
-  /// Return the element type component.
-  Type getElementType() const { return elementType; };
-
-  /// Return the raw attribute component.
-  Attribute getAttribute() const { return attr; };
-
-private:
-  friend class ShapeAdaptor;
-
-  ShapeStorageT dims;
-  Type elementType;
-  Attribute attr;
-  bool ranked{false};
-};
-
 /// Adaptor class to abstract the 
diff erences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
 class ShapeAdaptor {
@@ -137,7 +79,7 @@ class ShapeAdaptor {
   int64_t getNumElements() const;
 
   /// Returns whether valid (non-null) shape.
-  operator bool() const { return !val.isNull(); }
+  explicit operator bool() const { return !val.isNull(); }
 
   /// Dumps textual repesentation to stderr.
   void dump() const;
@@ -148,6 +90,71 @@ class ShapeAdaptor {
   PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
 };
 
+/// ShapedTypeComponents that represents the components of a ShapedType.
+/// The components consist of
+///  - A ranked or unranked shape with the dimension specification match those
+///    of ShapeType's getShape() (e.g., dynamic dimension represented using
+///    ShapedType::kDynamicSize)
+///  - A element type, may be unset (nullptr)
+///  - A attribute, may be unset (nullptr)
+/// Used by ShapedType type inferences.
+class ShapedTypeComponents {
+  /// Internal storage type for shape.
+  using ShapeStorageT = SmallVector<int64_t, 3>;
+
+public:
+  /// Default construction is an unranked shape.
+  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
+  ShapedTypeComponents(Type elementType)
+      : elementType(elementType), attr(nullptr), ranked(false) {}
+  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
+    ranked = shapedType.hasRank();
+    elementType = shapedType.getElementType();
+    if (ranked)
+      dims = llvm::to_vector<4>(shapedType.getShape());
+  }
+  ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
+    ranked = adaptor.hasRank();
+    elementType = adaptor.getElementType();
+    if (ranked)
+      adaptor.getDims(*this);
+  }
+  template <typename Arg, typename = typename std::enable_if_t<
+                              std::is_constructible<ShapeStorageT, Arg>::value>>
+  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
+                       Attribute attr = nullptr)
+      : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
+        ranked(true) {}
+  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
+                       Attribute attr = nullptr)
+      : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
+        ranked(true) {}
+
+  /// Return the dimensions of the shape.
+  /// Requires: shape is ranked.
+  ArrayRef<int64_t> getDims() const {
+    assert(ranked && "requires ranked shape");
+    return dims;
+  }
+
+  /// Return whether the shape has a rank.
+  bool hasRank() const { return ranked; };
+
+  /// Return the element type component.
+  Type getElementType() const { return elementType; };
+
+  /// Return the raw attribute component.
+  Attribute getAttribute() const { return attr; };
+
+private:
+  friend class ShapeAdaptor;
+
+  ShapeStorageT dims;
+  Type elementType;
+  Attribute attr;
+  bool ranked{false};
+};
+
 /// Range of values and shapes (corresponding effectively to Shapes dialect's
 /// ValueShape type concept).
 // Currently this exposes the Value (of operands) and Type of the Value. This is


        


More information about the Mlir-commits mailing list