[Mlir-commits] [mlir] f21896f - [DenseElementAttr] Simplify the public API for creating these.

Chris Lattner llvmlistbot at llvm.org
Thu May 12 08:25:09 PDT 2022


Author: Chris Lattner
Date: 2022-05-12T16:18:23+01:00
New Revision: f21896f2c6dc6f4c2c3d0f192f7fefd178f5d5f7

URL: https://github.com/llvm/llvm-project/commit/f21896f2c6dc6f4c2c3d0f192f7fefd178f5d5f7
DIFF: https://github.com/llvm/llvm-project/commit/f21896f2c6dc6f4c2c3d0f192f7fefd178f5d5f7.diff

LOG: [DenseElementAttr] Simplify the public API for creating these.

Instead of requiring the client to compute the "isSplat" bit,
compute it internally.  This makes the logic more consistent
and defines away a lot of "elements.size()==1" in the clients.

This addresses Issue #55185

Differential Revision: https://reviews.llvm.org/D125447

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/Parser/AttributeParser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 4371a1cb088f9..85f6d3f4e638e 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -193,13 +193,8 @@ class DenseElementsAttr : public Attribute {
   ///   - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to
   ///     the linear order of the shape type from MSB to LSB, padded to on the
   ///     right.
-  ///
-  /// If `isSplatBuffer` is true, then the raw buffer should contain a
-  /// single element (or for the case of 1-bit, a single byte of 0 or 255),
-  /// which will be used to construct a splat.
   static DenseElementsAttr getFromRawBuffer(ShapedType type,
-                                            ArrayRef<char> rawBuffer,
-                                            bool isSplatBuffer);
+                                            ArrayRef<char> rawBuffer);
 
   /// Returns true if the given buffer is a valid raw buffer for the given type.
   /// `detectedSplat` is set if the buffer is valid and represents a splat

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 1fbc7eb18e5d4..19c8a07b94cd7 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -236,19 +236,27 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
     /// values. Each APFloat value is expected to have the same bitwidth as the
     /// element type of 'type'. 'type' must be a vector or tensor with static
     /// shape.
+    ///
+    /// If the `values` array only has a single element, then this  constructs
+    /// splat of that value.
     static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
-                                    ArrayRef<APFloat> values, bool isSplat);
+                                    ArrayRef<APFloat> values);
 
     /// Constructs a dense elements attribute from an array of raw APInt values.
     /// Each APInt value is expected to have the same bitwidth as the element
     /// type of 'type'. 'type' must be a vector or tensor with static shape.
+    ///
+    /// If the `values` array only has a single element, then this  constructs
+    /// splat of that value.
     static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
-                                    ArrayRef<APInt> values, bool isSplat);
+                                    ArrayRef<APInt> values);
 
     /// Get or create a new dense elements attribute instance with the given raw
     /// data buffer. 'type' must be a vector or tensor with static shape.
-    static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
-                                    bool isSplat);
+    ///
+    /// If the `values` array only has a single element, then this  constructs
+    /// splat of that value.
+    static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
 
     /// Overload of the raw 'get' method that asserts that the given type is of
     /// complex type. This method is used to verify type invariants that the
@@ -264,7 +272,6 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
                                               ArrayRef<char> data,
                                               int64_t dataEltSize, bool isInt,
                                               bool isSigned);
-
   public:
   }];
   let genAccessors = 0;
@@ -308,7 +315,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
   let builders = [
     AttrBuilderWithInferredContext<(ins "ShapedType":$type,
                                         "ArrayRef<StringRef>":$values), [{
-      return $_get(type.getContext(), type, values,
+      return $_get(type.getContext(), type, values, 
                    /* isSplat */(values.size() == 1));
     }]>,
   ];

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index aa498b2c1e183..759b708952e2f 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -348,11 +348,9 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
                               rawBufferSize);
   bool isSplat = false;
   if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
-                                           isSplat)) {
+                                           isSplat))
     return mlirAttributeGetNull();
-  }
-  return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp,
-                                                  isSplat));
+  return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
 }
 
 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ee5acdc34dc05..932973a13c217 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -837,7 +837,7 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
     if (!attr || !attr.isSplat())
       return failure();
     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
-        reshapeOp.getResultType(), attr.getRawData(), true);
+        reshapeOp.getResultType(), attr.getRawData());
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
     return success();
   }

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 0004fe90fe87a..1ecdf183d9ec2 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -713,8 +713,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
            "expected value to have same bitwidth as element type");
     writeBits(data.data(), i * storageBitWidth, intVal);
   }
-  return DenseIntOrFPElementsAttr::getRaw(type, data,
-                                          /*isSplat=*/(values.size() == 1));
+
+  // Handle the special encoding of splat of bool.
+  if (values.size() == 1 && values[0].getType().isInteger(1))
+    data[0] = data[0] ? -1 : 0;
+
+  return DenseIntOrFPElementsAttr::getRaw(type, data);
 }
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -723,10 +727,22 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   assert(type.getElementType().isInteger(1));
 
   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
-  for (int i = 0, e = values.size(); i != e; ++i)
-    setBit(buff.data(), i, values[i]);
-  return DenseIntOrFPElementsAttr::getRaw(type, buff,
-                                          /*isSplat=*/(values.size() == 1));
+
+  if (!values.empty()) {
+    bool isSplat = true;
+    bool firstValue = values[0];
+    for (int i = 0, e = values.size(); i != e; ++i) {
+      isSplat &= values[i] == firstValue;
+      setBit(buff.data(), i, values[i]);
+    }
+
+    if (isSplat) { // special encoding for splat.
+      buff.resize(1);
+      buff[0] = values[0] ? -1 : 0;
+    }
+  }
+
+  return DenseIntOrFPElementsAttr::getRaw(type, buff);
 }
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -743,8 +759,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   assert(type.getElementType().isIntOrIndex());
   assert(hasSameElementsOrSplat(type, values));
   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
-  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
-                                          /*isSplat=*/(values.size() == 1));
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
 }
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<std::complex<APInt>> values) {
@@ -754,8 +769,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
                           values.size() * 2);
-  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
-                                          /*isSplat=*/(values.size() == 1));
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
 }
 
 // Constructs a dense float elements attribute from an array of APFloat
@@ -766,8 +780,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   assert(type.getElementType().isa<FloatType>());
   assert(hasSameElementsOrSplat(type, values));
   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
-  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
-                                          /*isSplat=*/(values.size() == 1));
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
 }
 DenseElementsAttr
 DenseElementsAttr::get(ShapedType type,
@@ -778,17 +791,15 @@ DenseElementsAttr::get(ShapedType type,
   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
                            values.size() * 2);
   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
-  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
-                                          /*isSplat=*/(values.size() == 1));
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
 }
 
 /// Construct a dense elements attribute from a raw buffer representing the
 /// data for this attribute. Users should generally not use this methods as
 /// the expected buffer format may not be a form the user expects.
-DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
-                                                      ArrayRef<char> rawBuffer,
-                                                      bool isSplatBuffer) {
-  return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
+DenseElementsAttr
+DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
+  return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
 }
 
 /// Returns true if the given buffer is a valid raw buffer for the given type.
@@ -964,7 +975,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
          "expected the same element type");
   assert(newType.getNumElements() == curType.getNumElements() &&
          "expected the same number of elements");
-  return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
+  return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
 }
 
 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
@@ -976,7 +987,7 @@ DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
 
   assert(newType.getElementType() == curType.getElementType() &&
          "expected the same element type");
-  return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true);
+  return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
 }
 
 /// Return a new DenseElementsAttr that has the same data as the current
@@ -993,7 +1004,7 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
              getDenseElementBitWidth(curElType) &&
          "expected element types with the same bitwidth");
   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
-                                          getRawData(), isSplat());
+                                          getRawData());
 }
 
 DenseElementsAttr
@@ -1027,13 +1038,18 @@ int64_t DenseElementsAttr::getNumElements() const {
 template <typename APRangeT>
 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
                                 APRangeT &&values) {
-  data.resize(llvm::divideCeil(storageWidth * llvm::size(values), CHAR_BIT));
+  size_t numValues = llvm::size(values);
+  data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
   size_t offset = 0;
   for (auto it = values.begin(), e = values.end(); it != e;
        ++it, offset += storageWidth) {
     assert((*it).getBitWidth() <= storageWidth);
     writeBits(data.data(), offset, *it);
   }
+
+  // Handle the special encoding of splat of a boolean.
+  if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
+    data[0] = data[0] ? -1 : 0;
 }
 
 /// Constructs a dense elements attribute from an array of raw APFloat values.
@@ -1041,12 +1057,11 @@ static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
 /// type of 'type'. 'type' must be a vector or tensor with static shape.
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                                                    size_t storageWidth,
-                                                   ArrayRef<APFloat> values,
-                                                   bool isSplat) {
+                                                   ArrayRef<APFloat> values) {
   std::vector<char> data;
   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
-  return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
+  return DenseIntOrFPElementsAttr::getRaw(type, data);
 }
 
 /// Constructs a dense elements attribute from an array of raw APInt values.
@@ -1054,19 +1069,21 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
 /// of 'type'.
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                                                    size_t storageWidth,
-                                                   ArrayRef<APInt> values,
-                                                   bool isSplat) {
+                                                   ArrayRef<APInt> values) {
   std::vector<char> data;
   writeAPIntsToBuffer(storageWidth, data, values);
-  return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
+  return DenseIntOrFPElementsAttr::getRaw(type, data);
 }
 
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
-                                                   ArrayRef<char> data,
-                                                   bool isSplat) {
+                                                   ArrayRef<char> data) {
   assert((type.isa<RankedTensorType, VectorType>()) &&
          "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
+  bool isSplat = false;
+  bool isValid = isValidRawBuffer(type, data, isSplat);
+  assert(isValid);
+  (void)isValid;
   return Base::get(type.getContext(), type, data, isSplat);
 }
 
@@ -1084,7 +1101,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
 
   int64_t numElements = data.size() / dataEltSize;
   assert(numElements == 1 || numElements == type.getNumElements());
-  return getRaw(type, data, /*isSplat=*/numElements == 1);
+  return getRaw(type, data);
 }
 
 /// Overload of the 'getRaw' method that asserts that the given type is of
@@ -1099,7 +1116,8 @@ DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
 
   int64_t numElements = data.size() / dataEltSize;
   assert(numElements == 1 || numElements == type.getNumElements());
-  return getRaw(type, data, /*isSplat=*/numElements == 1);
+  (void)numElements;
+  return getRaw(type, data);
 }
 
 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
@@ -1212,7 +1230,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
   auto newArrayType =
       mappingHelper(mapping, *this, getType(), newElementType, elementData);
 
-  return getRaw(newArrayType, elementData, isSplat());
+  return getRaw(newArrayType, elementData);
 }
 
 /// Method for supporting type inquiry through isa, cast and dyn_cast.
@@ -1230,8 +1248,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
   llvm::SmallVector<char, 8> elementData;
   auto newArrayType =
       mappingHelper(mapping, *this, getType(), newElementType, elementData);
-
-  return getRaw(newArrayType, elementData, isSplat());
+  return getRaw(newArrayType, elementData);
 }
 
 /// Method for supporting type inquiry through isa, cast and dyn_cast.

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 0161a3c151209..3618962fa1d5f 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -717,11 +717,10 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc,
     MutableArrayRef<char> convRawData(outDataVec);
     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
         rawData, convRawData, type);
-    return DenseElementsAttr::getFromRawBuffer(type, convRawData,
-                                               detectedSplat);
+    return DenseElementsAttr::getFromRawBuffer(type, convRawData);
   }
 
-  return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
+  return DenseElementsAttr::getFromRawBuffer(type, rawData);
 }
 
 ParseResult TensorLiteralParser::parseElement() {


        


More information about the Mlir-commits mailing list