[Mlir-commits] [mlir] 4569c14 - Refactor TensorExp parameters into a union

Gus Smith llvmlistbot at llvm.org
Fri Jul 2 07:46:04 PDT 2021


Author: Gus Smith
Date: 2021-07-02T14:45:56Z
New Revision: 4569c14ac347180d9514f43c45c6f52569ce8f8c

URL: https://github.com/llvm/llvm-project/commit/4569c14ac347180d9514f43c45c6f52569ce8f8c
DIFF: https://github.com/llvm/llvm-project/commit/4569c14ac347180d9514f43c45c6f52569ce8f8c.diff

LOG: Refactor TensorExp parameters into a union

To make TensorExp clearer, this change refactors the e0/e1 fields into a union: e0/e1 for a binary op tensor expression, and tensor_num for a tensor-kinded tensor expression.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D105303

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d087e98ac42f3..4141c68a5e379 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -26,24 +26,39 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
 /// Dimension level type for a tensor (undef means index does not appear).
 enum class Dim { kSparse, kDense, kSingle, kUndef };
 
+/// Children expressions of a binary TensorExp.
+struct Children {
+  unsigned e0;
+  unsigned e1;
+};
+
 /// Tensor expression. Represents a MLIR expression in tensor index notation.
 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
 /// stored directly. For binary operations, e0 and e1 denote the index of the
 /// children tensor expressions.
 struct TensorExp {
-  TensorExp(Kind k, unsigned x, unsigned y, Value v)
-      : kind(k), e0(x), e1(y), val(v) {
-    assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
-           (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
-           (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
+  TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) {
+    assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) ||
+           (kind == Kind::kInvariant && x == -1u && y == -1u && val) ||
+           (kind >= Kind::kMulF && x != -1u && y != -1u && !val));
+    if (kind == Kind::kTensor) {
+      tensor = x;
+    } else if (kind >= Kind::kMulF) {
+      children.e0 = x;
+      children.e1 = y;
+    }
   }
 
   /// Tensor expression kind.
   Kind kind;
 
-  /// Indices of children expression(s).
-  unsigned e0;
-  unsigned e1;
+  union {
+    /// Expressions representing tensors simply have a tensor number.
+    unsigned tensor;
+
+    /// Binary operations hold the indices of their child expressions.
+    Children children;
+  };
 
   /// Direct link to IR for an invariant. During code generation,
   /// field is used to cache "hoisted" loop invariant tensor loads.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0409a7eabdfb7..813fe683ae619 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -214,11 +214,11 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) {
   switch (merger.exp(exp).kind) {
   case Kind::kTensor:
-    return merger.exp(exp).e0 == tensor;
+    return merger.exp(exp).tensor == tensor;
   case Kind::kMulF:
   case Kind::kMulI:
-    return isConjunction(merger, tensor, merger.exp(exp).e0) ||
-           isConjunction(merger, tensor, merger.exp(exp).e1);
+    return isConjunction(merger, tensor, merger.exp(exp).children.e0) ||
+           isConjunction(merger, tensor, merger.exp(exp).children.e1);
   default:
     return false;
   }
@@ -455,7 +455,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
   }
   // Actual load.
   SmallVector<Value, 4> args;
-  OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
+  OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
   unsigned tensor = t->getOperandNumber();
   auto map = op.getTiedIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
@@ -628,8 +628,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     return genTensorLoad(merger, codegen, rewriter, op, exp);
   else if (merger.exp(exp).kind == Kind::kInvariant)
     return genInvariantValue(merger, codegen, rewriter, exp);
-  Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
-  Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
+  Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
+  Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
   switch (merger.exp(exp).kind) {
   case Kind::kTensor:
   case Kind::kInvariant:
@@ -653,7 +653,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
   if (merger.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
-    OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
+    OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
     auto map = op.getTiedIndexingMap(t);
     auto enc = getSparseTensorEncoding(t->get().getType());
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
@@ -675,8 +675,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
-    unsigned e0 = merger.exp(exp).e0;
-    unsigned e1 = merger.exp(exp).e1;
+    unsigned e0 = merger.exp(exp).children.e0;
+    unsigned e1 = merger.exp(exp).children.e1;
     genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
     genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 0c869be07a125..6150c15a0ad18 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -72,7 +72,8 @@ unsigned Merger::optimizeSet(unsigned s0) {
     if (p0 != p1) {
       // Is this a straightforward copy?
       unsigned e = latPoints[p1].exp;
-      if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
+      if (tensorExps[e].kind == Kind::kTensor &&
+          tensorExps[e].tensor == outTensor)
         continue;
       // Conjunction already covered?
       for (unsigned p2 : latSets[s]) {
@@ -150,11 +151,11 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
-    if (tensorExps[e].e0 == syntheticTensor)
+    if (tensorExps[e].tensor == syntheticTensor)
       llvm::dbgs() << "synthetic_";
-    else if (tensorExps[e].e0 == outTensor)
+    else if (tensorExps[e].tensor == outTensor)
       llvm::dbgs() << "output_";
-    llvm::dbgs() << "tensor_" << tensorExps[e].e0;
+    llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
     break;
   case Kind::kInvariant:
     llvm::dbgs() << "invariant";
@@ -162,17 +163,17 @@ void Merger::dumpExp(unsigned e) const {
   default:
   case Kind::kMulI:
     llvm::dbgs() << "(";
-    dumpExp(tensorExps[e].e0);
+    dumpExp(tensorExps[e].children.e0);
     llvm::dbgs() << " * ";
-    dumpExp(tensorExps[e].e1);
+    dumpExp(tensorExps[e].children.e1);
     llvm::dbgs() << ")";
     break;
   case Kind::kAddF:
   case Kind::kAddI:
     llvm::dbgs() << "(";
-    dumpExp(tensorExps[e].e0);
+    dumpExp(tensorExps[e].children.e0);
     llvm::dbgs() << " + ";
-    dumpExp(tensorExps[e].e1);
+    dumpExp(tensorExps[e].children.e1);
     llvm::dbgs() << ")";
     break;
   }
@@ -234,12 +235,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
     // set to the undefined index in that dimension. An invariant expression
     // is set to a synthetic tensor with undefined indices only.
     unsigned s = addSet();
-    unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
+    unsigned t =
+        kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor;
     latSets[s].push_back(addLat(t, idx, e));
     return s;
   }
-  unsigned s0 = buildLattices(tensorExps[e].e0, idx);
-  unsigned s1 = buildLattices(tensorExps[e].e1, idx);
+  unsigned s0 = buildLattices(tensorExps[e].children.e0, idx);
+  unsigned s1 = buildLattices(tensorExps[e].children.e1, idx);
   switch (kind) {
   case Kind::kTensor:
   case Kind::kInvariant:


        


More information about the Mlir-commits mailing list