[Mlir-commits] [mlir] 8e4f8d3 - [mlir][sparse] merge ifs in new sparse rewriting rules

Aart Bik llvmlistbot at llvm.org
Fri Feb 25 15:06:56 PST 2022


Author: Aart Bik
Date: 2022-02-25T15:06:47-08:00
New Revision: 8e4f8d353263cf4b5505febcfed5fdeecb6d7d85

URL: https://github.com/llvm/llvm-project/commit/8e4f8d353263cf4b5505febcfed5fdeecb6d7d85
DIFF: https://github.com/llvm/llvm-project/commit/8e4f8d353263cf4b5505febcfed5fdeecb6d7d85.diff

LOG: [mlir][sparse] merge ifs in new sparse rewriting rules

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D120500

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
index 3958ab3baf178..597b9b38e1b71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
@@ -45,15 +45,8 @@ static bool isSparseTensor(OpOperand *op) {
 // Helper method to find zero or empty initialization.
 static bool isEmptyInit(OpOperand *op) {
   Value val = op->get();
-  if (matchPattern(val, m_Zero()))
-    return true;
-  if (matchPattern(val, m_AnyZeroFloat()))
-    return true;
-  if (val.getDefiningOp<InitTensorOp>())
-    return true;
-  if (val.getDefiningOp<InitOp>())
-    return true;
-  return false;
+  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
+         val.getDefiningOp<InitTensorOp>() || val.getDefiningOp<InitOp>();
 }
 
 // Helper to detect sampling operation.
@@ -123,11 +116,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
                                 PatternRewriter &rewriter) const override {
     // Check consumer.
     if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
-        op.getNumResults() != 1)
-      return failure();
-    if (op.getNumParallelLoops() != op.getNumLoops())
-      return failure();
-    if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
+        op.getNumResults() != 1 ||
+        op.getNumParallelLoops() != op.getNumLoops() ||
+        !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
         !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
         !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
       return failure();
@@ -143,15 +134,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     // Check producer.
     auto prod = dyn_cast_or_null<GenericOp>(
         op.getInputOperand(other)->get().getDefiningOp());
-    if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1)
-      return failure();
-    if (!prod.getResult(0).hasOneUse())
+    if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
+        !prod.getResult(0).hasOneUse())
       return failure();
     // Sampling consumer and sum of multiplication chain producer.
     if (!isEmptyInit(op.getOutputOperand(0)) ||
-        !isEmptyInit(prod.getOutputOperand(0)))
-      return failure();
-    if (!isSampling(op) || !isSumOfMul(prod))
+        !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
+        !isSumOfMul(prod))
       return failure();
     // Modify operand structure of producer and consumer.
     Location loc = prod.getLoc();


        


More information about the Mlir-commits mailing list