[Mlir-commits] [mlir] c788cad - [mlir][linalg] Fix FoldConstantTranspose execution inefficiency

Lei Zhang llvmlistbot at llvm.org
Thu Oct 28 06:49:13 PDT 2021


Author: Lei Zhang
Date: 2021-10-28T09:45:14-04:00
New Revision: c788cad83b6b5c24f8160f9fc11a69dd7beafb8b

URL: https://github.com/llvm/llvm-project/commit/c788cad83b6b5c24f8160f9fc11a69dd7beafb8b
DIFF: https://github.com/llvm/llvm-project/commit/c788cad83b6b5c24f8160f9fc11a69dd7beafb8b.diff

LOG: [mlir][linalg] Fix FoldConstantTranspose execution inefficiency

* Move SmallVectors outside of inner loops to avoid frequent
  allocations and deallocations
* Calculate linearized index and call flat range getters to
  avoid internal shape querying behind `getValue`.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D112099

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 32ad335954b13..ee5622d2662db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1286,12 +1286,16 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
 template <typename ConcreteType>
 class FoldConstantBase : public OpRewritePattern<GenericOp> {
 public:
+  struct APIntOrFloat {
+    Optional<APInt> apInt;
+    Optional<APFloat> apFloat;
+  };
   struct APIntOrFloatArray {
     SmallVector<APInt> apInts;
     SmallVector<APFloat> apFloats;
   };
   using RegionComputationFn =
-      std::function<APIntOrFloatArray(APIntOrFloatArray)>;
+      std::function<APIntOrFloat(const APIntOrFloatArray &)>;
 
   FoldConstantBase(MLIRContext *context,
                    const ControlElementwiseOpsFusionFn &controlFn,
@@ -1403,57 +1407,109 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
     auto outputShape = outputType.getShape();
 
-    // Transpose the input constant. Because we don't know its rank in advance,
-    // we need to loop over the range [0, element count) and delinearize the
-    // index.
-    for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
-      SmallVector<uint64_t> indices(loopBounds.size(), 0);
-      int totalCount = linearIndex0;
+    // Allocate small vectors for index delinearization. Initial values do not
+    // matter here as they will be overwritten later.
+    SmallVector<uint64_t> indices(loopBounds.size(), 0);
+    SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
+    SmallVector<SmallVector<uint64_t>> srcIndices(
+        numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
+    SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
+    uint64_t dstLinearIndex = 0;
+
+    // Allocate spaces for compute function inputs. Initial values do not matter
+    // here as they will be overwritten later.
+    APIntOrFloatArray computeFnInputs;
+
+    auto inputShapes = llvm::to_vector<4>(
+        llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
+          return operand->get().getType().cast<ShapedType>().getShape();
+        }));
+
+    // Given a `linearIndex`, remap it to a linear index to access linalg op
+    // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
+    // `srcLinearIndices`, `dstLinearIndex` in place.
+    auto computeRemappedLinearIndex = [&](int linearIndex) {
+      int totalCount = linearIndex;
       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
         indices[dim] = totalCount % loopBounds[dim];
         totalCount /= loopBounds[dim];
       }
 
-      SmallVector<SmallVector<uint64_t>> srcIndices;
-      for (int i = 0; i < numInputs; ++i)
-        srcIndices.emplace_back(loopBounds.size(), 0);
-      SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
-
       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
         for (int i = 0; i < numInputs; ++i)
           srcIndices[i][dim] = indices[inputDims[i][dim]];
         dstIndices[dim] = indices[outputDims[dim]];
       }
 
-      uint64_t linearIndex1 = dstIndices.front();
-      for (int dim = 1; dim < outputType.getRank(); ++dim)
-        linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
+      dstLinearIndex = dstIndices.front();
+      for (int i = 0; i < numInputs; ++i)
+        srcLinearIndices[i] = srcIndices[i].front();
 
-      // Collect constant elements for all inputs at this loop iteration.
-      SmallVector<APInt> intValues;
-      SmallVector<APFloat> fpValues;
-      if (elementType.isa<FloatType>()) {
+      for (int dim = 1; dim < outputType.getRank(); ++dim) {
+        dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
         for (int i = 0; i < numInputs; ++i)
-          fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
-      } else {
-        for (int i = 0; i < numInputs; ++i)
-          intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
+          srcLinearIndices[i] =
+              srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
+      }
+    };
+
+    bool isFloat = elementType.isa<FloatType>();
+    if (isFloat) {
+      SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
+          inputFpIterators;
+      for (int i = 0; i < numInputs; ++i)
+        inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
+
+      computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
+
+      // Transpose the input constant. Because we don't know its rank in
+      // advance, we need to loop over the range [0, element count) and
+      // delinearize the index.
+      for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+        computeRemappedLinearIndex(linearIndex);
+
+        // Collect constant elements for all inputs at this loop iteration.
+        for (int i = 0; i < numInputs; ++i) {
+          computeFnInputs.apFloats[i] =
+              *(inputFpIterators[i].begin() + srcLinearIndices[i]);
+        }
+
+        // Invoke the computation to get the corresponding constant output
+        // element.
+        APIntOrFloat outputs = computeFn(computeFnInputs);
+
+        fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
       }
+    } else {
+      SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
+          inputIntIterators;
+      for (int i = 0; i < numInputs; ++i)
+        inputIntIterators.push_back(inputValues[i].getValues<APInt>());
+
+      computeFnInputs.apInts.resize(numInputs);
+
+      // Transpose the input constant. Because we don't know its rank in
+      // advance, we need to loop over the range [0, element count) and
+      // delinearize the index.
+      for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+        computeRemappedLinearIndex(linearIndex);
 
-      // Invoke the computation to get the corresponding constant output
-      // element.
-      APIntOrFloatArray inputs = {intValues, fpValues};
-      APIntOrFloatArray outputs = computeFn(inputs);
+        // Collect constant elements for all inputs at this loop iteration.
+        for (int i = 0; i < numInputs; ++i) {
+          computeFnInputs.apInts[i] =
+              *(inputIntIterators[i].begin() + srcLinearIndices[i]);
+        }
+
+        // Invoke the computation to get the corresponding constant output
+        // element.
+        APIntOrFloat outputs = computeFn(computeFnInputs);
 
-      if (elementType.isa<FloatType>()) {
-        fpOutputValues[linearIndex1] = outputs.apFloats.front();
-      } else {
-        intOutputValues[linearIndex1] = outputs.apInts.front();
+        intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
       }
     }
 
     DenseIntOrFPElementsAttr outputAttr;
-    if (elementType.isa<FloatType>()) {
+    if (isFloat) {
       outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
     } else {
       outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
@@ -1494,7 +1550,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
     }
 
     // No computation; just return the orginal value.
-    return [](APIntOrFloatArray inputs) { return inputs; };
+    return [](const APIntOrFloatArray &inputs) {
+      if (inputs.apFloats.empty())
+        return APIntOrFloat{inputs.apInts.front(), llvm::None};
+      return APIntOrFloat{llvm::None, inputs.apFloats.front()};
+    };
   }
 
   ControlElementwiseOpsFusionFn controlFn;


        


More information about the Mlir-commits mailing list