[Mlir-commits] [mlir] 8e4f8d3 - [mlir][sparse] merge ifs in new sparse rewriting rules
Aart Bik
llvmlistbot at llvm.org
Fri Feb 25 15:06:56 PST 2022
Author: Aart Bik
Date: 2022-02-25T15:06:47-08:00
New Revision: 8e4f8d353263cf4b5505febcfed5fdeecb6d7d85
URL: https://github.com/llvm/llvm-project/commit/8e4f8d353263cf4b5505febcfed5fdeecb6d7d85
DIFF: https://github.com/llvm/llvm-project/commit/8e4f8d353263cf4b5505febcfed5fdeecb6d7d85.diff
LOG: [mlir][sparse] merge ifs in new sparse rewriting rules
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D120500
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
index 3958ab3baf178..597b9b38e1b71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
@@ -45,15 +45,8 @@ static bool isSparseTensor(OpOperand *op) {
// Helper method to find zero or empty initialization.
static bool isEmptyInit(OpOperand *op) {
Value val = op->get();
- if (matchPattern(val, m_Zero()))
- return true;
- if (matchPattern(val, m_AnyZeroFloat()))
- return true;
- if (val.getDefiningOp<InitTensorOp>())
- return true;
- if (val.getDefiningOp<InitOp>())
- return true;
- return false;
+ return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
+ val.getDefiningOp<InitTensorOp>() || val.getDefiningOp<InitOp>();
}
// Helper to detect sampling operation.
@@ -123,11 +116,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {
// Check consumer.
if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
- op.getNumResults() != 1)
- return failure();
- if (op.getNumParallelLoops() != op.getNumLoops())
- return failure();
- if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
+ op.getNumResults() != 1 ||
+ op.getNumParallelLoops() != op.getNumLoops() ||
+ !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
!op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
!op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
return failure();
@@ -143,15 +134,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
// Check producer.
auto prod = dyn_cast_or_null<GenericOp>(
op.getInputOperand(other)->get().getDefiningOp());
- if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1)
- return failure();
- if (!prod.getResult(0).hasOneUse())
+ if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
+ !prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
if (!isEmptyInit(op.getOutputOperand(0)) ||
- !isEmptyInit(prod.getOutputOperand(0)))
- return failure();
- if (!isSampling(op) || !isSumOfMul(prod))
+ !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
+ !isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
Location loc = prod.getLoc();
More information about the Mlir-commits
mailing list