[Mlir-commits] [mlir] 444683a - [mlir] Remove the element type enum from DenseArrayAttr

Jeff Niu llvmlistbot at llvm.org
Wed Aug 24 17:02:40 PDT 2022


Author: Jeff Niu
Date: 2022-08-24T17:02:31-07:00
New Revision: 444683a9de4bc534b1154559cd537e7f6aa52847

URL: https://github.com/llvm/llvm-project/commit/444683a9de4bc534b1154559cd537e7f6aa52847
DIFF: https://github.com/llvm/llvm-project/commit/444683a9de4bc534b1154559cd537e7f6aa52847.diff

LOG: [mlir] Remove the element type enum from DenseArrayAttr

The element type enum is not needed to differentiate dense array kinds
because the element type of the shaped type can be used instead.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D132535

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 5a11a757f4ec7..70574b9fe9eb4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -29,6 +29,7 @@ class IntegerSet;
 class IntegerType;
 class Location;
 class Operation;
+class RankedTensorType;
 
 //===----------------------------------------------------------------------===//
 // Elements Attributes
@@ -781,18 +782,14 @@ class DenseArrayAttr : public DenseArrayBaseAttr {
   void printWithoutBraces(raw_ostream &os) const;
 
   /// Parse the short form `[42, 100, -1]` without any type prefix.
-  static Attribute parse(AsmParser &parser, Type odsType);
+  static Attribute parse(AsmParser &parser, Type type);
 
   /// Parse the short form `42, 100, -1` without any type prefix or braces.
-  static Attribute parseWithoutBraces(AsmParser &parser, Type odsType);
+  static Attribute parseWithoutBraces(AsmParser &parser, Type type);
 
   /// Support for isa<>/cast<>.
   static bool classof(Attribute attr);
 };
-template <>
-void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const;
-template <>
-void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
 
 extern template class DenseArrayAttr<bool>;
 extern template class DenseArrayAttr<int8_t>;

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 467a0437d9493..cee4f0ef20673 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -187,13 +187,11 @@ def Builtin_DenseArrayBase : Builtin_Attr<
     [1, 2, 3]
     ```
   }];
-  let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
-                        "DenseArrayBaseAttr::EltType":$elementType,
-                        Builtin_DenseArrayRawDataParameter:$rawData);
+  let parameters = (ins
+    AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
+    Builtin_DenseArrayRawDataParameter:$rawData
+  );
   let extraClassDeclaration = [{
-    // All possible supported element type.
-    enum class EltType { I1, I8, I16, I32, I64, F32, F64 };
-
     /// Allow implicit conversion to ElementsAttr.
     operator ElementsAttr() const {
       return *this ? cast<ElementsAttr>() : nullptr;

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 20c1b0ad3539e..a7eedb25a524e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Endian.h"
 
 using namespace mlir;
@@ -718,30 +719,10 @@ void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
 }
 
 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
-  switch (getElementType()) {
-  case DenseArrayBaseAttr::EltType::I1:
-    this->cast<DenseBoolArrayAttr>().printWithoutBraces(os);
-    return;
-  case DenseArrayBaseAttr::EltType::I8:
-    this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
-    return;
-  case DenseArrayBaseAttr::EltType::I16:
-    this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
-    return;
-  case DenseArrayBaseAttr::EltType::I32:
-    this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
-    return;
-  case DenseArrayBaseAttr::EltType::I64:
-    this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
-    return;
-  case DenseArrayBaseAttr::EltType::F32:
-    this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
-    return;
-  case DenseArrayBaseAttr::EltType::F64:
-    this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
-    return;
-  }
-  llvm_unreachable("<unknown DenseArrayBaseAttr>");
+  llvm::TypeSwitch<DenseArrayBaseAttr>(*this)
+      .Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr,
+            DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr,
+            DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); });
 }
 
 void DenseArrayBaseAttr::print(raw_ostream &os) const {
@@ -750,6 +731,89 @@ void DenseArrayBaseAttr::print(raw_ostream &os) const {
   os << "]";
 }
 
+namespace {
+/// Instantiations of this class provide utilities for interacting with native
+/// data types in the context of DenseArrayAttr.
+template <size_t width,
+          IntegerType::SignednessSemantics signedness = IntegerType::Signless>
+struct DenseArrayAttrIntUtil {
+  static bool checkElementType(Type eltType) {
+    auto type = eltType.dyn_cast<IntegerType>();
+    if (!type || type.getWidth() != width)
+      return false;
+    return type.getSignedness() == signedness;
+  }
+
+  static Type getElementType(MLIRContext *ctx) {
+    return IntegerType::get(ctx, width, signedness);
+  }
+
+  template <typename T>
+  static void printElement(raw_ostream &os, T value) {
+    os << value;
+  }
+
+  template <typename T>
+  static ParseResult parseElement(AsmParser &parser, T &value) {
+    return parser.parseInteger(value);
+  }
+};
+template <typename T>
+struct DenseArrayAttrUtil;
+
+/// Specialization for boolean elements to print 'true' and 'false' literals for
+/// elements.
+template <>
+struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
+  static void printElement(raw_ostream &os, bool value) {
+    os << (value ? "true" : "false");
+  }
+};
+
+/// Specialization for 8-bit integers to ensure values are printed as integers
+/// and not characters.
+template <>
+struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
+  static void printElement(raw_ostream &os, int8_t value) {
+    os << static_cast<int>(value);
+  }
+};
+template <>
+struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
+template <>
+struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
+template <>
+struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
+
+/// Specialization for 32-bit floats.
+template <>
+struct DenseArrayAttrUtil<float> {
+  static bool checkElementType(Type eltType) { return eltType.isF32(); }
+  static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
+  static void printElement(raw_ostream &os, float value) { os << value; }
+
+  /// Parse a double and cast it to a float.
+  static ParseResult parseElement(AsmParser &parser, float &value) {
+    double doubleVal;
+    if (parser.parseFloat(doubleVal))
+      return failure();
+    value = doubleVal;
+    return success();
+  }
+};
+
+/// Specialization for 64-bit floats.
+template <>
+struct DenseArrayAttrUtil<double> {
+  static bool checkElementType(Type eltType) { return eltType.isF64(); }
+  static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
+  static void printElement(raw_ostream &os, float value) { os << value; }
+  static ParseResult parseElement(AsmParser &parser, double &value) {
+    return parser.parseFloat(value);
+  }
+};
+} // namespace
+
 template <typename T>
 void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
   print(printer.getStream());
@@ -757,20 +821,9 @@ void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
 
 template <typename T>
 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
-  llvm::interleaveComma(asArrayRef(), os);
-}
-
-/// Specialization for bool to print `true` or `false`.
-template <>
-void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const {
-  llvm::interleaveComma(asArrayRef(), os,
-                        [&](bool v) { os << (v ? "true" : "false"); });
-}
-
-/// Specialization for int8_t for forcing printing as number instead of chars.
-template <>
-void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
-  llvm::interleaveComma(asArrayRef(), os, [&](int64_t v) { os << v; });
+  llvm::interleaveComma(asArrayRef(), os, [&](T value) {
+    DenseArrayAttrUtil<T>::printElement(os, value);
+  });
 }
 
 template <typename T>
@@ -780,27 +833,6 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
   os << "]";
 }
 
-/// Parse a single element: generic template for int types, specialized for
-/// floating point and boolean values below.
-template <typename T>
-static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
-  return parser.parseInteger(value);
-}
-
-template <>
-ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
-  double doubleVal;
-  if (parser.parseFloat(doubleVal))
-    return failure();
-  value = doubleVal;
-  return success();
-}
-
-template <>
-ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
-  return parser.parseFloat(value);
-}
-
 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
 template <typename T>
 Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
@@ -808,7 +840,7 @@ Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
   SmallVector<T> data;
   if (failed(parser.parseCommaSeparatedList([&]() {
         T value;
-        if (parseDenseArrayAttrElt(parser, value))
+        if (DenseArrayAttrUtil<T>::parseElement(parser, value))
           return failure();
         data.push_back(value);
         return success();
@@ -840,87 +872,23 @@ DenseArrayAttr<T>::operator ArrayRef<T>() const {
                      raw.size() / sizeof(T));
 }
 
-namespace {
-/// Mapping from C++ element type to MLIR DenseArrayAttr internals.
-template <typename T>
-struct denseArrayAttrEltTypeBuilder;
-template <>
-struct denseArrayAttrEltTypeBuilder<bool> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I1;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, IntegerType::get(context, 1));
-  }
-};
-template <>
-struct denseArrayAttrEltTypeBuilder<int8_t> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, IntegerType::get(context, 8));
-  }
-};
-template <>
-struct denseArrayAttrEltTypeBuilder<int16_t> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, IntegerType::get(context, 16));
-  }
-};
-template <>
-struct denseArrayAttrEltTypeBuilder<int32_t> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, IntegerType::get(context, 32));
-  }
-};
-template <>
-struct denseArrayAttrEltTypeBuilder<int64_t> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, IntegerType::get(context, 64));
-  }
-};
-template <>
-struct denseArrayAttrEltTypeBuilder<float> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, Float32Type::get(context));
-  }
-};
-template <>
-struct denseArrayAttrEltTypeBuilder<double> {
-  constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
-  static ShapedType getShapedType(MLIRContext *context,
-                                  ArrayRef<int64_t> shape) {
-    return RankedTensorType::get(shape, Float64Type::get(context));
-  }
-};
-} // namespace
-
 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
 template <typename T>
 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
                                          ArrayRef<T> content) {
-  auto size = static_cast<int64_t>(content.size());
-  auto shapedType =
-      denseArrayAttrEltTypeBuilder<T>::getShapedType(context, size);
-  auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
+  auto shapedType = RankedTensorType::get(
+      content.size(), DenseArrayAttrUtil<T>::getElementType(context));
   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
                                  content.size() * sizeof(T));
-  return Base::get(context, shapedType, eltType, rawArray)
+  return Base::get(context, shapedType, rawArray)
       .template cast<DenseArrayAttr<T>>();
 }
 
 template <typename T>
 bool DenseArrayAttr<T>::classof(Attribute attr) {
-  return attr.isa<DenseArrayBaseAttr>() &&
-         attr.cast<DenseArrayBaseAttr>().getElementType() ==
-             denseArrayAttrEltTypeBuilder<T>::eltType;
+  if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>())
+    return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
+  return false;
 }
 
 namespace mlir {

diff  --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 453a0cecd473e..6257799824c6f 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -9,6 +9,7 @@
 #include "TestAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
@@ -42,29 +43,28 @@ struct TestElementsAttrInterface
           continue;
         if (auto concreteAttr =
                 attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
-          switch (concreteAttr.getElementType()) {
-          case DenseArrayBaseAttr::EltType::I1:
-            testElementsAttrIteration<bool>(op, elementsAttr, "bool");
-            break;
-          case DenseArrayBaseAttr::EltType::I8:
-            testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
-            break;
-          case DenseArrayBaseAttr::EltType::I16:
-            testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
-            break;
-          case DenseArrayBaseAttr::EltType::I32:
-            testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
-            break;
-          case DenseArrayBaseAttr::EltType::I64:
-            testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
-            break;
-          case DenseArrayBaseAttr::EltType::F32:
-            testElementsAttrIteration<float>(op, elementsAttr, "float");
-            break;
-          case DenseArrayBaseAttr::EltType::F64:
-            testElementsAttrIteration<double>(op, elementsAttr, "double");
-            break;
-          }
+          llvm::TypeSwitch<DenseArrayBaseAttr>(concreteAttr)
+              .Case([&](DenseBoolArrayAttr attr) {
+                testElementsAttrIteration<bool>(op, attr, "bool");
+              })
+              .Case([&](DenseI8ArrayAttr attr) {
+                testElementsAttrIteration<int8_t>(op, attr, "int8_t");
+              })
+              .Case([&](DenseI16ArrayAttr attr) {
+                testElementsAttrIteration<int16_t>(op, attr, "int16_t");
+              })
+              .Case([&](DenseI32ArrayAttr attr) {
+                testElementsAttrIteration<int32_t>(op, attr, "int32_t");
+              })
+              .Case([&](DenseI64ArrayAttr attr) {
+                testElementsAttrIteration<int64_t>(op, attr, "int64_t");
+              })
+              .Case([&](DenseF32ArrayAttr attr) {
+                testElementsAttrIteration<float>(op, attr, "float");
+              })
+              .Case([&](DenseF64ArrayAttr attr) {
+                testElementsAttrIteration<double>(op, attr, "double");
+              });
           continue;
         }
         testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");


        


More information about the Mlir-commits mailing list