[Mlir-commits] [mlir] a5c46bf - [mlir] Allow DenseElementsAttr to use any shaped type

Jeff Niu llvmlistbot at llvm.org
Fri Sep 30 23:25:23 PDT 2022


Author: Jeff Niu
Date: 2022-09-30T23:25:14-07:00
New Revision: a5c46bf9521e34d7f8c6fa048014912afb910020

URL: https://github.com/llvm/llvm-project/commit/a5c46bf9521e34d7f8c6fa048014912afb910020
DIFF: https://github.com/llvm/llvm-project/commit/a5c46bf9521e34d7f8c6fa048014912afb910020.diff

LOG: [mlir] Allow DenseElementsAttr to use any shaped type

This patch allows the type of DenseElementsAttr to be any shaped type.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D135002

Added: 
    

Modified: 
    mlir/lib/AsmParser/AttributeParser.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/IR/invalid-builtin-attributes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 819c86c997f3c..1e79b44322616 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -1066,12 +1066,12 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
       return nullptr;
   }
 
-  if (!type.isa<RankedTensorType, VectorType>()) {
-    emitError("elements literal must be a ranked tensor or vector type");
+  auto sType = type.dyn_cast<ShapedType>();
+  if (!sType) {
+    emitError("elements literal must be a shaped type");
     return nullptr;
   }
 
-  auto sType = type.cast<ShapedType>();
   if (!sType.hasStaticShape())
     return (emitError("elements literal type must have static shape"), nullptr);
 

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 70c2b47f41721..ed22134d1dcc8 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1381,8 +1381,6 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
 
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                                                    ArrayRef<char> data) {
-  assert((type.isa<RankedTensorType, VectorType>()) &&
-         "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
   bool isSplat = false;
   bool isValid = isValidRawBuffer(type, data, isSplat);
@@ -1498,16 +1496,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
   size_t bitWidth = getDenseElementBitWidth(newElementType);
   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
 
-  ShapedType newArrayType;
-  if (inType.isa<RankedTensorType>())
-    newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
-  else if (inType.isa<UnrankedTensorType>())
-    newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
-  else if (auto vType = inType.dyn_cast<VectorType>())
-    newArrayType = VectorType::get(vType.getShape(), newElementType,
-                                   vType.getNumScalableDims());
-  else
-    assert(newArrayType && "Unhandled tensor type");
+  ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType);
 
   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
   data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));

diff  --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 49acce2cf1187..8e57afa41ba88 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
 
 func.func @elementsattr_non_tensor_type() -> () {
-  "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a ranked tensor or vector type}}
+  "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a shaped type}}
 }
 
 // -----


        


More information about the Mlir-commits mailing list