[Mlir-commits] [mlir] 4cd7ff6 - [mlir][linalg] Constant fold linalg.generic that are transposes
Lei Zhang
llvmlistbot at llvm.org
Fri Oct 8 05:12:37 PDT 2021
Author: Lei Zhang
Date: 2021-10-08T08:09:13-04:00
New Revision: 4cd7ff6728f440234def491be380e5af62f34b83
URL: https://github.com/llvm/llvm-project/commit/4cd7ff6728f440234def491be380e5af62f34b83
DIFF: https://github.com/llvm/llvm-project/commit/4cd7ff6728f440234def491be380e5af62f34b83.diff
LOG: [mlir][linalg] Constant fold linalg.generic that are transposes
This commit adds a pattern to perform constant folding on linalg
generic ops which are essentially transposes. We see real cases
where model importers may generate such patterns.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D110597
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 5f721e2098ee..96f7a1c3796e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1164,10 +1164,11 @@ struct FoldReshapeWithGenericOpByExpansion
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
/// handle cases where the constant is not single-valued.
-class FoldConstants : public OpRewritePattern<GenericOp> {
+class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
public:
- FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
- PatternBenefit benefit = 1)
+ FoldScalarOrSplatConstant(MLIRContext *context,
+ ControlElementwiseOpsFusionFn &fun,
+ PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
@@ -1268,6 +1269,237 @@ class FoldConstants : public OpRewritePattern<GenericOp> {
private:
ControlElementwiseOpsFusionFn controlFn;
};
+
+/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
+/// and permutation indexing maps.
+///
+/// `ConcreteType` should provide methods with signatures
+///
+/// ```c++
+/// bool matchIndexingMaps(GenericOp genericOp) const;
+/// RegionComputationFn getRegionComputeFn(GenericOp) const;
+/// ```
+///
+/// The latter inspects the region and returns the computation inside as a
+/// functor. The functor will be invoked with constant elements for all inputs
+/// and should return the corresponding computea constant element for output.
+template <typename ConcreteType>
+class FoldConstantBase : public OpRewritePattern<GenericOp> {
+public:
+ struct APIntOrFloatArray {
+ SmallVector<APInt> apInts;
+ SmallVector<APFloat> apFloats;
+ };
+ using RegionComputationFn =
+ std::function<APIntOrFloatArray(APIntOrFloatArray)>;
+
+ FoldConstantBase(MLIRContext *context,
+ const ControlElementwiseOpsFusionFn &controlFn,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (genericOp.hasBufferSemantics())
+ return failure();
+
+ // Only support ops generating one output for now.
+ if (genericOp.getNumOutputs() != 1)
+ return failure();
+
+ auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
+ // Require the output types to be static give we are generating constants.
+ if (!outputType || !outputType.hasStaticShape())
+ return failure();
+
+ if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
+ return operand->get().getType().isa<ShapedType>();
+ }))
+ return failure();
+
+ // Make sure all element types are the same.
+ auto getOperandElementType = [](OpOperand *operand) {
+ return operand->get().getType().cast<ShapedType>().getElementType();
+ };
+ if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
+ getOperandElementType)))
+ return failure();
+
+ // We can only handle the case where we have int/float elements.
+ auto elementType = outputType.getElementType();
+ if (!elementType.isIntOrFloat())
+ return failure();
+
+ // Require all indexing maps to be permutations for now. This is common and
+ // it simplifies input/output access greatly: we can do the data shuffling
+ // entirely in the compiler, without needing to turn all indices into
+ // Values, and then do affine apply on them, and then match back the
+ // constant again.
+ if (!llvm::all_of(genericOp.getIndexingMaps(),
+ [](AffineMap map) { return map.isPermutation(); }))
+ return failure();
+
+ for (OpOperand *operand : genericOp.getOutputOperands()) {
+ if (genericOp.payloadUsesValueFromOperand(operand))
+ return failure();
+ }
+
+ // Further check the indexing maps are okay for the ConcreteType.
+ if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+ return failure();
+
+ // Defer to the concrete type to check the region and discover the
+ // computation inside.
+ RegionComputationFn computeFn =
+ static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+ if (!computeFn)
+ return failure();
+
+ // All inputs should be constants.
+ int numInputs = genericOp.getNumInputs();
+ SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
+ for (auto operand : llvm::enumerate(genericOp.getInputOperands())) {
+ if (!matchPattern(operand.value()->get(),
+ m_Constant(&inputValues[operand.index()])))
+ return failure();
+ }
+
+ // Identified this as a potential candidate for folding. Now check the
+ // policy to see whether we are allowed to proceed.
+ for (int i = 0; i < numInputs; ++i) {
+ OpOperand *consumer = genericOp.getInputOperand(i);
+ OpResult producer = consumer->get().cast<OpResult>();
+ if (!controlFn(producer, *consumer))
+ return failure();
+ }
+
+ auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
+ SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
+ int64_t numElements = outputType.getNumElements();
+
+ // Use APInt/APFloat instead of Attribute here for constructing the output.
+ // This helps to avoid blowing up compiler memory usage: Attributes would
+ // unify the following cases but they have lifetime as the MLIRContext.
+ SmallVector<APInt> intOutputValues;
+ SmallVector<APFloat> fpOutputValues;
+ if (elementType.template isa<FloatType>())
+ fpOutputValues.resize(numElements, APFloat(0.f));
+ else
+ intOutputValues.resize(numElements);
+
+ // Return the constant dim positions from the given permutation map.
+ auto getDimPositions = [](AffineMap map) {
+ SmallVector<unsigned> dims;
+ dims.reserve(map.getNumResults());
+ for (AffineExpr result : map.getResults()) {
+ dims.push_back(result.cast<AffineDimExpr>().getPosition());
+ }
+ return dims;
+ };
+
+ SmallVector<SmallVector<unsigned>> inputDims;
+ for (int i = 0; i < numInputs; ++i)
+ inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
+ auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
+ auto outputShape = outputType.getShape();
+
+ // Transpose the input constant. Because we don't know its rank in advance,
+ // we need to loop over the range [0, element count) and delinearize the
+ // index.
+ for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
+ SmallVector<uint64_t> indices(loopBounds.size(), 0);
+ int totalCount = linearIndex0;
+ for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
+ indices[dim] = totalCount % loopBounds[dim];
+ totalCount /= loopBounds[dim];
+ }
+
+ SmallVector<SmallVector<uint64_t>> srcIndices;
+ for (int i = 0; i < numInputs; ++i)
+ srcIndices.emplace_back(loopBounds.size(), 0);
+ SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
+
+ for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
+ for (int i = 0; i < numInputs; ++i)
+ srcIndices[i][dim] = indices[inputDims[i][dim]];
+ dstIndices[dim] = indices[outputDims[dim]];
+ }
+
+ uint64_t linearIndex1 = dstIndices.front();
+ for (int dim = 1; dim < outputType.getRank(); ++dim)
+ linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
+
+ // Collect constant elements for all inputs at this loop iteration.
+ SmallVector<APInt> intValues;
+ SmallVector<APFloat> fpValues;
+ if (elementType.isa<FloatType>()) {
+ for (int i = 0; i < numInputs; ++i)
+ fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
+ } else {
+ for (int i = 0; i < numInputs; ++i)
+ intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
+ }
+
+ // Invoke the computation to get the corresponding constant output
+ // element.
+ APIntOrFloatArray inputs = {intValues, fpValues};
+ APIntOrFloatArray outputs = computeFn(inputs);
+
+ if (elementType.isa<FloatType>()) {
+ fpOutputValues[linearIndex1] = outputs.apFloats.front();
+ } else {
+ intOutputValues[linearIndex1] = outputs.apInts.front();
+ }
+ }
+
+ DenseIntOrFPElementsAttr outputAttr;
+ if (elementType.isa<FloatType>()) {
+ outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
+ } else {
+ outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
+ }
+ rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
+ return success();
+ }
+
+private:
+ ControlElementwiseOpsFusionFn controlFn;
+};
+
+// Folds linalg.generic ops that are actually transposes on constant values.
+struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+ using FoldConstantBase::FoldConstantBase;
+
+ bool matchIndexingMaps(GenericOp genericOp) const {
+ // We should have one input and one output.
+ return genericOp.getIndexingMaps().size() == 2;
+ }
+
+ RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+ // Make sure the region only contains a yield op.
+ Block &body = genericOp.region().front();
+ if (!llvm::hasSingleElement(body))
+ return nullptr;
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+ if (!yieldOp)
+ return nullptr;
+
+ // The yield op should return the block argument corresponds to the input.
+ for (Value yieldVal : yieldOp.values()) {
+ auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+ if (!yieldArg || yieldArg.getOwner() != &body)
+ return nullptr;
+ if (yieldArg.getArgNumber() != 0)
+ return nullptr;
+ }
+
+ // No computation; just return the orginal value.
+ return [](APIntOrFloatArray inputs) { return inputs; };
+ }
+
+ ControlElementwiseOpsFusionFn controlFn;
+};
+
} // namespace
static Optional<SmallVector<Value>>
@@ -1442,8 +1674,9 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
auto *context = patterns.getContext();
- patterns.add<FuseElementwiseOps, FoldConstants>(
- context, options.controlElementwiseOpsFusionFn);
+ patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
+ FoldConstantTranspose>(context,
+ options.controlElementwiseOpsFusionFn);
patterns.add<RemoveOutsDependency>(context);
populateFoldReshapeOpsByExpansionPatterns(patterns,
options.controlFoldingReshapesFn);
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 985335a5f952..d72af1253866 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -755,15 +755,15 @@ func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<
%2:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>,
- affine_map<(d0, d1) -> ()>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
+ affine_map<(d0, d1) -> ()>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
%3 = addf %arg1, %arg2 : f32
- linalg.yield %3, %arg3 : f32, i32
+ linalg.yield %3, %arg3 : f32, i32
} -> (tensor<?x?xf32>, tensor<?x?xi32>)
return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
}
@@ -774,3 +774,136 @@ func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<
// CHECK-SAME: ins(%{{.+}} : tensor<?x?xf32>)
// CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32
// CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_fp32
+func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ // CHECK: %[[CST:.+]] = constant
+ // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> tensor<3x2xf32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_fp64
+func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
+ %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
+ // CHECK: %[[CST:.+]] = constant
+ // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
+ ^bb0(%arg1: f64, %arg2: f64):
+ linalg.yield %arg1 : f64
+ } -> tensor<3x2xf64>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i32
+func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
+ %input = constant dense<[[
+ [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
+ [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+ ]]> : tensor<1x2x3x4xi32>
+ // CHECK: %[[CST:.+]] = constant dense<[
+ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+ // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+ // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+ // CHECK-SAME{LITERAL}: ]>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32):
+ linalg.yield %arg1 : i32
+ } -> tensor<3x1x4x2xi32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i16
+func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
+ %input = constant dense<[[
+ [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
+ [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+ ]]> : tensor<1x2x3x4xi16>
+ // CHECK: %[[CST:.+]] = constant dense<[
+ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+ // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+ // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+ // CHECK-SAME{LITERAL}: ]>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
+ ^bb0(%arg1: i16, %arg2: i16):
+ linalg.yield %arg1 : i16
+ } -> tensor<3x1x4x2xi16>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x1x4x2xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_yield_const
+func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ %cst = constant 8.0 : f32
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %cst : f32
+ } -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
+func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %add = addf %arg1, %arg1 : f32
+ linalg.yield %add : f32
+ } -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
More information about the Mlir-commits
mailing list