[Mlir-commits] [mlir] d0541b4 - [mlir] Add I1 support to DenseArrayAttr

Jeff Niu llvmlistbot at llvm.org
Thu Aug 4 07:24:52 PDT 2022


Author: Jeff Niu
Date: 2022-08-04T10:24:45-04:00
New Revision: d0541b47000739c68c540170c6b9790ec1ea3b77

URL: https://github.com/llvm/llvm-project/commit/d0541b47000739c68c540170c6b9790ec1ea3b77
DIFF: https://github.com/llvm/llvm-project/commit/d0541b47000739c68c540170c6b9790ec1ea3b77.diff

LOG: [mlir] Add I1 support to DenseArrayAttr

This patch adds a DenseI1ArrayAttr to support arrays of i1. Importantly,
the implementation is as a simple `ArrayRef<bool>` instead of using bit
compression, which was problematic in DenseElementsAttr.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D130957

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/AsmParser/AttributeParser.cpp
    mlir/lib/AsmParser/Parser.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/IR/elements-attr-interface.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index eb8f0ca8334ec..d73d087597f71 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -791,8 +791,11 @@ class DenseArrayAttr : public DenseArrayBaseAttr {
   static bool classof(Attribute attr);
 };
 template <>
+void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const;
+template <>
 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
 
+extern template class DenseArrayAttr<bool>;
 extern template class DenseArrayAttr<int8_t>;
 extern template class DenseArrayAttr<int16_t>;
 extern template class DenseArrayAttr<int32_t>;
@@ -802,6 +805,7 @@ extern template class DenseArrayAttr<double>;
 } // namespace detail
 
 // Public name for all the supported DenseArrayAttr
+using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
 using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
 using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
 using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index c053d7e7b5d25..e710d8d3ff0f7 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -180,7 +180,7 @@ def Builtin_DenseArrayBase : Builtin_Attr<
                         ArrayRefParameter<"char">:$elements);
   let extraClassDeclaration = [{
     // All possible supported element type.
-    enum class EltType { I8, I16, I32, I64, F32, F64 };
+    enum class EltType { I1, I8, I16, I32, I64, F32, F64 };
 
     /// Allow implicit conversion to ElementsAttr.
     operator ElementsAttr() const {
@@ -189,7 +189,8 @@ def Builtin_DenseArrayBase : Builtin_Attr<
 
     /// ElementsAttr implementation.
     using ContiguousIterableTypesT =
-        std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
+        std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
+    const bool *value_begin_impl(OverloadToken<bool>) const;
     const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
     const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
     const int32_t *value_begin_impl(OverloadToken<int32_t>) const;

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 56a960f69131f..0202f829df325 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1282,6 +1282,7 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
   let storageType = "::mlir::" # denseAttrName;
   let returnType = "::llvm::ArrayRef<" # cppType # ">";
 }
+def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
 def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
 def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
 def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;

diff  --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index e197f784d4605..a554ad2df6f5d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -845,6 +845,12 @@ Attribute Parser::parseDenseArrayAttr() {
 
   if (auto intType = type.dyn_cast<IntegerType>()) {
     switch (type.getIntOrFloatBitWidth()) {
+    case 1:
+      if (isEmptyList)
+        result = DenseBoolArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
     case 8:
       if (isEmptyList)
         result = DenseI8ArrayAttr::get(parser.getContext(), {});
@@ -870,7 +876,7 @@ Attribute Parser::parseDenseArrayAttr() {
         result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     default:
-      emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
+      emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
       return {};
     }
   } else if (auto floatType = type.dyn_cast<FloatType>()) {

diff  --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 6cc96e7cfd0ac..e30617b906cdc 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -238,6 +238,15 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
 
 /// Parse an optional integer value from the stream.
 OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
+  // Parse `false` and `true` keywords as 0 and 1 respectively.
+  if (consumeIf(Token::kw_false)) {
+    result = false;
+    return success();
+  } else if (consumeIf(Token::kw_true)) {
+    result = true;
+    return success();
+  }
+
   Token curToken = getToken();
   if (curToken.isNot(Token::integer, Token::minus))
     return llvm::None;

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 44a41946acff1..42858ad53ebe9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1860,26 +1860,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
     }
   } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
     typeElision = AttrTypeElision::Must;
-    switch (denseArrayAttr.getElementType()) {
-    case DenseArrayBaseAttr::EltType::I8:
-      os << "[:i8";
-      break;
-    case DenseArrayBaseAttr::EltType::I16:
-      os << "[:i16";
-      break;
-    case DenseArrayBaseAttr::EltType::I32:
-      os << "[:i32";
-      break;
-    case DenseArrayBaseAttr::EltType::I64:
-      os << "[:i64";
-      break;
-    case DenseArrayBaseAttr::EltType::F32:
-      os << "[:f32";
-      break;
-    case DenseArrayBaseAttr::EltType::F64:
-      os << "[:f64";
-      break;
-    }
+    os << "[:" << denseArrayAttr.getType().getElementType();
     if (denseArrayAttr.size())
       os << " ";
     denseArrayAttr.printWithoutBraces(os);

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index ce7dc22accb43..334f20c2a446a 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -732,6 +732,9 @@ DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
 
 ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; }
 
+const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
+  return cast<DenseBoolArrayAttr>().asArrayRef().begin();
+}
 const int8_t *
 DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
   return cast<DenseI8ArrayAttr>().asArrayRef().begin();
@@ -762,6 +765,9 @@ void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
 
 void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
   switch (getElementType()) {
+  case DenseArrayBaseAttr::EltType::I1:
+    this->cast<DenseBoolArrayAttr>().printWithoutBraces(os);
+    return;
   case DenseArrayBaseAttr::EltType::I8:
     this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
     return;
@@ -797,15 +803,20 @@ void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
 
 template <typename T>
 void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
-  ArrayRef<T> values{*this};
-  llvm::interleaveComma(values, os);
+  llvm::interleaveComma(asArrayRef(), os);
+}
+
+/// Specialization for bool to print `true` or `false`.
+template <>
+void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const {
+  llvm::interleaveComma(asArrayRef(), os,
+                        [&](bool v) { os << (v ? "true" : "false"); });
 }
 
 /// Specialization for int8_t for forcing printing as number instead of chars.
 template <>
 void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
-  ArrayRef<int8_t> values{*this};
-  llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
+  llvm::interleaveComma(asArrayRef(), os, [&](int64_t v) { os << v; });
 }
 
 template <typename T>
@@ -816,7 +827,7 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
 }
 
 /// Parse a single element: generic template for int types, specialized for
-/// floating points below.
+/// floating point and boolean values below.
 template <typename T>
 static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
   return parser.parseInteger(value);
@@ -880,6 +891,14 @@ namespace {
 template <typename T>
 struct denseArrayAttrEltTypeBuilder;
 template <>
+struct denseArrayAttrEltTypeBuilder<bool> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I1;
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
+    return RankedTensorType::get(shape, IntegerType::get(context, 1));
+  }
+};
+template <>
 struct denseArrayAttrEltTypeBuilder<int8_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
   static ShapedType getShapedType(MLIRContext *context,
@@ -953,6 +972,7 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
 namespace mlir {
 namespace detail {
 // Explicit instantiation for all the supported DenseArrayAttr.
+template class DenseArrayAttr<bool>;
 template class DenseArrayAttr<int8_t>;
 template class DenseArrayAttr<int16_t>;
 template class DenseArrayAttr<int32_t>;

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index d0bcce65c7ccc..4b17194ec4f57 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -521,13 +521,15 @@ func.func @simple_scalar_example() {
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: func @dense_array_attr
-func.func @dense_array_attr() attributes{
+func.func @dense_array_attr() attributes {
 // CHECK-SAME: emptyf32attr = [:f32],
                emptyf32attr = [:f32],
 // CHECK-SAME: emptyf64attr = [:f64],
                emptyf64attr = [:f64],
 // CHECK-SAME: emptyi16attr = [:i16],
                emptyi16attr = [:i16],
+// CHECK-SAME: emptyi1attr = [:i1],
+               emptyi1attr = [:i1],
 // CHECK-SAME: emptyi32attr = [:i32],
                emptyi32attr = [:i32],
 // CHECK-SAME: emptyi64attr = [:i64],
@@ -540,6 +542,8 @@ func.func @dense_array_attr() attributes{
                f64attr = [:f64 -142.],
 // CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
                i16attr = [:i16 3, 5, -4, 10],
+// CHECK-SAME: i1attr = [:i1 true, false, true],
+               i1attr = [:i1 true, false, true],
 // CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
                i32attr = [:i32 1024, 453, -6435],
 // CHECK-SAME: i64attr = [:i64 -142],
@@ -549,6 +553,8 @@ func.func @dense_array_attr() attributes{
  } {
 // CHECK:  test.dense_array_attr
   test.dense_array_attr
+// CHECK-SAME: i1attr = [true, false, true]
+               i1attr = [true, false, true]
 // CHECK-SAME: i8attr = [1, -2, 3]
                i8attr = [1, -2, 3]
 // CHECK-SAME: i16attr = [3, 5, -4, 10]

diff  --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
index 04a9bd6214d3b..d094edec24bfa 100644
--- a/mlir/test/IR/elements-attr-interface.mlir
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -27,6 +27,8 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
 // expected-error at below {{Test iterating `IntegerAttr`: }}
 arith.constant dense<> : tensor<0xi64>
 
+// expected-error at below {{Test iterating `bool`: true, false, true, false, true, false}}
+arith.constant [:i1 true, false, true, false, true, false]
 // expected-error at below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
 arith.constant [:i8 10, 11, -12, 13, 14]
 // expected-error at below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3b89b188da49f..b500a5458d44c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -272,6 +272,7 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
 
 def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
   let arguments = (ins
+    DenseBoolArrayAttr:$i1attr,
     DenseI8ArrayAttr:$i8attr,
     DenseI16ArrayAttr:$i16attr,
     DenseI32ArrayAttr:$i32attr,
@@ -281,10 +282,9 @@ def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
     DenseI32ArrayAttr:$emptyattr
   );
   let assemblyFormat = [{
-   `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
-   `i64attr` `=` $i64attr  `f32attr` `=` $f32attr `f64attr` `=` $f64attr
-   `emptyattr` `=` $emptyattr
-   attr-dict
+   `i1attr` `=` $i1attr `i8attr` `=` $i8attr `i16attr` `=` $i16attr
+   `i32attr` `=` $i32attr `i64attr` `=` $i64attr  `f32attr` `=` $f32attr
+   `f64attr` `=` $f64attr `emptyattr` `=` $emptyattr attr-dict
   }];
 }
 

diff  --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index f32a49bd5bedb..453a0cecd473e 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -43,6 +43,9 @@ struct TestElementsAttrInterface
         if (auto concreteAttr =
                 attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
           switch (concreteAttr.getElementType()) {
+          case DenseArrayBaseAttr::EltType::I1:
+            testElementsAttrIteration<bool>(op, elementsAttr, "bool");
+            break;
           case DenseArrayBaseAttr::EltType::I8:
             testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
             break;


        


More information about the Mlir-commits mailing list