[Mlir-commits] [mlir] [mlir][linalg] Add linalg.transpose constant folding (PR #92589)

Ryan Holt llvmlistbot at llvm.org
Fri May 17 11:53:06 PDT 2024


https://github.com/ryan-holt-1 created https://github.com/llvm/llvm-project/pull/92589

There was existing support for constant folding a `linalg.generic` that was actually a transpose. This commit adds support for the named op, `linalg.transpose`, as well by making use of the `LinalgOp` interface.

>From 0f644420e6627c17ff76ca42b29f3ccf867aa42d Mon Sep 17 00:00:00 2001
From: ryan-holt-1 <ryanholt at mathworks.com>
Date: Fri, 17 May 2024 14:35:15 -0400
Subject: [PATCH] [mlir][linalg] Add linalg.transpose constant folding

There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.
---
 .../Linalg/Transforms/ConstantFold.cpp        | 62 ++++++++++---------
 .../Linalg/fusion-elementwise-ops.mlir        | 12 ++++
 2 files changed, 44 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 8fffabf11f3fd..2e6079e1402e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -23,21 +23,21 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 namespace {
-/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
-/// and permutation indexing maps.
+/// Base class for constant folding linalg structured ops with N inputs, 1
+/// output, and permutation indexing maps.
 ///
 /// `ConcreteType` should provide methods with signatures
 ///
 /// ```c++
-///   bool matchIndexingMaps(GenericOp genericOp) const;
-///   RegionComputationFn getRegionComputeFn(GenericOp) const;
+///   bool matchIndexingMaps(LinalgOp linalgOp) const;
+///   RegionComputationFn getRegionComputeFn(LinalgOp) const;
 /// ```
 ///
 /// The latter inspects the region and returns the computation inside as a
 /// functor. The functor will be invoked with constant elements for all inputs
 /// and should return the corresponding computed constant element for output.
 template <typename ConcreteType>
-class FoldConstantBase : public OpRewritePattern<GenericOp> {
+class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
 public:
   struct APIntOrFloat {
     std::optional<APInt> apInt;
@@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
   FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
                    PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+        controlFn(controlFn) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     // Mixed and buffer sematics aren't supported.
-    if (!genericOp.hasPureTensorSemantics())
+    if (!linalgOp.hasPureTensorSemantics())
       return failure();
 
     // Only support ops generating one output for now.
-    if (genericOp.getNumDpsInits() != 1)
+    if (linalgOp.getNumDpsInits() != 1)
       return failure();
 
-    auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
+    auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
     // Require the output types to be static given that we are generating
     // constants.
     if (!outputType || !outputType.hasStaticShape())
       return failure();
 
-    if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+    if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
           return isa<ShapedType>(input.getType());
         }))
       return failure();
@@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
       return cast<ShapedType>(value.getType()).getElementType();
     };
     if (!llvm::all_equal(
-            llvm::map_range(genericOp->getOperands(), getOperandElementType)))
+            llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
       return failure();
 
     // We can only handle the case where we have int/float elements.
@@ -93,30 +94,30 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     // entirely in the compiler, without needing to turn all indices into
     // Values, and then do affine apply on them, and then match back the
     // constant again.
-    if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+    if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
                       [](AffineMap map) { return map.isPermutation(); }))
       return failure();
 
-    for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
-      if (genericOp.payloadUsesValueFromOperand(&operand))
+    for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
+      if (linalgOp.payloadUsesValueFromOperand(&operand))
         return failure();
     }
 
     // Further check the indexing maps are okay for the ConcreteType.
-    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
       return failure();
 
     // Defer to the concrete type to check the region and discover the
     // computation inside.
     RegionComputationFn computeFn =
-        static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+        static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
     if (!computeFn)
       return failure();
 
     // All inputs should be constants.
-    int numInputs = genericOp.getNumDpsInputs();
+    int numInputs = linalgOp.getNumDpsInputs();
     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
-    for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
       if (!matchPattern(en.value()->get(),
                         m_Constant(&inputValues[en.index()])))
         return failure();
@@ -124,12 +125,11 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     // Identified this as a potential candidate for folding. Now check the
     // policy to see whether we are allowed to proceed.
-    for (OpOperand *operand : genericOp.getDpsInputOperands()) {
+    for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
       if (!controlFn(operand))
         return failure();
     }
 
-    auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
     int64_t numElements = outputType.getNumElements();
 
@@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     SmallVector<SmallVector<unsigned>> inputDims;
     for (int i = 0; i < numInputs; ++i)
-      inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
-    auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
+      inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
+    auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
     auto outputShape = outputType.getShape();
 
     // Allocate small vectors for index delinearization. Initial values do not
@@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     APIntOrFloatArray computeFnInputs;
 
     auto inputShapes = llvm::to_vector<4>(
-        llvm::map_range(genericOp.getInputs(), [](Value value) {
+        llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
           return cast<ShapedType>(value.getType()).getShape();
         }));
 
@@ -254,7 +254,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
                 : DenseElementsAttr::get(outputType, intOutputValues);
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
     return success();
   }
 
@@ -262,18 +262,20 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
   ControlFusionFn controlFn;
 };
 
-// Folds linalg.generic ops that are actually transposes on constant values.
+// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
+// on constant values.
 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+
   using FoldConstantBase::FoldConstantBase;
 
-  bool matchIndexingMaps(GenericOp genericOp) const {
+  bool matchIndexingMaps(LinalgOp linalgOp) const {
     // We should have one input and one output.
-    return genericOp.getIndexingMapsArray().size() == 2;
+    return linalgOp.getIndexingMapsArray().size() == 2;
   }
 
-  RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+  RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
     // Make sure the region only contains a yield op.
-    Block &body = genericOp.getRegion().front();
+    Block &body = linalgOp->getRegion(0).front();
     if (!llvm::hasSingleElement(body))
       return nullptr;
     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 15a4f6cdd3bbe..70f43885712b7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -910,6 +910,18 @@ func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tenso
 
 // -----
 
+// CHECK-LABEL: @named_transpose_fold_2d_fp32
+func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  //               CHECK: %[[CST:.+]] = arith.constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
 // Fusing the broadcast into a reduction would require to insert extra knowledge
 // about the size of the reduction dimension. As long, as this is not
 // implemented, we check that two linalg operations remain.



More information about the Mlir-commits mailing list