[Mlir-commits] [mlir] 0d4e7fb - [mlir][sparse] minor zero test refactoring in rewriting
Aart Bik
llvmlistbot at llvm.org
Wed Sep 7 10:07:22 PDT 2022
Author: Aart Bik
Date: 2022-09-07T10:07:11-07:00
New Revision: 0d4e7fba9a4722bef66d299d89a5546cfbf1790a
URL: https://github.com/llvm/llvm-project/commit/0d4e7fba9a4722bef66d299d89a5546cfbf1790a
DIFF: https://github.com/llvm/llvm-project/commit/0d4e7fba9a4722bef66d299d89a5546cfbf1790a.diff
LOG: [mlir][sparse] minor zero test refactoring in rewriting
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D133382
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 9adfacebda0d7..a8c08831e2b4d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//
+// Helper method to match any typed zero.
+static bool isZeroValue(Value val) {
+ return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
+}
+
// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(OpOperand *op) {
if (auto enc = getSparseTensorEncoding(op->get().getType())) {
@@ -47,8 +52,7 @@ static bool isAlloc(OpOperand *op, bool isZero) {
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
Value copy = alloc.getCopy();
if (isZero)
- return copy && (matchPattern(copy, m_Zero()) ||
- matchPattern(copy, m_AnyZeroFloat()));
+ return copy && isZeroValue(copy);
return !copy;
}
return false;
@@ -100,13 +104,10 @@ static bool isZeroYield(GenericOp op) {
if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
if (arg.getOwner()->getParentOp() == op) {
OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
- return matchPattern(t->get(), m_Zero()) ||
- matchPattern(t->get(), m_AnyZeroFloat());
+ return isZeroValue(t->get());
}
- } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
- return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
}
- return false;
+ return isZeroValue(yieldOp.getOperand(0));
}
//===---------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list