[Mlir-commits] [mlir] 8db947d - [mlir] Make ElementsAttr inherit from TypedAttr

Rahul Kayaith llvmlistbot at llvm.org
Thu Apr 20 13:32:03 PDT 2023


Author: Rahul Kayaith
Date: 2023-04-20T16:31:53-04:00
New Revision: 8db947da0b377a13cd0e536b1bbc825abff3c001

URL: https://github.com/llvm/llvm-project/commit/8db947da0b377a13cd0e536b1bbc825abff3c001
DIFF: https://github.com/llvm/llvm-project/commit/8db947da0b377a13cd0e536b1bbc825abff3c001.diff

LOG: [mlir] Make ElementsAttr inherit from TypedAttr

This allows implicit conversion from `ElementsAttr` to `TypedAttr`, but
required renaming the `ElementsAttr::getType()` interface method to
`getShapedType`.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D148492

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/IR/BuiltinAttributeInterfaces.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/lib/Transforms/ViewOpGraph.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 4f24580d02a70..5633b4ccd7cf6 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -189,7 +189,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
         return {};
       elementResults.push_back(*elementResult);
     }
-    return DenseElementsAttr::get(op.getType(), elementResults);
+    return DenseElementsAttr::get(op.getShapedType(), elementResults);
   }
   return {};
 }

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 19336e0483860..39c871d7cab15 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -15,11 +15,30 @@
 
 include "mlir/IR/OpBase.td"
 
+//===----------------------------------------------------------------------===//
+// TypedAttrInterface
+//===----------------------------------------------------------------------===//
+
+def TypedAttrInterface : AttrInterface<"TypedAttr"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    This interface is used for attributes that have a type. The type of an
+    attribute is understood to represent the type of the data contained in the
+    attribute and is often used as the type of a value with this data.
+  }];
+
+  let methods = [InterfaceMethod<
+    "Get the attribute's type",
+    "::mlir::Type", "getType"
+  >];
+}
+
 //===----------------------------------------------------------------------===//
 // ElementsAttrInterface
 //===----------------------------------------------------------------------===//
 
-def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
+def ElementsAttrInterface : AttrInterface<"ElementsAttr", [TypedAttrInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This interface is used for attributes that contain the constant elements of
@@ -78,7 +97,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     }
     /// * Attribute
     auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
-      mlir::Type elementType = getType().getElementType();
+      mlir::Type elementType = getShapedType().getElementType();
       auto it = llvm::map_range(getElements(), [=](uint64_t value) {
         return mlir::IntegerAttr::get(elementType,
                                       llvm::APInt(/*numBits=*/64, value));
@@ -154,13 +173,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     InterfaceMethod<[{
       Returns true if the attribute elements correspond to a splat, i.e. that
       all elements of the attribute are the same value.
-    }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
+    }], "bool", "isSplat", (ins), [{}], /*defaultImplementation=*/[{
         // By default, only check for a single element splat.
         return $_attr.getNumElements() == 1;
     }]>,
     InterfaceMethod<[{
       Returns the shaped type of the elements attribute.
-    }], "::mlir::ShapedType", "getType">
+    }], "::mlir::ShapedType", "getShapedType", (ins), [{}], /*defaultImplementation=*/[{
+        return $_attr.getType();
+    }]>
   ];
 
   string ElementsAttrInterfaceAccessors = [{
@@ -325,7 +346,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
                                       ArrayRef<uint64_t> index);
     static uint64_t getFlattenedIndex(ElementsAttr elementsAttr,
                                       ArrayRef<uint64_t> index) {
-      return getFlattenedIndex(elementsAttr.getType(), index);
+      return getFlattenedIndex(elementsAttr.getShapedType(), index);
     }
 
     /// Returns the number of elements held by this attribute.
@@ -357,7 +378,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// Return the elements of this attribute as a value of type 'T'.
     template <typename T>
     DefaultValueCheckT<T, iterator_range<T>> getValues() const {
-      return {getType(), value_begin<T>(), value_end<T>()};
+      return {getShapedType(), value_begin<T>(), value_end<T>()};
     }
     template <typename T>
     DefaultValueCheckT<T, iterator<T>> value_begin() const;
@@ -377,7 +398,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     DerivedAttrValueIteratorRange<T> getValues() const {
       auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-      return {getType(), llvm::map_range(getValues<Attribute>(),
+      return {getShapedType(), llvm::map_range(getValues<Attribute>(),
               static_cast<T (*)(Attribute)>(castFn))};
     }
     template <typename T, typename = DerivedAttrValueCheckT<T>>
@@ -397,7 +418,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     template <typename T>
     DefaultValueCheckT<T, std::optional<iterator_range<T>>> tryGetValues() const {
       if (std::optional<iterator<T>> beginIt = try_value_begin<T>())
-        return iterator_range<T>(getType(), *beginIt, value_end<T>());
+        return iterator_range<T>(getShapedType(), *beginIt, value_end<T>());
       return std::nullopt;
     }
     template <typename T>
@@ -413,7 +434,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
 
       auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
       return DerivedAttrValueIteratorRange<T>(
-        getType(),
+        getShapedType(),
         llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
       );
     }
@@ -474,23 +495,4 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// TypedAttrInterface
-//===----------------------------------------------------------------------===//
-
-def TypedAttrInterface : AttrInterface<"TypedAttr"> {
-  let cppNamespace = "::mlir";
-
-  let description = [{
-    This interface is used for attributes that have a type. The type of an
-    attribute is understood to represent the type of the data contained in the
-    attribute and is often used as the type of a value with this data.
-  }];
-
-  let methods = [InterfaceMethod<
-    "Get the attribute's type",
-    "::mlir::Type", "getType"
-  >];
-}
-
 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 446bc375d3fa6..01ab1a765aae0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -218,7 +218,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
-    "DenseIntOrFPElements", [ElementsAttrInterface, TypedAttrInterface],
+    "DenseIntOrFPElements", [ElementsAttrInterface],
     "DenseElementsAttr"
   > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
@@ -359,7 +359,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseStringElementsAttr : Builtin_Attr<
-    "DenseStringElements", [ElementsAttrInterface, TypedAttrInterface],
+    "DenseStringElements", [ElementsAttrInterface],
     "DenseElementsAttr"
   > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
@@ -430,7 +430,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
-    ElementsAttrInterface, TypedAttrInterface
+    ElementsAttrInterface
   ]> {
   let summary = "An Attribute containing a dense multi-dimensional array "
                 "backed by a resource";
@@ -804,7 +804,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_SparseElementsAttr : Builtin_Attr<
-    "SparseElements", [ElementsAttrInterface, TypedAttrInterface]
+    "SparseElements", [ElementsAttrInterface]
   > {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{

diff  --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 6e35f120e22c4..ab216baede61c 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -25,11 +25,11 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
-  return elementsAttr.getType().getElementType();
+  return elementsAttr.getShapedType().getElementType();
 }
 
 int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
-  return elementsAttr.getType().getNumElements();
+  return elementsAttr.getShapedType().getNumElements();
 }
 
 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
@@ -49,7 +49,7 @@ bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
 }
 bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
                                 ArrayRef<uint64_t> index) {
-  return isValidIndex(elementsAttr.getType(), index);
+  return isValidIndex(elementsAttr.getShapedType(), index);
 }
 
 uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index a08bb372fc738..66891afaf5d3b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -407,8 +407,8 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
 
   // Fall back to element-by-element construction otherwise.
   if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
-    assert(elementsAttr.getType().hasStaticShape());
-    assert(!elementsAttr.getType().getShape().empty() &&
+    assert(elementsAttr.getShapedType().hasStaticShape());
+    assert(!elementsAttr.getShapedType().getShape().empty() &&
            "unexpected empty elements attribute shape");
 
     SmallVector<llvm::Constant *, 8> constants;
@@ -422,7 +422,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     }
     ArrayRef<llvm::Constant *> constantsRef = constants;
     llvm::Constant *result = buildSequentialConstant(
-        constantsRef, elementsAttr.getType().getShape(), llvmType, loc);
+        constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
     assert(constantsRef.empty() && "did not consume all elemental constants");
     return result;
   }

diff  --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index bdefbe3d2abd5..4598b56a901d5 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -153,8 +153,8 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
     // Elide "big" elements attributes.
     auto elements = attr.dyn_cast<ElementsAttr>();
     if (elements && elements.getNumElements() > largeAttrLimit) {
-      os << std::string(elements.getType().getRank(), '[') << "..."
-         << std::string(elements.getType().getRank(), ']') << " : "
+      os << std::string(elements.getShapedType().getRank(), '[') << "..."
+         << std::string(elements.getShapedType().getRank(), ']') << " : "
          << elements.getType();
       return;
     }

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ba9909d08ae30..ec0a5548a1603 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -81,9 +81,7 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
 }
 
 // Test support for ElementsAttrInterface.
-def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
-    ElementsAttrInterface, TypedAttrInterface
-  ]> {
+def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
   let mnemonic = "i64_elements";
   let parameters = (ins
     AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
@@ -269,9 +267,7 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
 }
 
 // Test simple extern 1D vector using ElementsAttrInterface.
-def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
-    ElementsAttrInterface, TypedAttrInterface
-  ]> {
+def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ElementsAttrInterface]> {
   let mnemonic = "e1di64_elements";
   let parameters = (ins
     AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,


        


More information about the Mlir-commits mailing list