[Mlir-commits] [mlir] ff52ad7 - [mlir] Change DenseArrayAttr to TensorType

Jeff Niu llvmlistbot at llvm.org
Mon Aug 1 19:17:35 PDT 2022


Author: Jeff Niu
Date: 2022-08-01T22:17:28-04:00
New Revision: ff52ad796c971dbf805375a2140344f742db94a3

URL: https://github.com/llvm/llvm-project/commit/ff52ad796c971dbf805375a2140344f742db94a3
DIFF: https://github.com/llvm/llvm-project/commit/ff52ad796c971dbf805375a2140344f742db94a3.diff

LOG: [mlir] Change DenseArrayAttr to TensorType

Previously, DenseArrayAttr used VectorType for its shaped type.
VectorType is problematic for arrays because it doesn't support zero
dimensions, meaning that an empty array would have `vector<i32>` as its
type. ElementsAttr would think that an empty dense array is size 1, not
0. This patch switches over to TensorType, which does support zero
dimensions.

Fixes #56860

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D130921

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 2c9a4bd8dc6d..44a41946acff 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1880,7 +1880,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
       os << "[:f64";
       break;
     }
-    if (denseArrayAttr.getType().getRank())
+    if (denseArrayAttr.size())
       os << " ";
     denseArrayAttr.printWithoutBraces(os);
     os << "]";

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 12926c5cc94d..ce7dc22accb4 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -884,7 +884,7 @@ struct denseArrayAttrEltTypeBuilder<int8_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
   static ShapedType getShapedType(MLIRContext *context,
                                   ArrayRef<int64_t> shape) {
-    return VectorType::get(shape, IntegerType::get(context, 8));
+    return RankedTensorType::get(shape, IntegerType::get(context, 8));
   }
 };
 template <>
@@ -892,7 +892,7 @@ struct denseArrayAttrEltTypeBuilder<int16_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
   static ShapedType getShapedType(MLIRContext *context,
                                   ArrayRef<int64_t> shape) {
-    return VectorType::get(shape, IntegerType::get(context, 16));
+    return RankedTensorType::get(shape, IntegerType::get(context, 16));
   }
 };
 template <>
@@ -900,7 +900,7 @@ struct denseArrayAttrEltTypeBuilder<int32_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
   static ShapedType getShapedType(MLIRContext *context,
                                   ArrayRef<int64_t> shape) {
-    return VectorType::get(shape, IntegerType::get(context, 32));
+    return RankedTensorType::get(shape, IntegerType::get(context, 32));
   }
 };
 template <>
@@ -908,7 +908,7 @@ struct denseArrayAttrEltTypeBuilder<int64_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
   static ShapedType getShapedType(MLIRContext *context,
                                   ArrayRef<int64_t> shape) {
-    return VectorType::get(shape, IntegerType::get(context, 64));
+    return RankedTensorType::get(shape, IntegerType::get(context, 64));
   }
 };
 template <>
@@ -916,7 +916,7 @@ struct denseArrayAttrEltTypeBuilder<float> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
   static ShapedType getShapedType(MLIRContext *context,
                                   ArrayRef<int64_t> shape) {
-    return VectorType::get(shape, Float32Type::get(context));
+    return RankedTensorType::get(shape, Float32Type::get(context));
   }
 };
 template <>
@@ -924,7 +924,7 @@ struct denseArrayAttrEltTypeBuilder<double> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
   static ShapedType getShapedType(MLIRContext *context,
                                   ArrayRef<int64_t> shape) {
-    return VectorType::get(shape, Float64Type::get(context));
+    return RankedTensorType::get(shape, Float64Type::get(context));
   }
 };
 } // namespace
@@ -934,8 +934,8 @@ template <typename T>
 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
                                          ArrayRef<T> content) {
   auto size = static_cast<int64_t>(content.size());
-  auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType(
-      context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{});
+  auto shapedType =
+      denseArrayAttrEltTypeBuilder<T>::getShapedType(context, size);
   auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
                                  content.size() * sizeof(T));


        


More information about the Mlir-commits mailing list