[Mlir-commits] [mlir] 0d4e7fb - [mlir][sparse] minor zero test refactoring in rewriting

Aart Bik llvmlistbot at llvm.org
Wed Sep 7 10:07:22 PDT 2022


Author: Aart Bik
Date: 2022-09-07T10:07:11-07:00
New Revision: 0d4e7fba9a4722bef66d299d89a5546cfbf1790a

URL: https://github.com/llvm/llvm-project/commit/0d4e7fba9a4722bef66d299d89a5546cfbf1790a
DIFF: https://github.com/llvm/llvm-project/commit/0d4e7fba9a4722bef66d299d89a5546cfbf1790a.diff

LOG: [mlir][sparse] minor zero test refactoring in rewriting

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133382

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 9adfacebda0d7..a8c08831e2b4d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
 // Helper methods for the actual rewriting rules.
 //===---------------------------------------------------------------------===//
 
+// Helper method to match any typed zero.
+static bool isZeroValue(Value val) {
+  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
+}
+
 // Helper to detect a sparse tensor type operand.
 static bool isSparseTensor(OpOperand *op) {
   if (auto enc = getSparseTensorEncoding(op->get().getType())) {
@@ -47,8 +52,7 @@ static bool isAlloc(OpOperand *op, bool isZero) {
   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
     Value copy = alloc.getCopy();
     if (isZero)
-      return copy && (matchPattern(copy, m_Zero()) ||
-                      matchPattern(copy, m_AnyZeroFloat()));
+      return copy && isZeroValue(copy);
     return !copy;
   }
   return false;
@@ -100,13 +104,10 @@ static bool isZeroYield(GenericOp op) {
   if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
     if (arg.getOwner()->getParentOp() == op) {
       OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
-      return matchPattern(t->get(), m_Zero()) ||
-             matchPattern(t->get(), m_AnyZeroFloat());
+      return isZeroValue(t->get());
     }
-  } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
-    return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
   }
-  return false;
+  return isZeroValue(yieldOp.getOperand(0));
 }
 
 //===---------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list