[Mlir-commits] [mlir] f9be7a7 - [mlir] speed up construction of LLVM IR constants when possible
Alex Zinenko
llvmlistbot at llvm.org
Thu Sep 2 14:07:38 PDT 2021
Author: Alex Zinenko
Date: 2021-09-02T23:07:30+02:00
New Revision: f9be7a7afda3c90b99c9f50e5eff1624da5a6511
URL: https://github.com/llvm/llvm-project/commit/f9be7a7afda3c90b99c9f50e5eff1624da5a6511
DIFF: https://github.com/llvm/llvm-project/commit/f9be7a7afda3c90b99c9f50e5eff1624da5a6511.diff
LOG: [mlir] speed up construction of LLVM IR constants when possible
The translation to LLVM IR used to construct sequential constants by recurring
down to individual elements, creating constant values for them, and wrapping
them into aggregate constants in post-order. This is highly inefficient for
large constants with known data such as DenseElementsAttr. Use LLVM's
ConstantData for the innermost dimension instead. LLVM does seem to support
data constants for nested sequential constants so the outer dimensions are
still handled recursively. Nevertheless, this speeds up the translation of
large constants with equal dimensions by up to 30x.
Users are advised to rewrite large constants to use flat types before
translating to LLVM IR if more efficiency in translation is necessary. This is
not done automatically as the translation is not aware of the expectations of
the overall compilation flow about type changes and indexing, in particular for
global constants with external linkage.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D109152
Added:
Modified:
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 64c7970cc71d9..cadc2f0e32fbc 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -101,6 +101,92 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
} while (true);
}
+/// Convert a dense elements attribute to an LLVM IR constant using its raw data
+/// storage if possible. This supports elements attributes of tensor or vector
+/// type and avoids constructing separate objects for individual values of the
+/// innermost dimension. Constants for other dimensions are still constructed
+/// recursively. Returns null if constructing from raw data is not supported for
+/// this type, e.g., element type is not a power-of-two-sized primitive. Reports
+/// other errors at `loc`.
+static llvm::Constant *
+convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
+ llvm::Type *llvmType,
+ const ModuleTranslation &moduleTranslation) {
+ if (!denseElementsAttr)
+ return nullptr;
+
+ llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
+ if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
+ return nullptr;
+
+ // Compute the shape of all dimensions but the innermost. Note that the
+ // innermost dimension may be that of the vector element type.
+ ShapedType type = denseElementsAttr.getType();
+ bool hasVectorElementType = type.getElementType().isa<VectorType>();
+ unsigned numAggregates =
+ denseElementsAttr.getNumElements() /
+ (hasVectorElementType ? 1
+ : denseElementsAttr.getType().getShape().back());
+ ArrayRef<int64_t> outerShape = type.getShape();
+ if (!hasVectorElementType)
+ outerShape = outerShape.drop_back();
+
+ // Handle the case of vector splat, LLVM has special support for it.
+ if (denseElementsAttr.isSplat() &&
+ (type.isa<VectorType>() || hasVectorElementType)) {
+ llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
+ innermostLLVMType, denseElementsAttr.getSplatValue(), loc,
+ moduleTranslation, /*isTopLevel=*/false);
+ llvm::Constant *splatVector =
+ llvm::ConstantDataVector::getSplat(0, splatValue);
+ SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
+ ArrayRef<llvm::Constant *> constantsRef = constants;
+ return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
+ }
+ if (denseElementsAttr.isSplat())
+ return nullptr;
+
+ // In case of non-splat, create a constructor for the innermost constant from
+ // a piece of raw data.
+ std::function<llvm::Constant *(StringRef)> buildCstData;
+ if (type.isa<TensorType>()) {
+ auto vectorElementType = type.getElementType().dyn_cast<VectorType>();
+ if (vectorElementType && vectorElementType.getRank() == 1) {
+ buildCstData = [&](StringRef data) {
+ return llvm::ConstantDataVector::getRaw(
+ data, vectorElementType.getShape().back(), innermostLLVMType);
+ };
+ } else if (!vectorElementType) {
+ buildCstData = [&](StringRef data) {
+ return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
+ innermostLLVMType);
+ };
+ }
+ } else if (type.isa<VectorType>()) {
+ buildCstData = [&](StringRef data) {
+ return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
+ innermostLLVMType);
+ };
+ }
+ if (!buildCstData)
+ return nullptr;
+
+ // Create innermost constants and defer to the default constant creation
+ // mechanism for other dimensions.
+ SmallVector<llvm::Constant *> constants;
+ unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
+ (innermostLLVMType->getScalarSizeInBits() / 8);
+ constants.reserve(numAggregates);
+ for (unsigned i = 0; i < numAggregates; ++i) {
+ StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
+ aggregateSize);
+ constants.push_back(buildCstData(data));
+ }
+
+ ArrayRef<llvm::Constant *> constantsRef = constants;
+ return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
+}
+
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
/// This currently supports integer, floating point, splat and dense element
/// attributes and combinations thereof. Also, an array attribute with two
@@ -178,6 +264,14 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
}
+ // Try using raw elements data if possible.
+ if (llvm::Constant *result =
+ convertDenseElementsAttr(loc, attr.dyn_cast<DenseElementsAttr>(),
+ llvmType, moduleTranslation)) {
+ return result;
+ }
+
+ // Fall back to element-by-element construction otherwise.
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
assert(elementsAttr.getType().hasStaticShape());
assert(!elementsAttr.getType().getShape().empty() &&
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index e6795e5199b26..2455175268fdc 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -50,6 +50,36 @@ llvm.mlir.global internal constant @int_gep() : !llvm.ptr<i32> {
llvm.return %gepinit : !llvm.ptr<i32>
}
+// CHECK{LITERAL}: @dense_float_vector = internal global <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>
+llvm.mlir.global internal @dense_float_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf32>) : vector<3xf32>
+
+// CHECK{LITERAL}: @splat_float_vector = internal global <3 x float> <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
+llvm.mlir.global internal @splat_float_vector(dense<42.0> : vector<3xf32>) : vector<3xf32>
+
+// CHECK{LITERAL}: @dense_double_vector = internal global <3 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00>
+llvm.mlir.global internal @dense_double_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf64>) : vector<3xf64>
+
+// CHECK{LITERAL}: @splat_double_vector = internal global <3 x double> <double 4.200000e+01, double 4.200000e+01, double 4.200000e+01>
+llvm.mlir.global internal @splat_double_vector(dense<42.0> : vector<3xf64>) : vector<3xf64>
+
+// CHECK{LITERAL}: @dense_i64_vector = internal global <3 x i64> <i64 1, i64 2, i64 3>
+llvm.mlir.global internal @dense_i64_vector(dense<[1, 2, 3]> : vector<3xi64>) : vector<3xi64>
+
+// CHECK{LITERAL}: @splat_i64_vector = internal global <3 x i64> <i64 42, i64 42, i64 42>
+llvm.mlir.global internal @splat_i64_vector(dense<42> : vector<3xi64>) : vector<3xi64>
+
+// CHECK{LITERAL}: @dense_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> <float 3.000000e+00, float 4.000000e+00>]
+llvm.mlir.global internal @dense_float_vector_2d(dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>>
+
+// CHECK{LITERAL}: @splat_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]
+llvm.mlir.global internal @splat_float_vector_2d(dense<42.0> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>>
+
+// CHECK{LITERAL}: @dense_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> <float 3.000000e+00, float 4.000000e+00>], [2 x <2 x float>] [<2 x float> <float 5.000000e+00, float 6.000000e+00>, <2 x float> <float 7.000000e+00, float 8.000000e+00>]]
+llvm.mlir.global internal @dense_float_vector_3d(dense<[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>
+
+// CHECK{LITERAL}: @splat_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>], [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]]
+llvm.mlir.global internal @splat_float_vector_3d(dense<42.0> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>
+
//
// Linkage attribute.
//
@@ -67,7 +97,7 @@ llvm.mlir.global weak @weak(42 : i32) : i32
// CHECK: @common = common global i32 0
llvm.mlir.global common @common(0 : i32) : i32
// CHECK: @appending = appending global [3 x i32] [i32 1, i32 2, i32 3]
-llvm.mlir.global appending @appending(dense<[1,2,3]> : vector<3xi32>) : !llvm.array<3xi32>
+llvm.mlir.global appending @appending(dense<[1,2,3]> : tensor<3xi32>) : !llvm.array<3xi32>
// CHECK: @extern_weak = extern_weak global i32
llvm.mlir.global extern_weak @extern_weak() : i32
// CHECK: @linkonce_odr = linkonce_odr global i32 42
More information about the Mlir-commits
mailing list