[Mlir-commits] [mlir] 266a741 - [mlir][sparse] move tensor expression builder into Merger utility

Aart Bik llvmlistbot at llvm.org
Thu Jul 1 09:27:49 PDT 2021


Author: Aart Bik
Date: 2021-07-01T09:27:40-07:00
New Revision: 266a7414d8f2643be2b1dad86693b12a9f1246fa

URL: https://github.com/llvm/llvm-project/commit/266a7414d8f2643be2b1dad86693b12a9f1246fa
DIFF: https://github.com/llvm/llvm-project/commit/266a7414d8f2643be2b1dad86693b12a9f1246fa.diff

LOG: [mlir][sparse] move tensor expression builder into Merger utility

Rationale:
Follow-up on migrating lattice and tensor expression related methods into the new utility.
This also prepares the next step of generalizing the op kinds that are handled.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D105219

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index cbb0aede83f81..d087e98ac42f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/BitVector.h"
 
@@ -148,11 +149,6 @@ class Merger {
   /// Returns true if any set bit corresponds to queried dim.
   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
 
-  /// Builds the iteration lattices in a bottom-up traversal given the remaining
-  /// tensor (sub)expression and the next loop index in the iteration graph.
-  /// Returns index of the root expression.
-  unsigned buildLattices(unsigned exp, unsigned idx);
-
   /// Setter
   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
 
@@ -169,7 +165,19 @@ class Merger {
   void dumpBits(const llvm::BitVector &bits) const;
 #endif
 
+  /// Builds the iteration lattices in a bottom-up traversal given the remaining
+  /// tensor (sub)expression and the next loop index in the iteration graph.
+  /// Returns index of the root expression.
+  unsigned buildLattices(unsigned exp, unsigned idx);
+
+  /// Builds a tensor expression from the given Linalg operation.
+  /// Returns index of the root expression on success.
+  Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
+
 private:
+  /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
+  Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value val);
+
   const unsigned outTensor;
   const unsigned syntheticTensor;
   const unsigned numTensors;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index dd8d4967f1325..0409a7eabdfb7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -208,51 +208,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   return true;
 }
 
-/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
-/// This simplifies constructing (sub)expressions during iteration lattice
-/// building (compared to using the SSA representation everywhere).
-static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
-                                         Value val) {
-  if (auto arg = val.dyn_cast<BlockArgument>()) {
-    unsigned argN = arg.getArgNumber();
-    // Any argument of the generic op that is not marked as a scalar
-    // argument is considered a tensor, indexed by the implicit loop
-    // bounds. This includes rank-0 tensor arguments.
-    if (arg.getOwner()->getParentOp() == op) {
-      OpOperand *t = op.getInputAndOutputOperands()[argN];
-      if (!op.isScalar(t))
-        return merger.addExp(Kind::kTensor, argN);
-      val = t->get(); // get scalar value
-    }
-    // Any other argument (marked as scalar argument for the generic op
-    // or belonging to an enveloping op) is considered invariant.
-    return merger.addExp(Kind::kInvariant, val);
-  }
-  Operation *def = val.getDefiningOp();
-  if (def->getBlock() != &op.region().front()) {
-    // Something defined outside is invariant.
-    return merger.addExp(Kind::kInvariant, val);
-  } else if (def->getNumOperands() == 2) {
-    // Construct binary operations if subexpressions could be built.
-    auto x = buildTensorExp(merger, op, def->getOperand(0));
-    auto y = buildTensorExp(merger, op, def->getOperand(1));
-    if (x.hasValue() && y.hasValue()) {
-      unsigned e0 = x.getValue();
-      unsigned e1 = y.getValue();
-      if (isa<MulFOp>(def))
-        return merger.addExp(Kind::kMulF, e0, e1);
-      if (isa<MulIOp>(def))
-        return merger.addExp(Kind::kMulI, e0, e1);
-      if (isa<AddFOp>(def))
-        return merger.addExp(Kind::kAddF, e0, e1);
-      if (isa<AddIOp>(def))
-        return merger.addExp(Kind::kAddI, e0, e1);
-    }
-  }
-  // Cannot build (yet).
-  return None;
-}
-
 /// Returns true if given tensor co-iterates with conjunction only.
 /// For the output tensor, this defines a "simply dynamic" operation.
 /// For instance: A(I) = A(I) * B(I) * C(I)
@@ -1224,14 +1179,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
         !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
       return failure();
 
-    // Finds the terminating yield statement and builds the tensor
-    // expression for the Linalg operation in SSA form.
-    Operation *yield = op.region().front().getTerminator();
-    Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
+    // Builds the tensor expression for the Linalg operation in SSA form.
+    Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op);
     if (!exp.hasValue())
-      return failure(); // build failure
+      return failure();
 
-    // Reject an inadmissable tensor expression.
+    // Rejects an inadmissable tensor expression.
     if (!isAdmissableTensorExp(merger, op, exp.getValue()))
       return failure();
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
index bfd614cb8df4f..cbb82cb83d72c 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
@@ -6,4 +6,5 @@ add_mlir_dialect_library(MLIRSparseTensorUtils
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRLinalg
 )

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 3d63246e950fa..0c869be07a125 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -14,6 +14,10 @@
 namespace mlir {
 namespace sparse_tensor {
 
+//
+// Lattice methods.
+//
+
 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
   unsigned e = tensorExps.size();
   tensorExps.push_back(TensorExp(k, e0, e1, v));
@@ -68,7 +72,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
     if (p0 != p1) {
       // Is this a straightforward copy?
       unsigned e = latPoints[p1].exp;
-      if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
+      if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
         continue;
       // Conjunction already covered?
       for (unsigned p2 : latSets[s]) {
@@ -137,33 +141,6 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
   return false;
 }
 
-unsigned Merger::buildLattices(unsigned e, unsigned idx) {
-  Kind kind = exp(e).kind;
-  if (kind == Kind::kTensor || kind == Kind::kInvariant) {
-    // Either the index is really used in the tensor expression, or it is
-    // set to the undefined index in that dimension. An invariant expression
-    // is set to a synthetic tensor with undefined indices only.
-    unsigned s = addSet();
-    unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor;
-    set(s).push_back(addLat(t, idx, e));
-    return s;
-  }
-  unsigned s0 = buildLattices(exp(e).e0, idx);
-  unsigned s1 = buildLattices(exp(e).e1, idx);
-  switch (kind) {
-  case Kind::kTensor:
-  case Kind::kInvariant:
-    llvm_unreachable("handled above");
-  case Kind::kMulF:
-  case Kind::kMulI:
-    return takeConj(kind, s0, s1);
-  case Kind::kAddF:
-  case Kind::kAddI:
-    return takeDisj(kind, s0, s1);
-  }
-  llvm_unreachable("unexpected expression kind");
-}
-
 #ifndef NDEBUG
 
 //
@@ -173,6 +150,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
+    if (tensorExps[e].e0 == syntheticTensor)
+      llvm::dbgs() << "synthetic_";
+    else if (tensorExps[e].e0 == outTensor)
+      llvm::dbgs() << "output_";
     llvm::dbgs() << "tensor_" << tensorExps[e].e0;
     break;
   case Kind::kInvariant:
@@ -242,5 +223,82 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
 
 #endif // NDEBUG
 
+//
+// Builder methods.
+//
+
+unsigned Merger::buildLattices(unsigned e, unsigned idx) {
+  Kind kind = tensorExps[e].kind;
+  if (kind == Kind::kTensor || kind == Kind::kInvariant) {
+    // Either the index is really used in the tensor expression, or it is
+    // set to the undefined index in that dimension. An invariant expression
+    // is set to a synthetic tensor with undefined indices only.
+    unsigned s = addSet();
+    unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
+    latSets[s].push_back(addLat(t, idx, e));
+    return s;
+  }
+  unsigned s0 = buildLattices(tensorExps[e].e0, idx);
+  unsigned s1 = buildLattices(tensorExps[e].e1, idx);
+  switch (kind) {
+  case Kind::kTensor:
+  case Kind::kInvariant:
+    llvm_unreachable("handled above");
+  case Kind::kMulF:
+  case Kind::kMulI:
+    return takeConj(kind, s0, s1);
+  case Kind::kAddF:
+  case Kind::kAddI:
+    return takeDisj(kind, s0, s1);
+  }
+  llvm_unreachable("unexpected expression kind");
+}
+
+Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
+  Operation *yield = op.region().front().getTerminator();
+  return buildTensorExp(op, yield->getOperand(0));
+}
+
+Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
+  if (auto arg = val.dyn_cast<BlockArgument>()) {
+    unsigned argN = arg.getArgNumber();
+    // Any argument of the generic op that is not marked as a scalar
+    // argument is considered a tensor, indexed by the implicit loop
+    // bounds. This includes rank-0 tensor arguments.
+    if (arg.getOwner()->getParentOp() == op) {
+      OpOperand *t = op.getInputAndOutputOperands()[argN];
+      if (!op.isScalar(t))
+        return addExp(Kind::kTensor, argN);
+      val = t->get(); // get scalar value
+    }
+    // Any other argument (marked as scalar argument for the generic op
+    // or belonging to an enveloping op) is considered invariant.
+    return addExp(Kind::kInvariant, val);
+  }
+  // Something defined outside is invariant.
+  Operation *def = val.getDefiningOp();
+  if (def->getBlock() != &op.region().front())
+    return addExp(Kind::kInvariant, val);
+  // Construct binary operations if subexpressions could be built.
+  if (def->getNumOperands() == 2) {
+    auto x = buildTensorExp(op, def->getOperand(0));
+    auto y = buildTensorExp(op, def->getOperand(1));
+    if (x.hasValue() && y.hasValue()) {
+      unsigned e0 = x.getValue();
+      unsigned e1 = y.getValue();
+      if (isa<MulFOp>(def))
+        return addExp(Kind::kMulF, e0, e1);
+      if (isa<MulIOp>(def))
+        return addExp(Kind::kMulI, e0, e1);
+      if (isa<AddFOp>(def))
+        return addExp(Kind::kAddF, e0, e1);
+      if (isa<AddIOp>(def))
+        return addExp(Kind::kAddI, e0, e1);
+    }
+  }
+  // Cannot build.
+  return None;
+}
+
 } // namespace sparse_tensor
 } // namespace mlir


        


More information about the Mlir-commits mailing list