[Mlir-commits] [mlir] b003ebd - [MLIR][Linalg] Generalize splat constant folding

Frederik Gossen llvmlistbot at llvm.org
Tue Apr 27 00:09:01 PDT 2021


Author: Frederik Gossen
Date: 2021-04-27T09:08:34+02:00
New Revision: b003ebd603c9b16ad65527f89c1a9898598ce6ff

URL: https://github.com/llvm/llvm-project/commit/b003ebd603c9b16ad65527f89c1a9898598ce6ff
DIFF: https://github.com/llvm/llvm-project/commit/b003ebd603c9b16ad65527f89c1a9898598ce6ff.diff

LOG: [MLIR][Linalg] Generalize splat constant folding

Splat constant folding was limited to `std.constant` operations. Instead, use
the constant matcher and apply splat constant folding to any constant-like
operation that holds a splat attribute.

Differential Revision: https://reviews.llvm.org/D101301

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 628b5969fe00..5af62dafe6d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1311,10 +1312,12 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
       return failure();
     LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
     for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) {
-      ConstantOp constantOp = operand.value().get().getDefiningOp<ConstantOp>();
-      if (!constantOp ||
-          !constantOp.value().cast<DenseElementsAttr>().isSplat() ||
-          !controlFn(constantOp->getResult(0), operand.value()))
+      Operation *def = operand.value().get().getDefiningOp();
+      DenseElementsAttr constantAttr;
+      if (!def ||
+          !matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
+          !constantAttr.isSplat() ||
+          !controlFn(def->getResult(0), operand.value()))
         continue;
 
       // The indexing_maps for the operands of the fused operation are same as
@@ -1337,8 +1340,7 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
 
       // Create a constant scalar value from the splat constant.
       Value scalarConstant = rewriter.create<ConstantOp>(
-          constantOp.getLoc(),
-          constantOp.value().cast<DenseElementsAttr>().getSplatValue());
+          def->getLoc(), constantAttr.getSplatValue());
 
       LinalgOp fusedOp = createLinalgOpOfSameType(
           linalgOp, rewriter, rewriter.getUnknownLoc(),


        


More information about the Mlir-commits mailing list