[Mlir-commits] [mlir] 4cd7ff6 - [mlir][linalg] Constant fold linalg.generic that are transposes

Lei Zhang llvmlistbot at llvm.org
Fri Oct 8 05:12:37 PDT 2021


Author: Lei Zhang
Date: 2021-10-08T08:09:13-04:00
New Revision: 4cd7ff6728f440234def491be380e5af62f34b83

URL: https://github.com/llvm/llvm-project/commit/4cd7ff6728f440234def491be380e5af62f34b83
DIFF: https://github.com/llvm/llvm-project/commit/4cd7ff6728f440234def491be380e5af62f34b83.diff

LOG: [mlir][linalg] Constant fold linalg.generic that are transposes

This commit adds a pattern to perform constant folding on linalg
generic ops which are essentially transposes. We see real cases
where model importers may generate such patterns.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D110597

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 5f721e2098ee..96f7a1c3796e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1164,10 +1164,11 @@ struct FoldReshapeWithGenericOpByExpansion
 
 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
 /// handle cases where the constant is not single-valued.
-class FoldConstants : public OpRewritePattern<GenericOp> {
+class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
 public:
-  FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
-                PatternBenefit benefit = 1)
+  FoldScalarOrSplatConstant(MLIRContext *context,
+                            ControlElementwiseOpsFusionFn &fun,
+                            PatternBenefit benefit = 1)
       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
@@ -1268,6 +1269,237 @@ class FoldConstants : public OpRewritePattern<GenericOp> {
 private:
   ControlElementwiseOpsFusionFn controlFn;
 };
+
+/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
+/// and permutation indexing maps.
+///
+/// `ConcreteType` should provide methods with signatures
+///
+/// ```c++
+///   bool matchIndexingMaps(GenericOp genericOp) const;
+///   RegionComputationFn getRegionComputeFn(GenericOp) const;
+/// ```
+///
+/// The latter inspects the region and returns the computation inside as a
+/// functor. The functor will be invoked with constant elements for all inputs
+/// and should return the corresponding computea constant element for output.
+template <typename ConcreteType>
+class FoldConstantBase : public OpRewritePattern<GenericOp> {
+public:
+  struct APIntOrFloatArray {
+    SmallVector<APInt> apInts;
+    SmallVector<APFloat> apFloats;
+  };
+  using RegionComputationFn =
+      std::function<APIntOrFloatArray(APIntOrFloatArray)>;
+
+  FoldConstantBase(MLIRContext *context,
+                   const ControlElementwiseOpsFusionFn &controlFn,
+                   PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (genericOp.hasBufferSemantics())
+      return failure();
+
+    // Only support ops generating one output for now.
+    if (genericOp.getNumOutputs() != 1)
+      return failure();
+
+    auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
+    // Require the output types to be static give we are generating constants.
+    if (!outputType || !outputType.hasStaticShape())
+      return failure();
+
+    if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
+          return operand->get().getType().isa<ShapedType>();
+        }))
+      return failure();
+
+    // Make sure all element types are the same.
+    auto getOperandElementType = [](OpOperand *operand) {
+      return operand->get().getType().cast<ShapedType>().getElementType();
+    };
+    if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
+                                        getOperandElementType)))
+      return failure();
+
+    // We can only handle the case where we have int/float elements.
+    auto elementType = outputType.getElementType();
+    if (!elementType.isIntOrFloat())
+      return failure();
+
+    // Require all indexing maps to be permutations for now. This is common and
+    // it simplifies input/output access greatly: we can do the data shuffling
+    // entirely in the compiler, without needing to turn all indices into
+    // Values, and then do affine apply on them, and then match back the
+    // constant again.
+    if (!llvm::all_of(genericOp.getIndexingMaps(),
+                      [](AffineMap map) { return map.isPermutation(); }))
+      return failure();
+
+    for (OpOperand *operand : genericOp.getOutputOperands()) {
+      if (genericOp.payloadUsesValueFromOperand(operand))
+        return failure();
+    }
+
+    // Further check the indexing maps are okay for the ConcreteType.
+    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+      return failure();
+
+    // Defer to the concrete type to check the region and discover the
+    // computation inside.
+    RegionComputationFn computeFn =
+        static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+    if (!computeFn)
+      return failure();
+
+    // All inputs should be constants.
+    int numInputs = genericOp.getNumInputs();
+    SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
+    for (auto operand : llvm::enumerate(genericOp.getInputOperands())) {
+      if (!matchPattern(operand.value()->get(),
+                        m_Constant(&inputValues[operand.index()])))
+        return failure();
+    }
+
+    // Identified this as a potential candidate for folding. Now check the
+    // policy to see whether we are allowed to proceed.
+    for (int i = 0; i < numInputs; ++i) {
+      OpOperand *consumer = genericOp.getInputOperand(i);
+      OpResult producer = consumer->get().cast<OpResult>();
+      if (!controlFn(producer, *consumer))
+        return failure();
+    }
+
+    auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
+    SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
+    int64_t numElements = outputType.getNumElements();
+
+    // Use APInt/APFloat instead of Attribute here for constructing the output.
+    // This helps to avoid blowing up compiler memory usage: Attributes would
+    // unify the following cases but they have lifetime as the MLIRContext.
+    SmallVector<APInt> intOutputValues;
+    SmallVector<APFloat> fpOutputValues;
+    if (elementType.template isa<FloatType>())
+      fpOutputValues.resize(numElements, APFloat(0.f));
+    else
+      intOutputValues.resize(numElements);
+
+    // Return the constant dim positions from the given permutation map.
+    auto getDimPositions = [](AffineMap map) {
+      SmallVector<unsigned> dims;
+      dims.reserve(map.getNumResults());
+      for (AffineExpr result : map.getResults()) {
+        dims.push_back(result.cast<AffineDimExpr>().getPosition());
+      }
+      return dims;
+    };
+
+    SmallVector<SmallVector<unsigned>> inputDims;
+    for (int i = 0; i < numInputs; ++i)
+      inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
+    auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
+    auto outputShape = outputType.getShape();
+
+    // Transpose the input constant. Because we don't know its rank in advance,
+    // we need to loop over the range [0, element count) and delinearize the
+    // index.
+    for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
+      SmallVector<uint64_t> indices(loopBounds.size(), 0);
+      int totalCount = linearIndex0;
+      for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
+        indices[dim] = totalCount % loopBounds[dim];
+        totalCount /= loopBounds[dim];
+      }
+
+      SmallVector<SmallVector<uint64_t>> srcIndices;
+      for (int i = 0; i < numInputs; ++i)
+        srcIndices.emplace_back(loopBounds.size(), 0);
+      SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
+
+      for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
+        for (int i = 0; i < numInputs; ++i)
+          srcIndices[i][dim] = indices[inputDims[i][dim]];
+        dstIndices[dim] = indices[outputDims[dim]];
+      }
+
+      uint64_t linearIndex1 = dstIndices.front();
+      for (int dim = 1; dim < outputType.getRank(); ++dim)
+        linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
+
+      // Collect constant elements for all inputs at this loop iteration.
+      SmallVector<APInt> intValues;
+      SmallVector<APFloat> fpValues;
+      if (elementType.isa<FloatType>()) {
+        for (int i = 0; i < numInputs; ++i)
+          fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
+      } else {
+        for (int i = 0; i < numInputs; ++i)
+          intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
+      }
+
+      // Invoke the computation to get the corresponding constant output
+      // element.
+      APIntOrFloatArray inputs = {intValues, fpValues};
+      APIntOrFloatArray outputs = computeFn(inputs);
+
+      if (elementType.isa<FloatType>()) {
+        fpOutputValues[linearIndex1] = outputs.apFloats.front();
+      } else {
+        intOutputValues[linearIndex1] = outputs.apInts.front();
+      }
+    }
+
+    DenseIntOrFPElementsAttr outputAttr;
+    if (elementType.isa<FloatType>()) {
+      outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
+    } else {
+      outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
+    }
+    rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
+    return success();
+  }
+
+private:
+  ControlElementwiseOpsFusionFn controlFn;
+};
+
+// Folds linalg.generic ops that are actually transposes on constant values.
+struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+  using FoldConstantBase::FoldConstantBase;
+
+  bool matchIndexingMaps(GenericOp genericOp) const {
+    // We should have one input and one output.
+    return genericOp.getIndexingMaps().size() == 2;
+  }
+
+  RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+    // Make sure the region only contains a yield op.
+    Block &body = genericOp.region().front();
+    if (!llvm::hasSingleElement(body))
+      return nullptr;
+    auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+    if (!yieldOp)
+      return nullptr;
+
+    // The yield op should return the block argument corresponds to the input.
+    for (Value yieldVal : yieldOp.values()) {
+      auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+      if (!yieldArg || yieldArg.getOwner() != &body)
+        return nullptr;
+      if (yieldArg.getArgNumber() != 0)
+        return nullptr;
+    }
+
+    // No computation; just return the orginal value.
+    return [](APIntOrFloatArray inputs) { return inputs; };
+  }
+
+  ControlElementwiseOpsFusionFn controlFn;
+};
+
 } // namespace
 
 static Optional<SmallVector<Value>>
@@ -1442,8 +1674,9 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
   auto *context = patterns.getContext();
-  patterns.add<FuseElementwiseOps, FoldConstants>(
-      context, options.controlElementwiseOpsFusionFn);
+  patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
+               FoldConstantTranspose>(context,
+                                      options.controlElementwiseOpsFusionFn);
   patterns.add<RemoveOutsDependency>(context);
   populateFoldReshapeOpsByExpansionPatterns(patterns,
                                             options.controlFoldingReshapesFn);

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 985335a5f952..d72af1253866 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -755,15 +755,15 @@ func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<
   %2:2 = linalg.generic {
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                        affine_map<(d0, d1) -> ()>,
-		       affine_map<(d0, d1) -> ()>,
-		       affine_map<(d0, d1) -> (d0, d1)>,
-		       affine_map<(d0, d1) -> (d0, d1)>],
+                       affine_map<(d0, d1) -> ()>,
+                       affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}
       ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
       outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
       ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
         %3 = addf %arg1, %arg2 : f32
-	linalg.yield %3, %arg3 : f32, i32
+        linalg.yield %3, %arg3 : f32, i32
       } -> (tensor<?x?xf32>, tensor<?x?xi32>)
   return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
 }
@@ -774,3 +774,136 @@ func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<
 //  CHECK-SAME:       ins(%{{.+}} : tensor<?x?xf32>)
 //       CHECK:     %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32
 //       CHECK:     linalg.yield %[[YIELD]], %[[C42]] : f32, i32
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_fp32
+func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  //               CHECK: %[[CST:.+]] = constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %arg1 : f32
+  } -> tensor<3x2xf32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_fp64
+func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
+  %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
+  //               CHECK: %[[CST:.+]] = constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
+  ^bb0(%arg1: f64, %arg2: f64):
+    linalg.yield %arg1 : f64
+  } -> tensor<3x2xf64>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i32
+func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
+  %input = constant dense<[[
+    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
+    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+  ]]> : tensor<1x2x3x4xi32>
+  //               CHECK: %[[CST:.+]] = constant dense<[
+  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+  // CHECK-SAME{LITERAL}: ]>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
+  ^bb0(%arg1: i32, %arg2: i32):
+    linalg.yield %arg1 : i32
+  } -> tensor<3x1x4x2xi32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i16
+func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
+  %input = constant dense<[[
+    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
+    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+  ]]> : tensor<1x2x3x4xi16>
+  //               CHECK: %[[CST:.+]] = constant dense<[
+  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+  // CHECK-SAME{LITERAL}: ]>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
+  ^bb0(%arg1: i16, %arg2: i16):
+    linalg.yield %arg1 : i16
+  } -> tensor<3x1x4x2xi16>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x1x4x2xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %arg1 : f32
+  } -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_yield_const
+func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  %cst = constant 8.0 : f32
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %cst : f32
+  } -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
+func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    %add = addf %arg1, %arg1 : f32
+    linalg.yield %add : f32
+  } -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}


        


More information about the Mlir-commits mailing list