[Mlir-commits] [mlir] f7cd55f - [mlir][llvm] Improve LLVM IR constant import.

Tobias Gysi llvmlistbot at llvm.org
Thu May 11 05:52:52 PDT 2023


Author: Tobias Gysi
Date: 2023-05-11T12:45:45Z
New Revision: f7cd55f56e9413f37474c8c12fc0a6231bf5fca5

URL: https://github.com/llvm/llvm-project/commit/f7cd55f56e9413f37474c8c12fc0a6231bf5fca5
DIFF: https://github.com/llvm/llvm-project/commit/f7cd55f56e9413f37474c8c12fc0a6231bf5fca5.diff

LOG: [mlir][llvm] Improve LLVM IR constant import.

Improve the constant import to handle zeroinitializer as well as
additional float types such as quad floats. The logic got restructured
to avoid creating intermediate dense element attributes when
constructing multi-dimensional arrays. Additionally, we also leverage
the fact that we do not need to iterate all elements of splat constants.

Reviewed By: Dinistro

Differential Revision: https://reviews.llvm.org/D150274

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Target/LLVMIR/Import/global-variables.ll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 61413c5ced201..47af0756a21e5 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -255,11 +255,14 @@ class ModuleImport {
   /// DictionaryAttr for the LLVM dialect.
   DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
                                            OpBuilder &builder);
-  /// Returns the builtin type equivalent to be used in attributes for the given
-  /// LLVM IR dialect type.
-  Type getStdTypeForAttr(Type type);
-  /// Returns `value` as an attribute to attach to a GlobalOp.
-  Attribute getConstantAsAttr(llvm::Constant *value);
+  /// Returns the builtin type equivalent to the given LLVM dialect type or
+  /// nullptr if there is no equivalent. The returned type can be used to create
+  /// an attribute for a GlobalOp or a ConstantOp.
+  Type getBuiltinTypeForAttr(Type type);
+  /// Returns `constant` as an attribute to attach to a GlobalOp or ConstantOp
+  /// or nullptr if the constant is not convertible. It supports scalar integer
+  /// and float constants as well as shaped types thereof including strings.
+  Attribute getConstantAsAttr(llvm::Constant *constant);
   /// Returns the topologically sorted set of transitive dependencies needed to
   /// convert the given constant.
   SetVector<llvm::Constant *> getConstantsToConvert(llvm::Constant *constant);

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 13ca331f632a4..ba378930863b0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -621,144 +621,193 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
   iface->setAttr(iface.getFastmathAttrName(), attr);
 }
 
-// We only need integers, floats, doubles, and vectors and tensors thereof for
-// attributes. Scalar and vector types are converted to the standard
-// equivalents. Array types are converted to ranked tensors; nested array types
-// are converted to multi-dimensional tensors or vectors, depending on the
-// innermost type being a scalar or a vector.
-Type ModuleImport::getStdTypeForAttr(Type type) {
-  if (!type)
-    return nullptr;
+/// Returns if `type` is a scalar integer or floating-point type.
+static bool isScalarType(Type type) {
+  return type.isa<IntegerType, FloatType>();
+}
 
-  if (type.isa<IntegerType, FloatType>())
-    return type;
+/// Returns `type` if it is a builtin integer or floating-point vector type that
+/// can be used to create an attribute or nullptr otherwise. If provided,
+/// `arrayShape` is added to the shape of the vector to create an attribute that
+/// matches an array of vectors.
+static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
+  if (!LLVM::isCompatibleVectorType(type))
+    return {};
 
-  // LLVM vectors can only contain scalars.
-  if (LLVM::isCompatibleVectorType(type)) {
-    llvm::ElementCount numElements = LLVM::getVectorNumElements(type);
-    if (numElements.isScalable()) {
-      emitError(UnknownLoc::get(context)) << "scalable vectors not supported";
-      return nullptr;
-    }
-    Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
-    if (!elementType)
-      return nullptr;
-    return VectorType::get(numElements.getKnownMinValue(), elementType);
+  llvm::ElementCount numElements = LLVM::getVectorNumElements(type);
+  if (numElements.isScalable()) {
+    emitError(UnknownLoc::get(type.getContext()))
+        << "scalable vectors not supported";
+    return {};
   }
 
-  // LLVM arrays can contain other arrays or vectors.
-  if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
-    // Recover the nested array shape.
-    SmallVector<int64_t, 4> shape;
-    shape.push_back(arrayType.getNumElements());
-    while (arrayType.getElementType().isa<LLVMArrayType>()) {
-      arrayType = arrayType.getElementType().cast<LLVMArrayType>();
-      shape.push_back(arrayType.getNumElements());
-    }
+  // An LLVM dialect vector can only contain scalars.
+  Type elementType = LLVM::getVectorElementType(type);
+  if (!isScalarType(elementType))
+    return {};
 
-    // If the innermost type is a vector, use the multi-dimensional vector as
-    // attribute type.
-    if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
-      llvm::ElementCount numElements =
-          LLVM::getVectorNumElements(arrayType.getElementType());
-      if (numElements.isScalable()) {
-        emitError(UnknownLoc::get(context)) << "scalable vectors not supported";
-        return nullptr;
-      }
-      shape.push_back(numElements.getKnownMinValue());
+  SmallVector<int64_t> shape(arrayShape.begin(), arrayShape.end());
+  shape.push_back(numElements.getKnownMinValue());
+  return VectorType::get(shape, elementType);
+}
 
-      Type elementType = getStdTypeForAttr(
-          LLVM::getVectorElementType(arrayType.getElementType()));
-      if (!elementType)
-        return nullptr;
-      return VectorType::get(shape, elementType);
-    }
+Type ModuleImport::getBuiltinTypeForAttr(Type type) {
+  if (!type)
+    return {};
 
-    // Otherwise use a tensor.
-    Type elementType = getStdTypeForAttr(arrayType.getElementType());
-    if (!elementType)
-      return nullptr;
-    return RankedTensorType::get(shape, elementType);
-  }
+  // Return builtin integer and floating-point types as is.
+  if (isScalarType(type))
+    return type;
+
+  // Return builtin vectors of integer and floating-point types as is.
+  if (Type vectorType = getVectorTypeForAttr(type))
+    return vectorType;
 
-  return nullptr;
+  // Multi-dimensional array types are converted to tensors or vectors,
+  // depending on the innermost type being a scalar or a vector.
+  SmallVector<int64_t> arrayShape;
+  while (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
+    arrayShape.push_back(arrayType.getNumElements());
+    type = arrayType.getElementType();
+  }
+  if (isScalarType(type))
+    return RankedTensorType::get(arrayShape, type);
+  return getVectorTypeForAttr(type, arrayShape);
 }
 
-// Get the given constant as an attribute. Not all constants can be represented
-// as attributes.
-Attribute ModuleImport::getConstantAsAttr(llvm::Constant *value) {
-  if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
+/// Returns an integer or float attribute for the provided scalar constant
+/// `constScalar` or nullptr if the conversion fails.
+static Attribute getScalarConstantAsAttr(OpBuilder &builder,
+                                         llvm::Constant *constScalar) {
+  MLIRContext *context = builder.getContext();
+
+  // Convert scalar intergers.
+  if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
     return builder.getIntegerAttr(
-        IntegerType::get(context, ci->getType()->getBitWidth()),
-        ci->getValue());
-  if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
-    if (c->isString())
-      return builder.getStringAttr(c->getAsString());
-  if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
-    llvm::Type *type = c->getType();
-    FloatType floatTy;
-    if (type->isBFloatTy())
-      floatTy = FloatType::getBF16(context);
-    else
-      floatTy = detail::getFloatType(context, type->getScalarSizeInBits());
-    assert(floatTy && "unsupported floating point type");
-    return builder.getFloatAttr(floatTy, c->getValueAPF());
+        IntegerType::get(context, constInt->getType()->getBitWidth()),
+        constInt->getValue());
   }
-  if (auto *f = dyn_cast<llvm::Function>(value))
-    return SymbolRefAttr::get(builder.getContext(), f->getName());
-
-  // Convert constant data to a dense elements attribute.
-  if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
-    Type type = convertType(cd->getElementType());
-    auto attrType = getStdTypeForAttr(convertType(cd->getType()))
-                        .dyn_cast_or_null<ShapedType>();
-    if (!attrType)
-      return nullptr;
-
-    if (type.isa<IntegerType>()) {
-      SmallVector<APInt, 8> values;
-      values.reserve(cd->getNumElements());
-      for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
-        values.push_back(cd->getElementAsAPInt(i));
-      return DenseElementsAttr::get(attrType, values);
-    }
 
-    if (type.isa<Float32Type, Float64Type>()) {
-      SmallVector<APFloat, 8> values;
-      values.reserve(cd->getNumElements());
-      for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
-        values.push_back(cd->getElementAsAPFloat(i));
-      return DenseElementsAttr::get(attrType, values);
+  // Convert scalar floats.
+  if (auto *constFloat = dyn_cast<llvm::ConstantFP>(constScalar)) {
+    llvm::Type *type = constFloat->getType();
+    FloatType floatType =
+        type->isBFloatTy()
+            ? FloatType::getBF16(context)
+            : LLVM::detail::getFloatType(context, type->getScalarSizeInBits());
+    if (!floatType) {
+      emitError(UnknownLoc::get(builder.getContext()))
+          << "unexpected floating-point type";
+      return {};
     }
+    return builder.getFloatAttr(floatType, constFloat->getValueAPF());
+  }
+  return {};
+}
 
-    return nullptr;
+/// Returns an integer or float attribute array for the provided constant
+/// sequence `constSequence` or nullptr if the conversion fails.
+static SmallVector<Attribute>
+getSequenceConstantAsAttrs(OpBuilder &builder,
+                           llvm::ConstantDataSequential *constSequence) {
+  SmallVector<Attribute> elementAttrs;
+  elementAttrs.reserve(constSequence->getNumElements());
+  for (auto idx : llvm::seq<int64_t>(0, constSequence->getNumElements())) {
+    llvm::Constant *constElement = constSequence->getElementAsConstant(idx);
+    elementAttrs.push_back(getScalarConstantAsAttr(builder, constElement));
   }
+  return elementAttrs;
+}
+
+Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
+  // Convert scalar constants.
+  if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant))
+    return scalarAttr;
+
+  // Convert function references.
+  if (auto *func = dyn_cast<llvm::Function>(constant))
+    return SymbolRefAttr::get(builder.getContext(), func->getName());
+
+  // Returns the static shape of the provided type if possible.
+  auto getConstantShape = [&](llvm::Type *type) {
+    return getBuiltinTypeForAttr(convertType(type))
+        .dyn_cast_or_null<ShapedType>();
+  };
 
-  // Unpack constant aggregates to create dense elements attribute whenever
-  // possible. Return nullptr (failure) otherwise.
-  if (isa<llvm::ConstantAggregate>(value)) {
-    auto outerType = getStdTypeForAttr(convertType(value->getType()))
-                         .dyn_cast_or_null<ShapedType>();
-    if (!outerType)
-      return nullptr;
-
-    SmallVector<Attribute, 8> values;
-    SmallVector<int64_t, 8> shape;
-
-    for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
-      auto nested = getConstantAsAttr(value->getAggregateElement(i))
-                        .dyn_cast_or_null<DenseElementsAttr>();
-      if (!nested)
-        return nullptr;
-
-      values.append(nested.value_begin<Attribute>(),
-                    nested.value_end<Attribute>());
+  // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
+  // integer or half/bfloat/float/double values.
+  if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(constant)) {
+    if (constArray->isString())
+      return builder.getStringAttr(constArray->getAsString());
+    auto shape = getConstantShape(constArray->getType());
+    if (!shape)
+      return {};
+    // Convert splat constants to splat elements attributes.
+    auto *constVector = dyn_cast<llvm::ConstantDataVector>(constant);
+    if (constVector && constVector->isSplat()) {
+      // A vector is guaranteed to have at least size one.
+      Attribute splatAttr = getScalarConstantAsAttr(
+          builder, constVector->getElementAsConstant(0));
+      return SplatElementsAttr::get(shape, splatAttr);
     }
+    // Convert non-splat constants to dense elements attributes.
+    SmallVector<Attribute> elementAttrs =
+        getSequenceConstantAsAttrs(builder, constArray);
+    return DenseElementsAttr::get(shape, elementAttrs);
+  }
 
-    return DenseElementsAttr::get(outerType, values);
+  // Convert multi-dimensional constant aggregates that store all kinds of
+  // integer and floating-point types.
+  if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(constant)) {
+    auto shape = getConstantShape(constAggregate->getType());
+    if (!shape)
+      return {};
+    // Collect the aggregate elements in depths first order.
+    SmallVector<Attribute> elementAttrs;
+    SmallVector<llvm::Constant *> workList = {constAggregate};
+    while (!workList.empty()) {
+      llvm::Constant *current = workList.pop_back_val();
+      // Append any nested aggregates in reverse order to ensure the head
+      // element of the nested aggregates is at the back of the work list.
+      if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(current)) {
+        for (auto idx :
+             reverse(llvm::seq<int64_t>(0, constAggregate->getNumOperands())))
+          workList.push_back(constAggregate->getAggregateElement(idx));
+        continue;
+      }
+      // Append the elements of nested constant arrays or vectors that store
+      // 1/2/4/8-byte integer or half/bfloat/float/double values.
+      if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(current)) {
+        SmallVector<Attribute> attrs =
+            getSequenceConstantAsAttrs(builder, constArray);
+        elementAttrs.append(attrs.begin(), attrs.end());
+        continue;
+      }
+      // Append nested scalar constants that store all kinds of integer and
+      // floating-point types.
+      if (Attribute scalarAttr = getScalarConstantAsAttr(builder, current)) {
+        elementAttrs.push_back(scalarAttr);
+        continue;
+      }
+      // Bail if the aggregate contains a unsupported constant type such as a
+      // constant expression.
+      return {};
+    }
+    return DenseElementsAttr::get(shape, elementAttrs);
   }
 
-  return nullptr;
+  // Convert zero aggregates.
+  if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
+    auto shape = getBuiltinTypeForAttr(convertType(constZero->getType()))
+                     .dyn_cast_or_null<ShapedType>();
+    if (!shape)
+      return {};
+    // Convert zero aggregates with a static shape to splat elements attributes.
+    Attribute splatAttr = builder.getZeroAttr(shape.getElementType());
+    assert(splatAttr && "expected non-null zero attribute for scalar types");
+    return SplatElementsAttr::get(shape, splatAttr);
+  }
+  return {};
 }
 
 LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {

diff  --git a/mlir/test/Target/LLVMIR/Import/global-variables.ll b/mlir/test/Target/LLVMIR/Import/global-variables.ll
index 81aa4a0f152fd..daf20ea606d87 100644
--- a/mlir/test/Target/LLVMIR/Import/global-variables.ll
+++ b/mlir/test/Target/LLVMIR/Import/global-variables.ll
@@ -168,20 +168,58 @@
 @array_constant = internal constant [2 x float] [float 1., float 2.]
 
 ; CHECK: llvm.mlir.global internal constant @nested_array_constant
-; CHECK-SAME:  (dense<[{{\[}}1, 2], [3, 4]]> : tensor<2x2xi32>)
-; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : !llvm.array<2 x array<2 x i32>>
+; CHECK-SAME-LITERAL:  (dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>)
+; CHECK-SAME-LITERAL:  {addr_space = 0 : i32, dso_local} : !llvm.array<2 x array<2 x i32>>
 @nested_array_constant = internal constant [2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]]
 
 ; CHECK: llvm.mlir.global internal constant @nested_array_constant3
-; CHECK-SAME:  (dense<[{{\[}}[1, 2], [3, 4]]]> : tensor<1x2x2xi32>)
-; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>>
+; CHECK-SAME-LITERAL:  (dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi32>)
+; CHECK-SAME-LITERAL:  {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>>
 @nested_array_constant3 = internal constant [1 x [2 x [2 x i32]]] [[2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]]]
 
 ; CHECK: llvm.mlir.global internal constant @nested_array_vector
-; CHECK-SAME:  (dense<[{{\[}}[1, 2], [3, 4]]]> : vector<1x2x2xi32>)
-; CHECK-SAME:   {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>>
+; CHECK-SAME-LITERAL:  (dense<[[[1, 2], [3, 4]]]> : vector<1x2x2xi32>)
+; CHECK-SAME-LITERAL:  {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>>
 @nested_array_vector = internal constant [1 x [2 x <2 x i32>]] [[2 x <2 x i32>] [<2 x i32> <i32 1, i32 2>, <2 x i32> <i32 3, i32 4>]]
 
+; CHECK:  llvm.mlir.global internal constant @vector_constant_zero
+; CHECK-SAME:  (dense<0> : vector<2xi24>)
+; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : vector<2xi24>
+ at vector_constant_zero = internal constant <2 x i24> zeroinitializer
+
+; CHECK:  llvm.mlir.global internal constant @array_constant_zero
+; CHECK-SAME:  (dense<0.000000e+00> : tensor<2xbf16>)
+; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : !llvm.array<2 x bf16>
+ at array_constant_zero = internal constant [2 x bfloat] zeroinitializer
+
+; CHECK: llvm.mlir.global internal constant @nested_array_constant3_zero
+; CHECK-SAME:  (dense<0> : tensor<1x2x2xi32>)
+; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>>
+ at nested_array_constant3_zero = internal constant [1 x [2 x [2 x i32]]] zeroinitializer
+
+; CHECK: llvm.mlir.global internal constant @nested_array_vector_zero
+; CHECK-SAME:  (dense<0> : vector<1x2x2xi32>)
+; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>>
+ at nested_array_vector_zero = internal constant [1 x [2 x <2 x i32>]] zeroinitializer
+
+; CHECK: llvm.mlir.global internal constant @nested_bool_array_constant
+; CHECK-SAME-LITERAL:  (dense<[[true, false]]> : tensor<1x2xi1>)
+; CHECK-SAME-LITERAL:  {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x i1>>
+ at nested_bool_array_constant = internal constant [1 x [2 x i1]] [[2 x i1] [i1 1, i1 0]]
+
+; CHECK: llvm.mlir.global internal constant @quad_float_constant
+; CHECK-SAME:  dense<[
+; CHECK-SAME:    529.340000000000031832314562052488327
+; CHECK-SAME:    529.340000000001850821718107908964157
+; CHECK-SAME:  ]> : vector<2xf128>)
+; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : vector<2xf128>
+ at quad_float_constant = internal constant <2 x fp128> <fp128 0xLF000000000000000400808AB851EB851, fp128 0xLF000000000000000400808AB851EB852>
+
+; CHECK: llvm.mlir.global internal constant @quad_float_splat_constant
+; CHECK-SAME:  dense<529.340000000000031832314562052488327> : vector<2xf128>)
+; CHECK-SAME:  {addr_space = 0 : i32, dso_local} : vector<2xf128>
+ at quad_float_splat_constant = internal constant <2 x fp128> <fp128 0xLF000000000000000400808AB851EB851, fp128 0xLF000000000000000400808AB851EB851>
+
 ; // -----
 
 ; CHECK: llvm.mlir.global_ctors {ctors = [@foo, @bar], priorities = [0 : i32, 42 : i32]}


        


More information about the Mlir-commits mailing list