[Mlir-commits] [mlir] b003ebd - [MLIR][Linalg] Generalize splat constant folding
Frederik Gossen
llvmlistbot at llvm.org
Tue Apr 27 00:09:01 PDT 2021
Author: Frederik Gossen
Date: 2021-04-27T09:08:34+02:00
New Revision: b003ebd603c9b16ad65527f89c1a9898598ce6ff
URL: https://github.com/llvm/llvm-project/commit/b003ebd603c9b16ad65527f89c1a9898598ce6ff
DIFF: https://github.com/llvm/llvm-project/commit/b003ebd603c9b16ad65527f89c1a9898598ce6ff.diff
LOG: [MLIR][Linalg] Generalize splat constant folding
Splat constant folding was limited to `std.constant` operations. Instead, use
the constant matcher and apply splat constant folding to any constant-like
operation that holds a splat attribute.
Differential Revision: https://reviews.llvm.org/D101301
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 628b5969fe00..5af62dafe6d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1311,10 +1312,12 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
return failure();
LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) {
- ConstantOp constantOp = operand.value().get().getDefiningOp<ConstantOp>();
- if (!constantOp ||
- !constantOp.value().cast<DenseElementsAttr>().isSplat() ||
- !controlFn(constantOp->getResult(0), operand.value()))
+ Operation *def = operand.value().get().getDefiningOp();
+ DenseElementsAttr constantAttr;
+ if (!def ||
+ !matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
+ !constantAttr.isSplat() ||
+ !controlFn(def->getResult(0), operand.value()))
continue;
// The indexing_maps for the operands of the fused operation are same as
@@ -1337,8 +1340,7 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
- constantOp.getLoc(),
- constantOp.value().cast<DenseElementsAttr>().getSplatValue());
+ def->getLoc(), constantAttr.getSplatValue());
LinalgOp fusedOp = createLinalgOpOfSameType(
linalgOp, rewriter, rewriter.getUnknownLoc(),
More information about the Mlir-commits
mailing list