[Mlir-commits] [mlir] df11a2b - [mlir][sparse] admit un-sparsifiable operations if all its operands are loaded from dense input

Peiming Liu llvmlistbot at llvm.org
Wed Jun 28 14:27:55 PDT 2023


Author: Peiming Liu
Date: 2023-06-28T21:27:50Z
New Revision: df11a2b41ad490d0ac752c83f12893128c3632c2

URL: https://github.com/llvm/llvm-project/commit/df11a2b41ad490d0ac752c83f12893128c3632c2
DIFF: https://github.com/llvm/llvm-project/commit/df11a2b41ad490d0ac752c83f12893128c3632c2.diff

LOG: [mlir][sparse] admit un-sparsifiable operations if all its operands are loaded from dense input

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D153998

Added: 
    mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index e166da529c14d..3ea8ce721d6e5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -70,6 +70,9 @@ struct TensorExp final {
   /// and kSelect, this holds the original operation with all regions. For
   /// kBinaryBranch, this holds the YieldOp for the left or right half
   /// to be merged into a nested scf loop.
+  ///
+  /// Or the actual operation that we can not sparsify but having all dense
+  /// operands for kDenseOp.
   Operation *op;
 
   /// An optional attribute that is required to determine the semantics of the
@@ -157,8 +160,9 @@ enum class TensorExp::Kind {
   kShrS, // signed
   kShrU, // unsigned
   kShlI,
-  kBinary, // semiring binary op
-  kReduce, // semiring reduction op
+  kBinary,  // semiring binary op
+  kReduce,  // semiring reduction op
+  kDenseOp, // special category of operations requiring all dense operands
 };
 
 //===----------------------------------------------------------------------===//
@@ -645,7 +649,11 @@ class Merger {
   Type inferType(ExprId e, Value src) const;
 
   /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
-  std::optional<ExprId> buildTensorExp(linalg::GenericOp op, Value v);
+  /// The boolean value returned indicates whether the result of the current
+  /// operation being built depends on any value that is loaded from a sparse
+  /// tensor.
+  std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
+                                                        Value v);
 
   /// Merger data structures.
   const TensorId outTensor;

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index fc87c8413c36f..f39a2069a57dd 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -92,6 +92,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
   case TensorExp::Kind::kSubI:
   case TensorExp::Kind::kCmpF:
   case TensorExp::Kind::kCmpI:
+  case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
     return ExpArity::kBinary;
   }
   llvm_unreachable("unexpected kind");
@@ -210,6 +211,11 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
     children.e0 = x;
     children.e1 = y;
     return;
+  case TensorExp::Kind::kDenseOp:
+    assert(x != detail::kInvalidId && !v && o);
+    children.e0 = x;
+    children.e1 = y;
+    return;
   }
   llvm_unreachable("unexpected kind");
 }
@@ -393,7 +399,8 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
 
 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
                         Operation *op) {
-  assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
+  assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
+         TensorExp::Kind::kDenseOp == kind);
   const LatSetId sNew = addSet();
   auto &setNew = latSets[sNew];
   for (const LatPointId p : set(s0)) {
@@ -546,6 +553,12 @@ bool Merger::hasNegateOnOut(ExprId e) const {
   case TensorExp::Kind::kSubI:
     return expContainsTensor(expr.children.e1, outTensor) ||
            hasNegateOnOut(expr.children.e0);
+  case TensorExp::Kind::kDenseOp: {
+    bool lhsNeg = hasNegateOnOut(expr.children.e0);
+    if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
+      return hasNegateOnOut(expr.children.e1);
+    return lhsNeg;
+  }
   default: {
     switch (getExpArity(expr.kind)) {
     case ExpArity::kNullary:
@@ -646,6 +659,10 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   case TensorExp::Kind::kCmpI:
   case TensorExp::Kind::kBinary:
     return false;
+  case TensorExp::Kind::kDenseOp:
+    // Since Merger guarantees all the operands of the kDenseOp to be dense, the
+    // operation must be single-condition.
+    return true;
   }
   llvm_unreachable("unexpected kind");
 }
@@ -771,6 +788,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
     return "binary";
   case TensorExp::Kind::kReduce:
     return "reduce";
+  case TensorExp::Kind::kDenseOp:
+    return "dense";
   }
   llvm_unreachable("unexpected kind for symbol");
 }
@@ -857,14 +876,19 @@ void Merger::dumpExp(ExprId e) const {
   case TensorExp::Kind::kCmpI:
   case TensorExp::Kind::kBinary:
   case TensorExp::Kind::kReduce:
+  case TensorExp::Kind::kDenseOp:
     llvm::dbgs() << "(";
     dumpExp(expr.children.e0);
     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
     if (expr.attr)
       llvm::dbgs() << "{" << expr.attr << "}";
-    llvm::dbgs() << " ";
-    dumpExp(expr.children.e1);
-    llvm::dbgs() << ")";
+    if (expr.children.e1 != detail::kInvalidId) {
+      llvm::dbgs() << " ";
+      dumpExp(expr.children.e1);
+      llvm::dbgs() << ")";
+    } else {
+      assert(expr.kind == TensorExp::Kind::kDenseOp);
+    }
     break;
   }
 }
@@ -1142,6 +1166,21 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       Operation *const op = expr.op;
       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
     }
+  case TensorExp::Kind::kDenseOp: {
+    // It does not really matter whether we use conjunctive/disjunctive set
+    // here, as all the operands of kDenseOp must be dense, the disjunctive set
+    // will be optimized into conjunctive set eventually.
+    if (expr.children.e1 == detail::kInvalidId) {
+      const ExprId e0 = expr.children.e0;
+      Operation *const op = expr.op;
+      return mapSet(kind, buildLattices(e0, i), Value(), op);
+    }
+
+    const ExprId e0 = expr.children.e0;
+    const ExprId e1 = expr.children.e1;
+    Operation *const op = expr.op;
+    return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
+  }
   }
   llvm_unreachable("unexpected expression kind");
 }
@@ -1150,7 +1189,7 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
   // Build the linalg semantics backward from yield.
   Operation *yield = op.getRegion().front().getTerminator();
   assert(isa<linalg::YieldOp>(yield));
-  return buildTensorExp(op, yield->getOperand(0));
+  return buildTensorExp(op, yield->getOperand(0)).first;
 }
 
 /// Only returns false if we are certain this is a nonzero.
@@ -1210,7 +1249,9 @@ static bool isAdmissibleBranch(Operation *op, Region &region) {
   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
 }
 
-std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
+std::pair<std::optional<ExprId>, bool>
+Merger::buildTensorExp(linalg::GenericOp op, Value v) {
+  // Recursion leaves.
   if (auto arg = dyn_cast<BlockArgument>(v)) {
     const TensorId tid = makeTensorId(arg.getArgNumber());
     // Any argument of the generic op that is not marked as a scalar
@@ -1218,96 +1259,98 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     // bounds. This includes rank-0 tensor arguments.
     if (arg.getOwner()->getParentOp() == op) {
       OpOperand &t = op->getOpOperand(tid);
+      bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
       if (!op.isScalar(&t))
-        return addTensorExp(tid);
+        return {addTensorExp(tid), hasSpDep};
       v = t.get(); // get scalar value
     }
     // Any other argument (marked as scalar argument for the generic op
     // or belonging to an enveloping op) is considered invariant.
-    return addInvariantExp(v);
+    return {addInvariantExp(v), /*hasSpDep=*/false};
   }
   // Something defined outside is invariant.
   Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.getRegion().front())
-    return addInvariantExp(v);
+    return {addInvariantExp(v), /*hasSpDep=*/false};
   // Construct index operations.
   if (def->getNumOperands() == 0) {
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
-      return addLoopVarExp(makeLoopId(indexOp.getDim()));
+      return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
   }
+
   // Construct unary operations if subexpression can be built.
   if (def->getNumOperands() == 1) {
-    const auto x = buildTensorExp(op, def->getOperand(0));
+    const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
     if (x.has_value()) {
       const ExprId e = *x;
       if (isa<math::AbsFOp>(def))
-        return addExp(TensorExp::Kind::kAbsF, e);
+        return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
       if (isa<complex::AbsOp>(def))
-        return addExp(TensorExp::Kind::kAbsC, e);
+        return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
       if (isa<math::AbsIOp>(def))
-        return addExp(TensorExp::Kind::kAbsI, e);
+        return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
       if (isa<math::CeilOp>(def))
-        return addExp(TensorExp::Kind::kCeilF, e);
+        return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
       if (isa<math::FloorOp>(def))
-        return addExp(TensorExp::Kind::kFloorF, e);
+        return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
       if (isa<math::SqrtOp>(def))
-        return addExp(TensorExp::Kind::kSqrtF, e);
+        return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
       if (isa<complex::SqrtOp>(def))
-        return addExp(TensorExp::Kind::kSqrtC, e);
+        return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
       if (isa<math::ExpM1Op>(def))
-        return addExp(TensorExp::Kind::kExpm1F, e);
+        return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
       if (isa<complex::Expm1Op>(def))
-        return addExp(TensorExp::Kind::kExpm1C, e);
+        return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
       if (isa<math::Log1pOp>(def))
-        return addExp(TensorExp::Kind::kLog1pF, e);
+        return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
       if (isa<complex::Log1pOp>(def))
-        return addExp(TensorExp::Kind::kLog1pC, e);
+        return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
       if (isa<math::SinOp>(def))
-        return addExp(TensorExp::Kind::kSinF, e);
+        return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
       if (isa<complex::SinOp>(def))
-        return addExp(TensorExp::Kind::kSinC, e);
+        return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
       if (isa<math::TanhOp>(def))
-        return addExp(TensorExp::Kind::kTanhF, e);
+        return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
       if (isa<complex::TanhOp>(def))
-        return addExp(TensorExp::Kind::kTanhC, e);
+        return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
       if (isa<arith::NegFOp>(def))
-        return addExp(TensorExp::Kind::kNegF, e); // no negi in std
+        return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
       if (isa<complex::NegOp>(def))
-        return addExp(TensorExp::Kind::kNegC, e);
+        return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
       if (isa<arith::TruncFOp>(def))
-        return addExp(TensorExp::Kind::kTruncF, e, v);
+        return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
       if (isa<arith::ExtFOp>(def))
-        return addExp(TensorExp::Kind::kExtF, e, v);
+        return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
       if (isa<arith::FPToSIOp>(def))
-        return addExp(TensorExp::Kind::kCastFS, e, v);
+        return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
       if (isa<arith::FPToUIOp>(def))
-        return addExp(TensorExp::Kind::kCastFU, e, v);
+        return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
       if (isa<arith::SIToFPOp>(def))
-        return addExp(TensorExp::Kind::kCastSF, e, v);
+        return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
       if (isa<arith::UIToFPOp>(def))
-        return addExp(TensorExp::Kind::kCastUF, e, v);
+        return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
       if (isa<arith::ExtSIOp>(def))
-        return addExp(TensorExp::Kind::kCastS, e, v);
+        return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
       if (isa<arith::ExtUIOp>(def))
-        return addExp(TensorExp::Kind::kCastU, e, v);
+        return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
       if (isa<arith::IndexCastOp>(def))
-        return addExp(TensorExp::Kind::kCastIdx, e, v);
+        return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
       if (isa<arith::TruncIOp>(def))
-        return addExp(TensorExp::Kind::kTruncI, e, v);
+        return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
       if (isa<complex::ImOp>(def))
-        return addExp(TensorExp::Kind::kCIm, e);
+        return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
       if (isa<complex::ReOp>(def))
-        return addExp(TensorExp::Kind::kCRe, e);
+        return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
       if (isa<arith::BitcastOp>(def))
-        return addExp(TensorExp::Kind::kBitCast, e, v);
+        return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
             isAdmissibleBranch(unop, unop.getAbsentRegion()))
-          return addExp(TensorExp::Kind::kUnary, e, Value(), def);
+          return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
       }
       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
         if (isAdmissibleBranch(selop, selop.getRegion()))
-          return addExp(TensorExp::Kind::kSelect, e, Value(), def);
+          return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
       }
     }
   }
@@ -1315,49 +1358,50 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // See buildLattices() for an explanation of rejecting certain
   // division and shift operations.
   if (def->getNumOperands() == 2) {
-    const auto x = buildTensorExp(op, def->getOperand(0));
-    const auto y = buildTensorExp(op, def->getOperand(1));
+    const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
+    const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
+    bool hasSpDep = xDepSp || yDepSp;
     if (x.has_value() && y.has_value()) {
       const ExprId e0 = *x;
       const ExprId e1 = *y;
       if (isa<arith::MulFOp>(def))
-        return addExp(TensorExp::Kind::kMulF, e0, e1);
+        return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
       if (isa<complex::MulOp>(def))
-        return addExp(TensorExp::Kind::kMulC, e0, e1);
+        return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
       if (isa<arith::MulIOp>(def))
-        return addExp(TensorExp::Kind::kMulI, e0, e1);
+        return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
-        return addExp(TensorExp::Kind::kDivF, e0, e1);
+        return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
       if (isa<complex::DivOp>(def) && !maybeZero(e1))
-        return addExp(TensorExp::Kind::kDivC, e0, e1);
+        return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
-        return addExp(TensorExp::Kind::kDivS, e0, e1);
+        return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
-        return addExp(TensorExp::Kind::kDivU, e0, e1);
+        return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
       if (isa<arith::AddFOp>(def))
-        return addExp(TensorExp::Kind::kAddF, e0, e1);
+        return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
       if (isa<complex::AddOp>(def))
-        return addExp(TensorExp::Kind::kAddC, e0, e1);
+        return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
       if (isa<arith::AddIOp>(def))
-        return addExp(TensorExp::Kind::kAddI, e0, e1);
+        return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
       if (isa<arith::SubFOp>(def))
-        return addExp(TensorExp::Kind::kSubF, e0, e1);
+        return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
       if (isa<complex::SubOp>(def))
-        return addExp(TensorExp::Kind::kSubC, e0, e1);
+        return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
       if (isa<arith::SubIOp>(def))
-        return addExp(TensorExp::Kind::kSubI, e0, e1);
+        return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
       if (isa<arith::AndIOp>(def))
-        return addExp(TensorExp::Kind::kAndI, e0, e1);
+        return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
       if (isa<arith::OrIOp>(def))
-        return addExp(TensorExp::Kind::kOrI, e0, e1);
+        return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
       if (isa<arith::XOrIOp>(def))
-        return addExp(TensorExp::Kind::kXorI, e0, e1);
+        return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
-        return addExp(TensorExp::Kind::kShrS, e0, e1);
+        return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
-        return addExp(TensorExp::Kind::kShrU, e0, e1);
+        return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
-        return addExp(TensorExp::Kind::kShlI, e0, e1);
+        return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
             ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1366,11 +1410,12 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
             ci.getPredicate() == arith::CmpIPredicate::uge) {
           // We can not sparsify comparison with equal, this is because 0 <= 0
           // yields true, and thus densifies the result.
-          return std::nullopt;
+          return {std::nullopt, false};
         }
 
-        return addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
-                      ci.getPredicateAttr());
+        auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
+                        ci.getPredicateAttr());
+        return {e, hasSpDep};
       }
       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1384,10 +1429,11 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
             cf.getPredicate() == arith::CmpFPredicate::UNO) {
           // We can not sparsify comparison with equal, this is because 0 <= 0
           // yields true, and thus densifies the result.
-          return std::nullopt;
+          return {std::nullopt, false};
         }
-        return addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
-                      cf.getPredicateAttr());
+        auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
+                        cf.getPredicateAttr());
+        return {e, hasSpDep};
       }
       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1395,26 +1441,54 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
             (binop.getRightIdentity() ||
              isAdmissibleBranch(binop, binop.getRightRegion())))
-          return addExp(TensorExp::Kind::kBinary, e0, e1, def);
+          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
       }
     }
   }
   // Construct ternary operations if subexpressions can be built.
   if (def->getNumOperands() == 3) {
-    const auto x = buildTensorExp(op, def->getOperand(0));
-    const auto y = buildTensorExp(op, def->getOperand(1));
-    const auto z = buildTensorExp(op, def->getOperand(2));
+    const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
+    const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
+    const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
+    bool hasSpDep = xDepSp || yDepSp || zDepSp;
     if (x.has_value() && y.has_value() && z.has_value()) {
       const ExprId e0 = *x;
       const ExprId e1 = *y;
       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
         if (isAdmissibleBranch(redop, redop.getRegion()))
-          return addExp(TensorExp::Kind::kReduce, e0, e1, def);
+          return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
       }
     }
   }
+
+  // If we reach here, we are dealing with an operation that is not currently
+  // sparsifiable. We can still generate code for it if all its operands only
+  // have dense dependencies (i.e., all the values are loaded from dense
+  // tensors).
+  if (def->getNumResults() != 1) // only handle single result operation.
+    return {std::nullopt, false};
+
+  SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
+  // Builds all the sub-expressions
+  for (Value operand : def->getOperands())
+    subExp.push_back(buildTensorExp(op, operand));
+
+  if (llvm::all_of(subExp,
+                   [](auto e) { return e.first.has_value() && !e.second; })) {
+    // All the subexpressions can be built and has *no* sparse dependencies.
+    if (subExp.size() == 2) {
+      auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
+                      *subExp[1].first, def);
+      return {e, false};
+    }
+    if (subExp.size() == 1) {
+      auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
+                      detail::kInvalidId, def);
+      return {e, false};
+    }
+  }
   // Cannot build.
-  return std::nullopt;
+  return {std::nullopt, false};
 }
 
 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
@@ -1609,6 +1683,14 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
     ReduceOp redOp = cast<ReduceOp>(expr.op);
     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
   }
+  case TensorExp::Kind::kDenseOp: {
+    Operation *actualOp = expr.op;
+    IRMapping mapping;
+    mapping.map(actualOp->getOperand(0), v0);
+    if (actualOp->getNumOperands() == 2)
+      mapping.map(actualOp->getOperand(1), v1);
+    return rewriter.clone(*actualOp, mapping)->getResult(0);
+  }
   }
   llvm_unreachable("unexpected expression kind in build");
 }

diff  --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
new file mode 100644
index 0000000000000..4bb66642ffd61
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#trait = {
+  indexing_maps = [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
+    affine_map<(d0, d1, d2, d3) -> (d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+#VEC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 32, crdWidth = 32 }>
+#COO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ], posWidth = 32, crdWidth = 32 }>
+#CCC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ], posWidth = 32, crdWidth = 32 }>
+
+//
+// This kernel can be sparsified as all unsparsifiable operations'
+// operands are loaded from dense tensors.
+//
+// CHECK-LABEL: func @dense_op_without_sp_dep
+// CHECK-NOT:   linalg.generic {{.*}}
+func.func @dense_op_without_sp_dep(%169: tensor<2x10x8xf32>,
+                                   %expanded_54: tensor<2x10x1xf32>,
+                                   %expanded_56: tensor<2x10x1xf32>,
+                                   %expanded_57: tensor<2x10x1xf32>,
+                                   %176: tensor<8xf32, #VEC>,
+                                   %177: tensor<8xf32, #VEC>,
+                                   %9: tensor<100x8xf32, #COO>) ->  tensor<2x10x100xf32> {
+    %cst_13 = arith.constant -3.40282347E+38 : f32
+    %178 = tensor.empty() : tensor<2x10x100xf32>
+    %179 = linalg.generic #trait
+    ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
+        tensor<2x10x8xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>,
+        tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
+    outs(%178 : tensor<2x10x100xf32>) {
+    ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
+      %180 = arith.mulf %in_60, %in_60 : f32
+      %181 = arith.mulf %in_59, %cst_13 : f32
+      %182 = arith.subf %181, %180 : f32
+      %183 = arith.maxf %182, %cst_13 : f32
+      %184 = arith.addf %183, %cst_13 : f32
+      %185 = math.rsqrt %184 : f32 // data dependent on sparse value.
+      %186 = arith.mulf %185, %in_61 : f32
+      %187 = arith.subf %in, %in_58 : f32
+      %188 = arith.mulf %187, %186 : f32
+      %189 = arith.addf %188, %in_62 : f32
+      %190 = arith.mulf %189, %in_63 : f32
+      %191 = arith.addf %out, %190 : f32
+      linalg.yield %191 : f32
+    } -> tensor<2x10x100xf32>
+   return %179 : tensor<2x10x100xf32>
+}
+
+//
+// This kernel cannot be sparsified as some unsparsifiable operations'
+// operands are loaded from sparse tensors.
+//
+// CHECK-LABEL: func @dense_op_with_sp_dep
+// CHECK:       linalg.generic {{.*}}
+func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>,
+                                %expanded_54: tensor<2x10x1xf32, #CCC>,
+                                %expanded_56: tensor<2x10x1xf32, #CCC>,
+                                %expanded_57: tensor<2x10x1xf32, #CCC>,
+                                %176: tensor<8xf32, #VEC>,
+                                %177: tensor<8xf32, #VEC>,
+                                %9: tensor<100x8xf32, #COO>) ->  tensor<2x10x100xf32> {
+    %cst_13 = arith.constant -3.40282347E+38 : f32
+    %178 = tensor.empty() : tensor<2x10x100xf32>
+    %179 = linalg.generic #trait
+    ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
+        tensor<2x10x8xf32>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>,
+        tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
+    outs(%178 : tensor<2x10x100xf32>) {
+    ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
+      %180 = arith.mulf %in_60, %in_60 : f32
+      %181 = arith.mulf %in_59, %cst_13 : f32
+      %182 = arith.subf %181, %180 : f32
+      %183 = arith.maxf %182, %cst_13 : f32
+      %184 = arith.addf %183, %cst_13 : f32
+      %185 = math.rsqrt %184 : f32
+      %186 = arith.mulf %185, %in_61 : f32
+      %187 = arith.subf %in, %in_58 : f32
+      %188 = arith.mulf %187, %186 : f32
+      %189 = arith.addf %188, %in_62 : f32
+      %190 = arith.mulf %189, %in_63 : f32
+      %191 = arith.addf %out, %190 : f32
+      linalg.yield %191 : f32
+    } -> tensor<2x10x100xf32>
+   return %179 : tensor<2x10x100xf32>
+}

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index b854ce8d7aa8a..00760c02bb63e 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -305,6 +305,12 @@ class MergerTestBase : public ::testing::Test {
     case TensorExp::Kind::kReduce:
       return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
              compareExpression(tensorExp.children.e1, pattern.children.e1);
+    case TensorExp::Kind::kDenseOp: {
+      bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
+      if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
+        return compareExpression(tensorExp.children.e1, pattern.children.e1);
+      return eq;
+    }
     }
     llvm_unreachable("unexpected kind");
   }


        


More information about the Mlir-commits mailing list