[Mlir-commits] [mlir] df11a2b - [mlir][sparse] admit un-sparsifiable operations if all its operands are loaded from dense input
Peiming Liu
llvmlistbot at llvm.org
Wed Jun 28 14:27:55 PDT 2023
Author: Peiming Liu
Date: 2023-06-28T21:27:50Z
New Revision: df11a2b41ad490d0ac752c83f12893128c3632c2
URL: https://github.com/llvm/llvm-project/commit/df11a2b41ad490d0ac752c83f12893128c3632c2
DIFF: https://github.com/llvm/llvm-project/commit/df11a2b41ad490d0ac752c83f12893128c3632c2.diff
LOG: [mlir][sparse] admit un-sparsifiable operations if all its operands are loaded from dense input
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D153998
Added:
mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index e166da529c14d..3ea8ce721d6e5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -70,6 +70,9 @@ struct TensorExp final {
/// and kSelect, this holds the original operation with all regions. For
/// kBinaryBranch, this holds the YieldOp for the left or right half
/// to be merged into a nested scf loop.
+ ///
+ /// Or the actual operation that we can not sparsify but having all dense
+ /// operands for kDenseOp.
Operation *op;
/// An optional attribute that is required to determine the semantics of the
@@ -157,8 +160,9 @@ enum class TensorExp::Kind {
kShrS, // signed
kShrU, // unsigned
kShlI,
- kBinary, // semiring binary op
- kReduce, // semiring reduction op
+ kBinary, // semiring binary op
+ kReduce, // semiring reduction op
+ kDenseOp, // special category of operations requiring all dense operands
};
//===----------------------------------------------------------------------===//
@@ -645,7 +649,11 @@ class Merger {
Type inferType(ExprId e, Value src) const;
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
- std::optional<ExprId> buildTensorExp(linalg::GenericOp op, Value v);
+ /// The boolean value returned indicates whether the result of the current
+ /// operation being built depends on any value that is loaded from a sparse
+ /// tensor.
+ std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
+ Value v);
/// Merger data structures.
const TensorId outTensor;
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index fc87c8413c36f..f39a2069a57dd 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -92,6 +92,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
case TensorExp::Kind::kSubI:
case TensorExp::Kind::kCmpF:
case TensorExp::Kind::kCmpI:
+ case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
return ExpArity::kBinary;
}
llvm_unreachable("unexpected kind");
@@ -210,6 +211,11 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
children.e0 = x;
children.e1 = y;
return;
+ case TensorExp::Kind::kDenseOp:
+ assert(x != detail::kInvalidId && !v && o);
+ children.e0 = x;
+ children.e1 = y;
+ return;
}
llvm_unreachable("unexpected kind");
}
@@ -393,7 +399,8 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
Operation *op) {
- assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
+ assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
+ TensorExp::Kind::kDenseOp == kind);
const LatSetId sNew = addSet();
auto &setNew = latSets[sNew];
for (const LatPointId p : set(s0)) {
@@ -546,6 +553,12 @@ bool Merger::hasNegateOnOut(ExprId e) const {
case TensorExp::Kind::kSubI:
return expContainsTensor(expr.children.e1, outTensor) ||
hasNegateOnOut(expr.children.e0);
+ case TensorExp::Kind::kDenseOp: {
+ bool lhsNeg = hasNegateOnOut(expr.children.e0);
+ if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
+ return hasNegateOnOut(expr.children.e1);
+ return lhsNeg;
+ }
default: {
switch (getExpArity(expr.kind)) {
case ExpArity::kNullary:
@@ -646,6 +659,10 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kCmpI:
case TensorExp::Kind::kBinary:
return false;
+ case TensorExp::Kind::kDenseOp:
+ // Since Merger guarantees all the operands of the kDenseOp to be dense, the
+ // operation must be single-condition.
+ return true;
}
llvm_unreachable("unexpected kind");
}
@@ -771,6 +788,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
return "binary";
case TensorExp::Kind::kReduce:
return "reduce";
+ case TensorExp::Kind::kDenseOp:
+ return "dense";
}
llvm_unreachable("unexpected kind for symbol");
}
@@ -857,14 +876,19 @@ void Merger::dumpExp(ExprId e) const {
case TensorExp::Kind::kCmpI:
case TensorExp::Kind::kBinary:
case TensorExp::Kind::kReduce:
+ case TensorExp::Kind::kDenseOp:
llvm::dbgs() << "(";
dumpExp(expr.children.e0);
llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
if (expr.attr)
llvm::dbgs() << "{" << expr.attr << "}";
- llvm::dbgs() << " ";
- dumpExp(expr.children.e1);
- llvm::dbgs() << ")";
+ if (expr.children.e1 != detail::kInvalidId) {
+ llvm::dbgs() << " ";
+ dumpExp(expr.children.e1);
+ llvm::dbgs() << ")";
+ } else {
+ assert(expr.kind == TensorExp::Kind::kDenseOp);
+ }
break;
}
}
@@ -1142,6 +1166,21 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
Operation *const op = expr.op;
return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
}
+ case TensorExp::Kind::kDenseOp: {
+ // It does not really matter whether we use conjunctive/disjunctive set
+ // here, as all the operands of kDenseOp must be dense, the disjunctive set
+ // will be optimized into conjunctive set eventually.
+ if (expr.children.e1 == detail::kInvalidId) {
+ const ExprId e0 = expr.children.e0;
+ Operation *const op = expr.op;
+ return mapSet(kind, buildLattices(e0, i), Value(), op);
+ }
+
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ Operation *const op = expr.op;
+ return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
+ }
}
llvm_unreachable("unexpected expression kind");
}
@@ -1150,7 +1189,7 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
// Build the linalg semantics backward from yield.
Operation *yield = op.getRegion().front().getTerminator();
assert(isa<linalg::YieldOp>(yield));
- return buildTensorExp(op, yield->getOperand(0));
+ return buildTensorExp(op, yield->getOperand(0)).first;
}
/// Only returns false if we are certain this is a nonzero.
@@ -1210,7 +1249,9 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0));
}
-std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
+std::pair<std::optional<ExprId>, bool>
+Merger::buildTensorExp(linalg::GenericOp op, Value v) {
+ // Recursion leaves.
if (auto arg = dyn_cast<BlockArgument>(v)) {
const TensorId tid = makeTensorId(arg.getArgNumber());
// Any argument of the generic op that is not marked as a scalar
@@ -1218,96 +1259,98 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// bounds. This includes rank-0 tensor arguments.
if (arg.getOwner()->getParentOp() == op) {
OpOperand &t = op->getOpOperand(tid);
+ bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
if (!op.isScalar(&t))
- return addTensorExp(tid);
+ return {addTensorExp(tid), hasSpDep};
v = t.get(); // get scalar value
}
// Any other argument (marked as scalar argument for the generic op
// or belonging to an enveloping op) is considered invariant.
- return addInvariantExp(v);
+ return {addInvariantExp(v), /*hasSpDep=*/false};
}
// Something defined outside is invariant.
Operation *def = v.getDefiningOp();
if (def->getBlock() != &op.getRegion().front())
- return addInvariantExp(v);
+ return {addInvariantExp(v), /*hasSpDep=*/false};
// Construct index operations.
if (def->getNumOperands() == 0) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return addLoopVarExp(makeLoopId(indexOp.getDim()));
+ return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
}
+
// Construct unary operations if subexpression can be built.
if (def->getNumOperands() == 1) {
- const auto x = buildTensorExp(op, def->getOperand(0));
+ const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
if (x.has_value()) {
const ExprId e = *x;
if (isa<math::AbsFOp>(def))
- return addExp(TensorExp::Kind::kAbsF, e);
+ return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
if (isa<complex::AbsOp>(def))
- return addExp(TensorExp::Kind::kAbsC, e);
+ return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
if (isa<math::AbsIOp>(def))
- return addExp(TensorExp::Kind::kAbsI, e);
+ return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
if (isa<math::CeilOp>(def))
- return addExp(TensorExp::Kind::kCeilF, e);
+ return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
if (isa<math::FloorOp>(def))
- return addExp(TensorExp::Kind::kFloorF, e);
+ return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
if (isa<math::SqrtOp>(def))
- return addExp(TensorExp::Kind::kSqrtF, e);
+ return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
if (isa<complex::SqrtOp>(def))
- return addExp(TensorExp::Kind::kSqrtC, e);
+ return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
if (isa<math::ExpM1Op>(def))
- return addExp(TensorExp::Kind::kExpm1F, e);
+ return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
if (isa<complex::Expm1Op>(def))
- return addExp(TensorExp::Kind::kExpm1C, e);
+ return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
if (isa<math::Log1pOp>(def))
- return addExp(TensorExp::Kind::kLog1pF, e);
+ return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
if (isa<complex::Log1pOp>(def))
- return addExp(TensorExp::Kind::kLog1pC, e);
+ return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
if (isa<math::SinOp>(def))
- return addExp(TensorExp::Kind::kSinF, e);
+ return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
if (isa<complex::SinOp>(def))
- return addExp(TensorExp::Kind::kSinC, e);
+ return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
if (isa<math::TanhOp>(def))
- return addExp(TensorExp::Kind::kTanhF, e);
+ return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
if (isa<complex::TanhOp>(def))
- return addExp(TensorExp::Kind::kTanhC, e);
+ return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
if (isa<arith::NegFOp>(def))
- return addExp(TensorExp::Kind::kNegF, e); // no negi in std
+ return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
if (isa<complex::NegOp>(def))
- return addExp(TensorExp::Kind::kNegC, e);
+ return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
if (isa<arith::TruncFOp>(def))
- return addExp(TensorExp::Kind::kTruncF, e, v);
+ return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
if (isa<arith::ExtFOp>(def))
- return addExp(TensorExp::Kind::kExtF, e, v);
+ return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
if (isa<arith::FPToSIOp>(def))
- return addExp(TensorExp::Kind::kCastFS, e, v);
+ return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
if (isa<arith::FPToUIOp>(def))
- return addExp(TensorExp::Kind::kCastFU, e, v);
+ return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
if (isa<arith::SIToFPOp>(def))
- return addExp(TensorExp::Kind::kCastSF, e, v);
+ return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
if (isa<arith::UIToFPOp>(def))
- return addExp(TensorExp::Kind::kCastUF, e, v);
+ return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
if (isa<arith::ExtSIOp>(def))
- return addExp(TensorExp::Kind::kCastS, e, v);
+ return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
if (isa<arith::ExtUIOp>(def))
- return addExp(TensorExp::Kind::kCastU, e, v);
+ return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
if (isa<arith::IndexCastOp>(def))
- return addExp(TensorExp::Kind::kCastIdx, e, v);
+ return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
if (isa<arith::TruncIOp>(def))
- return addExp(TensorExp::Kind::kTruncI, e, v);
+ return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
if (isa<complex::ImOp>(def))
- return addExp(TensorExp::Kind::kCIm, e);
+ return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
if (isa<complex::ReOp>(def))
- return addExp(TensorExp::Kind::kCRe, e);
+ return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
if (isa<arith::BitcastOp>(def))
- return addExp(TensorExp::Kind::kBitCast, e, v);
+ return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
isAdmissibleBranch(unop, unop.getAbsentRegion()))
- return addExp(TensorExp::Kind::kUnary, e, Value(), def);
+ return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
}
if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
if (isAdmissibleBranch(selop, selop.getRegion()))
- return addExp(TensorExp::Kind::kSelect, e, Value(), def);
+ return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
}
}
}
@@ -1315,49 +1358,50 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
if (def->getNumOperands() == 2) {
- const auto x = buildTensorExp(op, def->getOperand(0));
- const auto y = buildTensorExp(op, def->getOperand(1));
+ const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
+ const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
+ bool hasSpDep = xDepSp || yDepSp;
if (x.has_value() && y.has_value()) {
const ExprId e0 = *x;
const ExprId e1 = *y;
if (isa<arith::MulFOp>(def))
- return addExp(TensorExp::Kind::kMulF, e0, e1);
+ return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
if (isa<complex::MulOp>(def))
- return addExp(TensorExp::Kind::kMulC, e0, e1);
+ return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
if (isa<arith::MulIOp>(def))
- return addExp(TensorExp::Kind::kMulI, e0, e1);
+ return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
- return addExp(TensorExp::Kind::kDivF, e0, e1);
+ return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
if (isa<complex::DivOp>(def) && !maybeZero(e1))
- return addExp(TensorExp::Kind::kDivC, e0, e1);
+ return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
- return addExp(TensorExp::Kind::kDivS, e0, e1);
+ return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
- return addExp(TensorExp::Kind::kDivU, e0, e1);
+ return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
if (isa<arith::AddFOp>(def))
- return addExp(TensorExp::Kind::kAddF, e0, e1);
+ return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
if (isa<complex::AddOp>(def))
- return addExp(TensorExp::Kind::kAddC, e0, e1);
+ return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
if (isa<arith::AddIOp>(def))
- return addExp(TensorExp::Kind::kAddI, e0, e1);
+ return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
if (isa<arith::SubFOp>(def))
- return addExp(TensorExp::Kind::kSubF, e0, e1);
+ return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
if (isa<complex::SubOp>(def))
- return addExp(TensorExp::Kind::kSubC, e0, e1);
+ return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
if (isa<arith::SubIOp>(def))
- return addExp(TensorExp::Kind::kSubI, e0, e1);
+ return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
if (isa<arith::AndIOp>(def))
- return addExp(TensorExp::Kind::kAndI, e0, e1);
+ return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
if (isa<arith::OrIOp>(def))
- return addExp(TensorExp::Kind::kOrI, e0, e1);
+ return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
if (isa<arith::XOrIOp>(def))
- return addExp(TensorExp::Kind::kXorI, e0, e1);
+ return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
- return addExp(TensorExp::Kind::kShrS, e0, e1);
+ return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
- return addExp(TensorExp::Kind::kShrU, e0, e1);
+ return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
- return addExp(TensorExp::Kind::kShlI, e0, e1);
+ return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
if (ci.getPredicate() == arith::CmpIPredicate::eq &&
ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1366,11 +1410,12 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
ci.getPredicate() == arith::CmpIPredicate::uge) {
// We can not sparsify comparison with equal, this is because 0 <= 0
// yields true, and thus densifies the result.
- return std::nullopt;
+ return {std::nullopt, false};
}
- return addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
- ci.getPredicateAttr());
+ auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
+ ci.getPredicateAttr());
+ return {e, hasSpDep};
}
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1384,10 +1429,11 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
cf.getPredicate() == arith::CmpFPredicate::UNO) {
// We can not sparsify comparison with equal, this is because 0 <= 0
// yields true, and thus densifies the result.
- return std::nullopt;
+ return {std::nullopt, false};
}
- return addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
- cf.getPredicateAttr());
+ auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
+ cf.getPredicateAttr());
+ return {e, hasSpDep};
}
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1395,26 +1441,54 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
(binop.getRightIdentity() ||
isAdmissibleBranch(binop, binop.getRightRegion())))
- return addExp(TensorExp::Kind::kBinary, e0, e1, def);
+ return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
}
}
}
// Construct ternary operations if subexpressions can be built.
if (def->getNumOperands() == 3) {
- const auto x = buildTensorExp(op, def->getOperand(0));
- const auto y = buildTensorExp(op, def->getOperand(1));
- const auto z = buildTensorExp(op, def->getOperand(2));
+ const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
+ const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
+ const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
+ bool hasSpDep = xDepSp || yDepSp || zDepSp;
if (x.has_value() && y.has_value() && z.has_value()) {
const ExprId e0 = *x;
const ExprId e1 = *y;
if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
if (isAdmissibleBranch(redop, redop.getRegion()))
- return addExp(TensorExp::Kind::kReduce, e0, e1, def);
+ return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
}
}
}
+
+ // If we reach here, we are dealing with an operation that is not currently
+ // sparsifiable. We can still generate code for it if all its operands only
+ // have dense dependencies (i.e., all the values are loaded from dense
+ // tensors).
+ if (def->getNumResults() != 1) // only handle single result operation.
+ return {std::nullopt, false};
+
+ SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
+ // Builds all the sub-expressions
+ for (Value operand : def->getOperands())
+ subExp.push_back(buildTensorExp(op, operand));
+
+ if (llvm::all_of(subExp,
+ [](auto e) { return e.first.has_value() && !e.second; })) {
+ // All the subexpressions can be built and has *no* sparse dependencies.
+ if (subExp.size() == 2) {
+ auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
+ *subExp[1].first, def);
+ return {e, false};
+ }
+ if (subExp.size() == 1) {
+ auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
+ detail::kInvalidId, def);
+ return {e, false};
+ }
+ }
// Cannot build.
- return std::nullopt;
+ return {std::nullopt, false};
}
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
@@ -1609,6 +1683,14 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
ReduceOp redOp = cast<ReduceOp>(expr.op);
return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
}
+ case TensorExp::Kind::kDenseOp: {
+ Operation *actualOp = expr.op;
+ IRMapping mapping;
+ mapping.map(actualOp->getOperand(0), v0);
+ if (actualOp->getNumOperands() == 2)
+ mapping.map(actualOp->getOperand(1), v1);
+ return rewriter.clone(*actualOp, mapping)->getResult(0);
+ }
}
llvm_unreachable("unexpected expression kind in build");
}
diff --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
new file mode 100644
index 0000000000000..4bb66642ffd61
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#trait = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+#VEC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 32, crdWidth = 32 }>
+#COO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ], posWidth = 32, crdWidth = 32 }>
+#CCC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ], posWidth = 32, crdWidth = 32 }>
+
+//
+// This kernel can be sparsified as all unsparsifiable operations'
+// operands are loaded from dense tensors.
+//
+// CHECK-LABEL: func @dense_op_without_sp_dep
+// CHECK-NOT: linalg.generic {{.*}}
+func.func @dense_op_without_sp_dep(%169: tensor<2x10x8xf32>,
+ %expanded_54: tensor<2x10x1xf32>,
+ %expanded_56: tensor<2x10x1xf32>,
+ %expanded_57: tensor<2x10x1xf32>,
+ %176: tensor<8xf32, #VEC>,
+ %177: tensor<8xf32, #VEC>,
+ %9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> {
+ %cst_13 = arith.constant -3.40282347E+38 : f32
+ %178 = tensor.empty() : tensor<2x10x100xf32>
+ %179 = linalg.generic #trait
+ ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
+ tensor<2x10x8xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>,
+ tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
+ outs(%178 : tensor<2x10x100xf32>) {
+ ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
+ %180 = arith.mulf %in_60, %in_60 : f32
+ %181 = arith.mulf %in_59, %cst_13 : f32
+ %182 = arith.subf %181, %180 : f32
+ %183 = arith.maxf %182, %cst_13 : f32
+ %184 = arith.addf %183, %cst_13 : f32
+ %185 = math.rsqrt %184 : f32 // data dependent on sparse value.
+ %186 = arith.mulf %185, %in_61 : f32
+ %187 = arith.subf %in, %in_58 : f32
+ %188 = arith.mulf %187, %186 : f32
+ %189 = arith.addf %188, %in_62 : f32
+ %190 = arith.mulf %189, %in_63 : f32
+ %191 = arith.addf %out, %190 : f32
+ linalg.yield %191 : f32
+ } -> tensor<2x10x100xf32>
+ return %179 : tensor<2x10x100xf32>
+}
+
+//
+// This kernel cannot be sparsified as some unsparsifiable operations'
+// operands are loaded from sparse tensors.
+//
+// CHECK-LABEL: func @dense_op_with_sp_dep
+// CHECK: linalg.generic {{.*}}
+func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>,
+ %expanded_54: tensor<2x10x1xf32, #CCC>,
+ %expanded_56: tensor<2x10x1xf32, #CCC>,
+ %expanded_57: tensor<2x10x1xf32, #CCC>,
+ %176: tensor<8xf32, #VEC>,
+ %177: tensor<8xf32, #VEC>,
+ %9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> {
+ %cst_13 = arith.constant -3.40282347E+38 : f32
+ %178 = tensor.empty() : tensor<2x10x100xf32>
+ %179 = linalg.generic #trait
+ ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
+ tensor<2x10x8xf32>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>,
+ tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
+ outs(%178 : tensor<2x10x100xf32>) {
+ ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
+ %180 = arith.mulf %in_60, %in_60 : f32
+ %181 = arith.mulf %in_59, %cst_13 : f32
+ %182 = arith.subf %181, %180 : f32
+ %183 = arith.maxf %182, %cst_13 : f32
+ %184 = arith.addf %183, %cst_13 : f32
+ %185 = math.rsqrt %184 : f32
+ %186 = arith.mulf %185, %in_61 : f32
+ %187 = arith.subf %in, %in_58 : f32
+ %188 = arith.mulf %187, %186 : f32
+ %189 = arith.addf %188, %in_62 : f32
+ %190 = arith.mulf %189, %in_63 : f32
+ %191 = arith.addf %out, %190 : f32
+ linalg.yield %191 : f32
+ } -> tensor<2x10x100xf32>
+ return %179 : tensor<2x10x100xf32>
+}
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index b854ce8d7aa8a..00760c02bb63e 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -305,6 +305,12 @@ class MergerTestBase : public ::testing::Test {
case TensorExp::Kind::kReduce:
return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
compareExpression(tensorExp.children.e1, pattern.children.e1);
+ case TensorExp::Kind::kDenseOp: {
+ bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
+ if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
+ return compareExpression(tensorExp.children.e1, pattern.children.e1);
+ return eq;
+ }
}
llvm_unreachable("unexpected kind");
}
More information about the Mlir-commits
mailing list