[Mlir-commits] [mlir] [mlir][linalg] Add linalg.transpose constant folding (PR #92589)
Ryan Holt
llvmlistbot at llvm.org
Fri May 17 11:53:06 PDT 2024
https://github.com/ryan-holt-1 created https://github.com/llvm/llvm-project/pull/92589
There was existing support for constant folding a `linalg.generic` that was actually a transpose. This commit adds support for the named op, `linalg.transpose`, as well by making use of the `LinalgOp` interface.
>From 0f644420e6627c17ff76ca42b29f3ccf867aa42d Mon Sep 17 00:00:00 2001
From: ryan-holt-1 <ryanholt at mathworks.com>
Date: Fri, 17 May 2024 14:35:15 -0400
Subject: [PATCH] [mlir][linalg] Add linalg.transpose constant folding
There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.
---
.../Linalg/Transforms/ConstantFold.cpp | 62 ++++++++++---------
.../Linalg/fusion-elementwise-ops.mlir | 12 ++++
2 files changed, 44 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 8fffabf11f3fd..2e6079e1402e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -23,21 +23,21 @@ using namespace mlir;
using namespace mlir::linalg;
namespace {
-/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
-/// and permutation indexing maps.
+/// Base class for constant folding linalg structured ops with N inputs, 1
+/// output, and permutation indexing maps.
///
/// `ConcreteType` should provide methods with signatures
///
/// ```c++
-/// bool matchIndexingMaps(GenericOp genericOp) const;
-/// RegionComputationFn getRegionComputeFn(GenericOp) const;
+/// bool matchIndexingMaps(LinalgOp linalgOp) const;
+/// RegionComputationFn getRegionComputeFn(LinalgOp) const;
/// ```
///
/// The latter inspects the region and returns the computation inside as a
/// functor. The functor will be invoked with constant elements for all inputs
/// and should return the corresponding computed constant element for output.
template <typename ConcreteType>
-class FoldConstantBase : public OpRewritePattern<GenericOp> {
+class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
public:
struct APIntOrFloat {
std::optional<APInt> apInt;
@@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ controlFn(controlFn) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
// Mixed and buffer sematics aren't supported.
- if (!genericOp.hasPureTensorSemantics())
+ if (!linalgOp.hasPureTensorSemantics())
return failure();
// Only support ops generating one output for now.
- if (genericOp.getNumDpsInits() != 1)
+ if (linalgOp.getNumDpsInits() != 1)
return failure();
- auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
+ auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
// Require the output types to be static given that we are generating
// constants.
if (!outputType || !outputType.hasStaticShape())
return failure();
- if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+ if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
return isa<ShapedType>(input.getType());
}))
return failure();
@@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
return cast<ShapedType>(value.getType()).getElementType();
};
if (!llvm::all_equal(
- llvm::map_range(genericOp->getOperands(), getOperandElementType)))
+ llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
return failure();
// We can only handle the case where we have int/float elements.
@@ -93,30 +94,30 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// entirely in the compiler, without needing to turn all indices into
// Values, and then do affine apply on them, and then match back the
// constant again.
- if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+ if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isPermutation(); }))
return failure();
- for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
- if (genericOp.payloadUsesValueFromOperand(&operand))
+ for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
+ if (linalgOp.payloadUsesValueFromOperand(&operand))
return failure();
}
// Further check the indexing maps are okay for the ConcreteType.
- if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+ if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
return failure();
// Defer to the concrete type to check the region and discover the
// computation inside.
RegionComputationFn computeFn =
- static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+ static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
if (!computeFn)
return failure();
// All inputs should be constants.
- int numInputs = genericOp.getNumDpsInputs();
+ int numInputs = linalgOp.getNumDpsInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
- for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
if (!matchPattern(en.value()->get(),
m_Constant(&inputValues[en.index()])))
return failure();
@@ -124,12 +125,11 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// Identified this as a potential candidate for folding. Now check the
// policy to see whether we are allowed to proceed.
- for (OpOperand *operand : genericOp.getDpsInputOperands()) {
+ for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
if (!controlFn(operand))
return failure();
}
- auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
int64_t numElements = outputType.getNumElements();
@@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
SmallVector<SmallVector<unsigned>> inputDims;
for (int i = 0; i < numInputs; ++i)
- inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
- auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
+ inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
+ auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
auto outputShape = outputType.getShape();
// Allocate small vectors for index delinearization. Initial values do not
@@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
APIntOrFloatArray computeFnInputs;
auto inputShapes = llvm::to_vector<4>(
- llvm::map_range(genericOp.getInputs(), [](Value value) {
+ llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
return cast<ShapedType>(value.getType()).getShape();
}));
@@ -254,7 +254,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
return success();
}
@@ -262,18 +262,20 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
ControlFusionFn controlFn;
};
-// Folds linalg.generic ops that are actually transposes on constant values.
+// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
+// on constant values.
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+
using FoldConstantBase::FoldConstantBase;
- bool matchIndexingMaps(GenericOp genericOp) const {
+ bool matchIndexingMaps(LinalgOp linalgOp) const {
// We should have one input and one output.
- return genericOp.getIndexingMapsArray().size() == 2;
+ return linalgOp.getIndexingMapsArray().size() == 2;
}
- RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+ RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
// Make sure the region only contains a yield op.
- Block &body = genericOp.getRegion().front();
+ Block &body = linalgOp->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return nullptr;
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 15a4f6cdd3bbe..70f43885712b7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -910,6 +910,18 @@ func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tenso
// -----
+// CHECK-LABEL: @named_transpose_fold_2d_fp32
+func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ // CHECK: %[[CST:.+]] = arith.constant
+ // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+ %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
// Fusing the broadcast into a reduction would require to insert extra knowledge
// about the size of the reduction dimension. As long, as this is not
// implemented, we check that two linalg operations remain.
More information about the Mlir-commits
mailing list