[Mlir-commits] [mlir] 266a741 - [mlir][sparse] move tensor expression builder into Merger utility
Aart Bik
llvmlistbot at llvm.org
Thu Jul 1 09:27:49 PDT 2021
Author: Aart Bik
Date: 2021-07-01T09:27:40-07:00
New Revision: 266a7414d8f2643be2b1dad86693b12a9f1246fa
URL: https://github.com/llvm/llvm-project/commit/266a7414d8f2643be2b1dad86693b12a9f1246fa
DIFF: https://github.com/llvm/llvm-project/commit/266a7414d8f2643be2b1dad86693b12a9f1246fa.diff
LOG: [mlir][sparse] move tensor expression builder into Merger utility
Rationale:
Follow-up on migrating lattice and tensor expression related methods into the new utility.
This also prepares the next step of generalizing the op kinds that are handled.
Reviewed By: gussmith23
Differential Revision: https://reviews.llvm.org/D105219
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index cbb0aede83f81..d087e98ac42f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/BitVector.h"
@@ -148,11 +149,6 @@ class Merger {
/// Returns true if any set bit corresponds to queried dim.
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
- /// Builds the iteration lattices in a bottom-up traversal given the remaining
- /// tensor (sub)expression and the next loop index in the iteration graph.
- /// Returns index of the root expression.
- unsigned buildLattices(unsigned exp, unsigned idx);
-
/// Setter
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
@@ -169,7 +165,19 @@ class Merger {
void dumpBits(const llvm::BitVector &bits) const;
#endif
+ /// Builds the iteration lattices in a bottom-up traversal given the remaining
+ /// tensor (sub)expression and the next loop index in the iteration graph.
+ /// Returns index of the root expression.
+ unsigned buildLattices(unsigned exp, unsigned idx);
+
+ /// Builds a tensor expression from the given Linalg operation.
+ /// Returns index of the root expression on success.
+ Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
+
private:
+ /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
+ Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value val);
+
const unsigned outTensor;
const unsigned syntheticTensor;
const unsigned numTensors;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index dd8d4967f1325..0409a7eabdfb7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -208,51 +208,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
return true;
}
-/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
-/// This simplifies constructing (sub)expressions during iteration lattice
-/// building (compared to using the SSA representation everywhere).
-static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
- Value val) {
- if (auto arg = val.dyn_cast<BlockArgument>()) {
- unsigned argN = arg.getArgNumber();
- // Any argument of the generic op that is not marked as a scalar
- // argument is considered a tensor, indexed by the implicit loop
- // bounds. This includes rank-0 tensor arguments.
- if (arg.getOwner()->getParentOp() == op) {
- OpOperand *t = op.getInputAndOutputOperands()[argN];
- if (!op.isScalar(t))
- return merger.addExp(Kind::kTensor, argN);
- val = t->get(); // get scalar value
- }
- // Any other argument (marked as scalar argument for the generic op
- // or belonging to an enveloping op) is considered invariant.
- return merger.addExp(Kind::kInvariant, val);
- }
- Operation *def = val.getDefiningOp();
- if (def->getBlock() != &op.region().front()) {
- // Something defined outside is invariant.
- return merger.addExp(Kind::kInvariant, val);
- } else if (def->getNumOperands() == 2) {
- // Construct binary operations if subexpressions could be built.
- auto x = buildTensorExp(merger, op, def->getOperand(0));
- auto y = buildTensorExp(merger, op, def->getOperand(1));
- if (x.hasValue() && y.hasValue()) {
- unsigned e0 = x.getValue();
- unsigned e1 = y.getValue();
- if (isa<MulFOp>(def))
- return merger.addExp(Kind::kMulF, e0, e1);
- if (isa<MulIOp>(def))
- return merger.addExp(Kind::kMulI, e0, e1);
- if (isa<AddFOp>(def))
- return merger.addExp(Kind::kAddF, e0, e1);
- if (isa<AddIOp>(def))
- return merger.addExp(Kind::kAddI, e0, e1);
- }
- }
- // Cannot build (yet).
- return None;
-}
-
/// Returns true if given tensor co-iterates with conjunction only.
/// For the output tensor, this defines a "simply dynamic" operation.
/// For instance: A(I) = A(I) * B(I) * C(I)
@@ -1224,14 +1179,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
return failure();
- // Finds the terminating yield statement and builds the tensor
- // expression for the Linalg operation in SSA form.
- Operation *yield = op.region().front().getTerminator();
- Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
+ // Builds the tensor expression for the Linalg operation in SSA form.
+ Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op);
if (!exp.hasValue())
- return failure(); // build failure
+ return failure();
- // Reject an inadmissable tensor expression.
+ // Rejects an inadmissable tensor expression.
if (!isAdmissableTensorExp(merger, op, exp.getValue()))
return failure();
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
index bfd614cb8df4f..cbb82cb83d72c 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
@@ -6,4 +6,5 @@ add_mlir_dialect_library(MLIRSparseTensorUtils
LINK_LIBS PUBLIC
MLIRIR
+ MLIRLinalg
)
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 3d63246e950fa..0c869be07a125 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -14,6 +14,10 @@
namespace mlir {
namespace sparse_tensor {
+//
+// Lattice methods.
+//
+
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
unsigned e = tensorExps.size();
tensorExps.push_back(TensorExp(k, e0, e1, v));
@@ -68,7 +72,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
if (p0 != p1) {
// Is this a straightforward copy?
unsigned e = latPoints[p1].exp;
- if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
+ if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
continue;
// Conjunction already covered?
for (unsigned p2 : latSets[s]) {
@@ -137,33 +141,6 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
return false;
}
-unsigned Merger::buildLattices(unsigned e, unsigned idx) {
- Kind kind = exp(e).kind;
- if (kind == Kind::kTensor || kind == Kind::kInvariant) {
- // Either the index is really used in the tensor expression, or it is
- // set to the undefined index in that dimension. An invariant expression
- // is set to a synthetic tensor with undefined indices only.
- unsigned s = addSet();
- unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor;
- set(s).push_back(addLat(t, idx, e));
- return s;
- }
- unsigned s0 = buildLattices(exp(e).e0, idx);
- unsigned s1 = buildLattices(exp(e).e1, idx);
- switch (kind) {
- case Kind::kTensor:
- case Kind::kInvariant:
- llvm_unreachable("handled above");
- case Kind::kMulF:
- case Kind::kMulI:
- return takeConj(kind, s0, s1);
- case Kind::kAddF:
- case Kind::kAddI:
- return takeDisj(kind, s0, s1);
- }
- llvm_unreachable("unexpected expression kind");
-}
-
#ifndef NDEBUG
//
@@ -173,6 +150,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) {
case Kind::kTensor:
+ if (tensorExps[e].e0 == syntheticTensor)
+ llvm::dbgs() << "synthetic_";
+ else if (tensorExps[e].e0 == outTensor)
+ llvm::dbgs() << "output_";
llvm::dbgs() << "tensor_" << tensorExps[e].e0;
break;
case Kind::kInvariant:
@@ -242,5 +223,82 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
#endif // NDEBUG
+//
+// Builder methods.
+//
+
+unsigned Merger::buildLattices(unsigned e, unsigned idx) {
+ Kind kind = tensorExps[e].kind;
+ if (kind == Kind::kTensor || kind == Kind::kInvariant) {
+ // Either the index is really used in the tensor expression, or it is
+ // set to the undefined index in that dimension. An invariant expression
+ // is set to a synthetic tensor with undefined indices only.
+ unsigned s = addSet();
+ unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
+ latSets[s].push_back(addLat(t, idx, e));
+ return s;
+ }
+ unsigned s0 = buildLattices(tensorExps[e].e0, idx);
+ unsigned s1 = buildLattices(tensorExps[e].e1, idx);
+ switch (kind) {
+ case Kind::kTensor:
+ case Kind::kInvariant:
+ llvm_unreachable("handled above");
+ case Kind::kMulF:
+ case Kind::kMulI:
+ return takeConj(kind, s0, s1);
+ case Kind::kAddF:
+ case Kind::kAddI:
+ return takeDisj(kind, s0, s1);
+ }
+ llvm_unreachable("unexpected expression kind");
+}
+
+Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
+ Operation *yield = op.region().front().getTerminator();
+ return buildTensorExp(op, yield->getOperand(0));
+}
+
+Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
+ if (auto arg = val.dyn_cast<BlockArgument>()) {
+ unsigned argN = arg.getArgNumber();
+ // Any argument of the generic op that is not marked as a scalar
+ // argument is considered a tensor, indexed by the implicit loop
+ // bounds. This includes rank-0 tensor arguments.
+ if (arg.getOwner()->getParentOp() == op) {
+ OpOperand *t = op.getInputAndOutputOperands()[argN];
+ if (!op.isScalar(t))
+ return addExp(Kind::kTensor, argN);
+ val = t->get(); // get scalar value
+ }
+ // Any other argument (marked as scalar argument for the generic op
+ // or belonging to an enveloping op) is considered invariant.
+ return addExp(Kind::kInvariant, val);
+ }
+ // Something defined outside is invariant.
+ Operation *def = val.getDefiningOp();
+ if (def->getBlock() != &op.region().front())
+ return addExp(Kind::kInvariant, val);
+ // Construct binary operations if subexpressions could be built.
+ if (def->getNumOperands() == 2) {
+ auto x = buildTensorExp(op, def->getOperand(0));
+ auto y = buildTensorExp(op, def->getOperand(1));
+ if (x.hasValue() && y.hasValue()) {
+ unsigned e0 = x.getValue();
+ unsigned e1 = y.getValue();
+ if (isa<MulFOp>(def))
+ return addExp(Kind::kMulF, e0, e1);
+ if (isa<MulIOp>(def))
+ return addExp(Kind::kMulI, e0, e1);
+ if (isa<AddFOp>(def))
+ return addExp(Kind::kAddF, e0, e1);
+ if (isa<AddIOp>(def))
+ return addExp(Kind::kAddI, e0, e1);
+ }
+ }
+ // Cannot build.
+ return None;
+}
+
} // namespace sparse_tensor
} // namespace mlir
More information about the Mlir-commits
mailing list