[Mlir-commits] [mlir] 4569c14 - Refactor TensorExp parameters into a union
Gus Smith
llvmlistbot at llvm.org
Fri Jul 2 07:46:04 PDT 2021
Author: Gus Smith
Date: 2021-07-02T14:45:56Z
New Revision: 4569c14ac347180d9514f43c45c6f52569ce8f8c
URL: https://github.com/llvm/llvm-project/commit/4569c14ac347180d9514f43c45c6f52569ce8f8c
DIFF: https://github.com/llvm/llvm-project/commit/4569c14ac347180d9514f43c45c6f52569ce8f8c.diff
LOG: Refactor TensorExp parameters into a union
To make TensorExp clearer, this change refactors the e0/e1 fields into a union: e0/e1 for a binary op tensor expression, and tensor_num for a tensor-kinded tensor expression.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D105303
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d087e98ac42f3..4141c68a5e379 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -26,24 +26,39 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
/// Dimension level type for a tensor (undef means index does not appear).
enum class Dim { kSparse, kDense, kSingle, kUndef };
+/// Children expressions of a binary TensorExp.
+struct Children {
+ unsigned e0;
+ unsigned e1;
+};
+
/// Tensor expression. Represents a MLIR expression in tensor index notation.
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
/// stored directly. For binary operations, e0 and e1 denote the index of the
/// children tensor expressions.
struct TensorExp {
- TensorExp(Kind k, unsigned x, unsigned y, Value v)
- : kind(k), e0(x), e1(y), val(v) {
- assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
- (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
- (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
+ TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) {
+ assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) ||
+ (kind == Kind::kInvariant && x == -1u && y == -1u && val) ||
+ (kind >= Kind::kMulF && x != -1u && y != -1u && !val));
+ if (kind == Kind::kTensor) {
+ tensor = x;
+ } else if (kind >= Kind::kMulF) {
+ children.e0 = x;
+ children.e1 = y;
+ }
}
/// Tensor expression kind.
Kind kind;
- /// Indices of children expression(s).
- unsigned e0;
- unsigned e1;
+ union {
+ /// Expressions representing tensors simply have a tensor number.
+ unsigned tensor;
+
+ /// Binary operations hold the indices of their child expressions.
+ Children children;
+ };
/// Direct link to IR for an invariant. During code generation,
/// field is used to cache "hoisted" loop invariant tensor loads.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0409a7eabdfb7..813fe683ae619 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -214,11 +214,11 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) {
switch (merger.exp(exp).kind) {
case Kind::kTensor:
- return merger.exp(exp).e0 == tensor;
+ return merger.exp(exp).tensor == tensor;
case Kind::kMulF:
case Kind::kMulI:
- return isConjunction(merger, tensor, merger.exp(exp).e0) ||
- isConjunction(merger, tensor, merger.exp(exp).e1);
+ return isConjunction(merger, tensor, merger.exp(exp).children.e0) ||
+ isConjunction(merger, tensor, merger.exp(exp).children.e1);
default:
return false;
}
@@ -455,7 +455,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
}
// Actual load.
SmallVector<Value, 4> args;
- OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
+ OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
unsigned tensor = t->getOperandNumber();
auto map = op.getTiedIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
@@ -628,8 +628,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
return genTensorLoad(merger, codegen, rewriter, op, exp);
else if (merger.exp(exp).kind == Kind::kInvariant)
return genInvariantValue(merger, codegen, rewriter, exp);
- Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
- Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
+ Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
+ Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
switch (merger.exp(exp).kind) {
case Kind::kTensor:
case Kind::kInvariant:
@@ -653,7 +653,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
if (merger.exp(exp).kind == Kind::kTensor) {
// Inspect tensor indices.
bool atLevel = ldx == -1u;
- OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
+ OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
auto map = op.getTiedIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
@@ -675,8 +675,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
- unsigned e0 = merger.exp(exp).e0;
- unsigned e1 = merger.exp(exp).e1;
+ unsigned e0 = merger.exp(exp).children.e0;
+ unsigned e1 = merger.exp(exp).children.e1;
genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 0c869be07a125..6150c15a0ad18 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -72,7 +72,8 @@ unsigned Merger::optimizeSet(unsigned s0) {
if (p0 != p1) {
// Is this a straightforward copy?
unsigned e = latPoints[p1].exp;
- if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
+ if (tensorExps[e].kind == Kind::kTensor &&
+ tensorExps[e].tensor == outTensor)
continue;
// Conjunction already covered?
for (unsigned p2 : latSets[s]) {
@@ -150,11 +151,11 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) {
case Kind::kTensor:
- if (tensorExps[e].e0 == syntheticTensor)
+ if (tensorExps[e].tensor == syntheticTensor)
llvm::dbgs() << "synthetic_";
- else if (tensorExps[e].e0 == outTensor)
+ else if (tensorExps[e].tensor == outTensor)
llvm::dbgs() << "output_";
- llvm::dbgs() << "tensor_" << tensorExps[e].e0;
+ llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
break;
case Kind::kInvariant:
llvm::dbgs() << "invariant";
@@ -162,17 +163,17 @@ void Merger::dumpExp(unsigned e) const {
default:
case Kind::kMulI:
llvm::dbgs() << "(";
- dumpExp(tensorExps[e].e0);
+ dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " * ";
- dumpExp(tensorExps[e].e1);
+ dumpExp(tensorExps[e].children.e1);
llvm::dbgs() << ")";
break;
case Kind::kAddF:
case Kind::kAddI:
llvm::dbgs() << "(";
- dumpExp(tensorExps[e].e0);
+ dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " + ";
- dumpExp(tensorExps[e].e1);
+ dumpExp(tensorExps[e].children.e1);
llvm::dbgs() << ")";
break;
}
@@ -234,12 +235,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
// set to the undefined index in that dimension. An invariant expression
// is set to a synthetic tensor with undefined indices only.
unsigned s = addSet();
- unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
+ unsigned t =
+ kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor;
latSets[s].push_back(addLat(t, idx, e));
return s;
}
- unsigned s0 = buildLattices(tensorExps[e].e0, idx);
- unsigned s1 = buildLattices(tensorExps[e].e1, idx);
+ unsigned s0 = buildLattices(tensorExps[e].children.e0, idx);
+ unsigned s1 = buildLattices(tensorExps[e].children.e1, idx);
switch (kind) {
case Kind::kTensor:
case Kind::kInvariant:
More information about the Mlir-commits
mailing list