[Mlir-commits] [mlir] c788cad - [mlir][linalg] Fix FoldConstantTranspose execution inefficiency
Lei Zhang
llvmlistbot at llvm.org
Thu Oct 28 06:49:13 PDT 2021
Author: Lei Zhang
Date: 2021-10-28T09:45:14-04:00
New Revision: c788cad83b6b5c24f8160f9fc11a69dd7beafb8b
URL: https://github.com/llvm/llvm-project/commit/c788cad83b6b5c24f8160f9fc11a69dd7beafb8b
DIFF: https://github.com/llvm/llvm-project/commit/c788cad83b6b5c24f8160f9fc11a69dd7beafb8b.diff
LOG: [mlir][linalg] Fix FoldConstantTranspose execution inefficiency
* Move SmallVectors outside of inner loops to avoid frequent
allocations and deallocations
* Calculate linearized index and call flat range getters to
avoid internal shape querying behind `getValue`.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D112099
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 32ad335954b13..ee5622d2662db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1286,12 +1286,16 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
template <typename ConcreteType>
class FoldConstantBase : public OpRewritePattern<GenericOp> {
public:
+ struct APIntOrFloat {
+ Optional<APInt> apInt;
+ Optional<APFloat> apFloat;
+ };
struct APIntOrFloatArray {
SmallVector<APInt> apInts;
SmallVector<APFloat> apFloats;
};
using RegionComputationFn =
- std::function<APIntOrFloatArray(APIntOrFloatArray)>;
+ std::function<APIntOrFloat(const APIntOrFloatArray &)>;
FoldConstantBase(MLIRContext *context,
const ControlElementwiseOpsFusionFn &controlFn,
@@ -1403,57 +1407,109 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
auto outputShape = outputType.getShape();
- // Transpose the input constant. Because we don't know its rank in advance,
- // we need to loop over the range [0, element count) and delinearize the
- // index.
- for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
- SmallVector<uint64_t> indices(loopBounds.size(), 0);
- int totalCount = linearIndex0;
+ // Allocate small vectors for index delinearization. Initial values do not
+ // matter here as they will be overwritten later.
+ SmallVector<uint64_t> indices(loopBounds.size(), 0);
+ SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
+ SmallVector<SmallVector<uint64_t>> srcIndices(
+ numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
+ SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
+ uint64_t dstLinearIndex = 0;
+
+ // Allocate spaces for compute function inputs. Initial values do not matter
+ // here as they will be overwritten later.
+ APIntOrFloatArray computeFnInputs;
+
+ auto inputShapes = llvm::to_vector<4>(
+ llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
+ return operand->get().getType().cast<ShapedType>().getShape();
+ }));
+
+ // Given a `linearIndex`, remap it to a linear index to access linalg op
+ // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
+ // `srcLinearIndices`, `dstLinearIndex` in place.
+ auto computeRemappedLinearIndex = [&](int linearIndex) {
+ int totalCount = linearIndex;
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
indices[dim] = totalCount % loopBounds[dim];
totalCount /= loopBounds[dim];
}
- SmallVector<SmallVector<uint64_t>> srcIndices;
- for (int i = 0; i < numInputs; ++i)
- srcIndices.emplace_back(loopBounds.size(), 0);
- SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
-
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
for (int i = 0; i < numInputs; ++i)
srcIndices[i][dim] = indices[inputDims[i][dim]];
dstIndices[dim] = indices[outputDims[dim]];
}
- uint64_t linearIndex1 = dstIndices.front();
- for (int dim = 1; dim < outputType.getRank(); ++dim)
- linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
+ dstLinearIndex = dstIndices.front();
+ for (int i = 0; i < numInputs; ++i)
+ srcLinearIndices[i] = srcIndices[i].front();
- // Collect constant elements for all inputs at this loop iteration.
- SmallVector<APInt> intValues;
- SmallVector<APFloat> fpValues;
- if (elementType.isa<FloatType>()) {
+ for (int dim = 1; dim < outputType.getRank(); ++dim) {
+ dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
for (int i = 0; i < numInputs; ++i)
- fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
- } else {
- for (int i = 0; i < numInputs; ++i)
- intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
+ srcLinearIndices[i] =
+ srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
+ }
+ };
+
+ bool isFloat = elementType.isa<FloatType>();
+ if (isFloat) {
+ SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
+ inputFpIterators;
+ for (int i = 0; i < numInputs; ++i)
+ inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
+
+ computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
+
+ // Transpose the input constant. Because we don't know its rank in
+ // advance, we need to loop over the range [0, element count) and
+ // delinearize the index.
+ for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+ computeRemappedLinearIndex(linearIndex);
+
+ // Collect constant elements for all inputs at this loop iteration.
+ for (int i = 0; i < numInputs; ++i) {
+ computeFnInputs.apFloats[i] =
+ *(inputFpIterators[i].begin() + srcLinearIndices[i]);
+ }
+
+ // Invoke the computation to get the corresponding constant output
+ // element.
+ APIntOrFloat outputs = computeFn(computeFnInputs);
+
+ fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
}
+ } else {
+ SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
+ inputIntIterators;
+ for (int i = 0; i < numInputs; ++i)
+ inputIntIterators.push_back(inputValues[i].getValues<APInt>());
+
+ computeFnInputs.apInts.resize(numInputs);
+
+ // Transpose the input constant. Because we don't know its rank in
+ // advance, we need to loop over the range [0, element count) and
+ // delinearize the index.
+ for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+ computeRemappedLinearIndex(linearIndex);
- // Invoke the computation to get the corresponding constant output
- // element.
- APIntOrFloatArray inputs = {intValues, fpValues};
- APIntOrFloatArray outputs = computeFn(inputs);
+ // Collect constant elements for all inputs at this loop iteration.
+ for (int i = 0; i < numInputs; ++i) {
+ computeFnInputs.apInts[i] =
+ *(inputIntIterators[i].begin() + srcLinearIndices[i]);
+ }
+
+ // Invoke the computation to get the corresponding constant output
+ // element.
+ APIntOrFloat outputs = computeFn(computeFnInputs);
- if (elementType.isa<FloatType>()) {
- fpOutputValues[linearIndex1] = outputs.apFloats.front();
- } else {
- intOutputValues[linearIndex1] = outputs.apInts.front();
+ intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
}
}
DenseIntOrFPElementsAttr outputAttr;
- if (elementType.isa<FloatType>()) {
+ if (isFloat) {
outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
} else {
outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
@@ -1494,7 +1550,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
}
// No computation; just return the orginal value.
- return [](APIntOrFloatArray inputs) { return inputs; };
+ return [](const APIntOrFloatArray &inputs) {
+ if (inputs.apFloats.empty())
+ return APIntOrFloat{inputs.apInts.front(), llvm::None};
+ return APIntOrFloat{llvm::None, inputs.apFloats.front()};
+ };
}
ControlElementwiseOpsFusionFn controlFn;
More information about the Mlir-commits
mailing list