[Mlir-commits] [mlir] b22397f - [mlir][sparse] properly record dimension level type and properties

Aart Bik llvmlistbot at llvm.org
Mon Sep 12 10:00:07 PDT 2022


Author: Aart Bik
Date: 2022-09-12T09:59:53-07:00
New Revision: b22397fee45c35fcdfb679bf27f976038a3542a0

URL: https://github.com/llvm/llvm-project/commit/b22397fee45c35fcdfb679bf27f976038a3542a0
DIFF: https://github.com/llvm/llvm-project/commit/b22397fee45c35fcdfb679bf27f976038a3542a0.diff

LOG: [mlir][sparse] properly record dimension level type and properties

A next step towards supporting the new dimension level types and
properties. This changes properly records the properties in the
Merger, so that subsequent computations (lattice optimizations)
and code generation (during sparsification) can do the right thing.

https://github.com/llvm/llvm-project/issues/51658

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133620

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index ea00b958f1b95..584df8476466b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -21,7 +21,20 @@ namespace mlir {
 namespace sparse_tensor {
 
 /// Dimension level type for a tensor (undef means index does not appear).
-enum Dim { kSparse, kDense, kUndef };
+enum class DimLvlType { kDense, kCompressed, kSingleton, kUndef };
+
+/// Per-dimension level format (type and properties). Dense and undefined
+/// level types should always be marked ordered and unique.
+struct DimLevelFormat {
+  DimLevelFormat(DimLvlType tp, bool o = true, bool u = true)
+      : levelType(tp), isOrdered(o), isUnique(u) {
+    assert((tp == DimLvlType::kCompressed || tp == DimLvlType::kSingleton) ||
+           (o && u));
+  }
+  DimLvlType levelType;
+  bool isOrdered;
+  bool isUnique;
+};
 
 /// Tensor expression kind.
 enum Kind {
@@ -156,7 +169,9 @@ class Merger {
   /// invariant expressions in the kernel.
   Merger(unsigned t, unsigned l)
       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
-        hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
+        hasSparseOut(false),
+        dims(t + 1, std::vector<DimLevelFormat>(
+                        l, DimLevelFormat(DimLvlType::kUndef))) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@@ -225,31 +240,40 @@ class Merger {
   unsigned tensor(unsigned b) const { return b % numTensors; }
   unsigned index(unsigned b) const { return b / numTensors; }
 
-  /// Returns true if bit corresponds to queried dim.
-  bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
-
   /// Returns true if bit corresponds to index of output tensor.
   bool isOutTensor(unsigned b, unsigned i) const {
     return tensor(b) == outTensor && index(b) == i;
   }
 
-  /// Returns true if tensor access at given index has queried dim.
-  bool isDim(unsigned t, unsigned i, Dim d) const {
-    assert(t < numTensors && i < numLoops);
-    return dims[t][i] == d;
-  }
-
-  /// Returns true if any set bit corresponds to queried dim.
-  bool hasAnyDimOf(const BitVector &bits, Dim d) const;
-
   /// Returns true if given tensor iterates *only* in the given tensor
   /// expression. For the output tensor, this defines a "simply dynamic"
   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
   /// sparse vector a.
   bool isSingleCondition(unsigned t, unsigned e) const;
 
-  /// Dimension setter.
-  void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
+  /// Returns true if bit corresponds to given dimension level type.
+  bool isDimLevelType(unsigned b, DimLvlType tp) const {
+    return isDimLevelType(tensor(b), index(b), tp);
+  }
+
+  /// Returns true if tensor access at index has given dimension level type.
+  bool isDimLevelType(unsigned t, unsigned i, DimLvlType tp) const {
+    return getDimLevelFormat(t, i).levelType == tp;
+  }
+
+  /// Returns true if any set bit corresponds to given dimension level type.
+  bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const;
+
+  /// Dimension level format getter.
+  DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {
+    assert(t < numTensors && i < numLoops);
+    return dims[t][i];
+  }
+
+  /// Dimension level format setter.
+  void setDimLevelFormat(unsigned t, unsigned i, DimLevelFormat d) {
+    dims[t][i] = d;
+  }
 
   // Has sparse output tensor setter.
   void setHasSparseOut(bool s) { hasSparseOut = s; }
@@ -298,7 +322,7 @@ class Merger {
   const unsigned numTensors;
   const unsigned numLoops;
   bool hasSparseOut;
-  std::vector<std::vector<Dim>> dims;
+  std::vector<std::vector<DimLevelFormat>> dims;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index aa44a7f078785..642c66635a38d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -141,40 +141,60 @@ static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
   return d;
 }
 
-/// Helper method to translate dim level type to internal representation.
-static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
+/// Helper method to obtain the dimension level format from the encoding.
+//
+//  TODO: note that we store, but currently completely *ignore* the properties
+//
+static DimLevelFormat toDimLevelFormat(const SparseTensorEncodingAttr &enc,
+                                       unsigned d) {
   if (enc) {
-    SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
-    if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
-      return Dim::kSparse;
+    switch (enc.getDimLevelType()[d]) {
+    case SparseTensorEncodingAttr::DimLevelType::Dense:
+      return DimLevelFormat(DimLvlType::kDense);
+    case SparseTensorEncodingAttr::DimLevelType::Compressed:
+      return DimLevelFormat(DimLvlType::kCompressed);
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
+      return DimLevelFormat(DimLvlType::kCompressed, true, false);
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
+      return DimLevelFormat(DimLvlType::kCompressed, false, true);
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+      return DimLevelFormat(DimLvlType::kCompressed, false, false);
+    case SparseTensorEncodingAttr::DimLevelType::Singleton:
+      return DimLevelFormat(DimLvlType::kSingleton);
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
+      return DimLevelFormat(DimLvlType::kSingleton, true, false);
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
+      return DimLevelFormat(DimLvlType::kSingleton, false, true);
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+      return DimLevelFormat(DimLvlType::kSingleton, false, false);
+    }
   }
-  return Dim::kDense;
+  return DimLevelFormat(DimLvlType::kDense);
 }
 
 /// Helper method to inspect affine expressions. Rejects cases where the
-/// same index is used more than once. Also rejects affine expressions
-/// that are not a direct index for annotated tensors.
-// TODO: accept more affine cases for sparse tensors
-static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
-                       bool isDense) {
+/// same index is used more than once. Also rejects compound affine
+/// expressions in sparse dimensions.
+static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
+                       DimLevelFormat dim) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     unsigned idx = a.cast<AffineDimExpr>().getPosition();
-    if (!merger.isDim(tensor, idx, Dim::kUndef))
+    if (!merger.isDimLevelType(tensor, idx, DimLvlType::kUndef))
       return false; // used more than once
-    merger.setDim(tensor, idx, dim);
+    merger.setDimLevelFormat(tensor, idx, dim);
     return true;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
-    if (!isDense)
-      return false;
+    if (dim.levelType != DimLvlType::kDense)
+      return false; // compound only in dense dim
     auto binOp = a.cast<AffineBinaryOpExpr>();
-    return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
-           findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
+    return findAffine(merger, tensor, binOp.getLHS(), dim) &&
+           findAffine(merger, tensor, binOp.getRHS(), dim);
   }
   case AffineExprKind::Constant:
-    return isDense;
+    return dim.levelType == DimLvlType::kDense; // const only in dense dim
   default:
     return false;
   }
@@ -196,7 +216,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t->getOperandNumber();
       AffineExpr a = map.getResult(perm(enc, d));
-      if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
+      if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
         return false; // inadmissable affine expression
     }
   }
@@ -286,13 +306,13 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     if (mask & SortMask::kIncludeUndef) {
       unsigned tensor = t->getOperandNumber();
       for (unsigned i = 0; i < n; i++)
-        if (merger.isDim(tensor, i, Dim::kSparse))
+        if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
+            merger.isDimLevelType(tensor, i, DimLvlType::kSingleton))
           for (unsigned j = 0; j < n; j++)
-            if (merger.isDim(tensor, j, Dim::kUndef))
+            if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef))
               adjM[i][j] = true;
     }
   }
-
   // Topologically sort the iteration graph to determine loop order.
   // Report failure for a cyclic iteration graph.
   topSort.clear();
@@ -334,7 +354,8 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
   auto iteratorTypes = op.iterator_types().getValue();
   unsigned numLoops = iteratorTypes.size();
   for (unsigned i = 0; i < numLoops; i++)
-    if (merger.isDim(tensor, i, Dim::kSparse)) {
+    if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
+        merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
       allDense = false;
       break;
     }
@@ -519,7 +540,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         continue; // compound
       unsigned idx = a.cast<AffineDimExpr>().getPosition();
       // Handle sparse storage schemes.
-      if (merger.isDim(tensor, idx, Dim::kSparse)) {
+      if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) {
         auto dynShape = {ShapedType::kDynamicSize};
         auto ptrTp =
             MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
@@ -531,6 +552,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
             builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
         codegen.indices[tensor][idx] =
             builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+      } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
+        llvm_unreachable("TODO: not implemented yet");
       }
       // Find upper bound in current dimension.
       unsigned p = perm(enc, d);
@@ -543,7 +566,6 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     // Perform the required bufferization. Dense inputs materialize
     // from the input tensors. Dense outputs need special handling.
     // Sparse inputs use sparse primitives to obtain the values.
-    // We also accept in-place all-dense annotated "sparse" outputs.
     Type elementType = getElementTypeOrSelf(t->get().getType());
     if (!enc) {
       // Non-annotated dense tensors.
@@ -985,11 +1007,13 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
     return genInvariantValue(merger, codegen, rewriter, exp);
   if (merger.exp(exp).kind == Kind::kIndex)
     return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx);
+
   if (merger.exp(exp).kind == Kind::kReduce) {
     // Make custom reduction identity accessible for expanded access pattern.
     assert(codegen.redCustom == -1u);
     codegen.redCustom = exp;
   }
+
   Value v0 =
       genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
   Value v1 =
@@ -1000,8 +1024,12 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
              merger.exp(exp).kind == Kind::kBinaryBranch ||
              merger.exp(exp).kind == Kind::kReduce))
     ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
-  if (merger.exp(exp).kind == Kind::kReduce)
+
+  if (merger.exp(exp).kind == Kind::kReduce) {
+    assert(codegen.redCustom != -1u);
     codegen.redCustom = -1u;
+  }
+
   return ee;
 }
 
@@ -1029,7 +1057,7 @@ static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
                           linalg::GenericOp op, unsigned exp, unsigned ldx,
-                          bool atStart, unsigned last = 0) {
+                          bool atStart, unsigned last = -1u) {
   if (exp == -1u)
     return;
   if (merger.exp(exp).kind == Kind::kTensor) {
@@ -1131,7 +1159,7 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     if (inits[b]) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
-      if (merger.isDim(b, Dim::kSparse)) {
+      if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
         // Initialize sparse index.
         unsigned pat = at;
         for (; pat != 0; pat--) {
@@ -1145,6 +1173,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0);
         Value p1 = builder.create<arith::AddIOp>(loc, p0, one);
         codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1);
+      } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+        llvm_unreachable("TODO: not implemented yet");
       } else {
         // Dense index still in play.
         needsUniv = true;
@@ -1235,12 +1265,14 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   assert(idx == merger.index(fb));
   auto iteratorTypes = op.iterator_types().getValue();
   bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
-  bool isSparse = merger.isDim(fb, Dim::kSparse);
+  bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed);
   bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
                   denseUnitStrides(merger, op, idx);
   bool isParallel =
       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
 
+  assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement");
+
   // Prepare vector length.
   if (isVector)
     codegen.curVecLength = codegen.options.vectorLength;
@@ -1308,7 +1340,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // Construct the while-loop with a parameter for each index.
   Type indexType = builder.getIndexType();
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
-    if (indices[b] && merger.isDim(b, Dim::kSparse)) {
+    if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       types.push_back(indexType);
@@ -1341,7 +1373,8 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   Value cond;
   unsigned o = 0;
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
-    if (indices[b] && merger.isDim(b, Dim::kSparse)) {
+    // TODO: singleton
+    if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = before->getArgument(o);
@@ -1389,7 +1422,8 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // Initialize sparse indices.
   Value min;
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
-    if (locals[b] && merger.isDim(b, Dim::kSparse)) {
+    // TODO: singleton
+    if (locals[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value ptr = codegen.indices[tensor][idx];
@@ -1419,7 +1453,7 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // but may be needed for linearized codegen.
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
     if ((locals[b] || merger.isOutTensor(b, idx)) &&
-        merger.isDim(b, Dim::kDense)) {
+        merger.isDimLevelType(b, DimLvlType::kDense)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       unsigned pat = at;
@@ -1477,7 +1511,8 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
   SmallVector<Value, 4> operands;
   Value one = constantIndex(builder, loc, 1);
   for (unsigned b = 0, be = induction.size(); b < be; b++) {
-    if (induction[b] && merger.isDim(b, Dim::kSparse)) {
+    // TODO: singleton
+    if (induction[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = codegen.idxs[tensor][idx];
@@ -1541,7 +1576,8 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value clause;
-      if (merger.isDim(b, Dim::kSparse)) {
+      // TODO: singleton
+      if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
         Value op1 = codegen.idxs[tensor][idx];
         Value op2 = codegen.loops[idx];
         clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -1605,7 +1641,8 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     unsigned lsize = merger.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       unsigned li = merger.set(lts)[i];
-      if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
+      if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) &&
+          !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton))
         return true;
     }
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index eeaaa2e8e2e9d..cd5d1347c1356 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -262,10 +262,17 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
     }
   }
   // Now apply the two basic rules.
+  //
+  // TODO: improve for singleton and properties
+  //
   BitVector simple = latPoints[p0].bits;
-  bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
+  bool reset = isSingleton &&
+      (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) ||
+       hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton));
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
-    if (simple[b] && !isDim(b, kSparse)) {
+    if (simple[b] &&
+        (!isDimLevelType(b, DimLvlType::kCompressed) &&
+         !isDimLevelType(b, DimLvlType::kSingleton))) {
       if (reset)
         simple.reset(b);
       reset = true;
@@ -290,14 +297,8 @@ bool Merger::latGT(unsigned i, unsigned j) const {
 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
   BitVector tmp = latPoints[j].bits;
   tmp ^= latPoints[i].bits;
-  return !hasAnyDimOf(tmp, kSparse);
-}
-
-bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
-  for (unsigned b = 0, be = bits.size(); b < be; b++)
-    if (bits[b] && isDim(b, d))
-      return true;
-  return false;
+  return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) &&
+         !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton);
 }
 
 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
@@ -383,6 +384,13 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   llvm_unreachable("unexpected kind");
 }
 
+bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const {
+  for (unsigned b = 0, be = bits.size(); b < be; b++)
+    if (bits[b] && isDimLevelType(b, tp))
+      return true;
+  return false;
+}
+
 #ifndef NDEBUG
 
 //===----------------------------------------------------------------------===//
@@ -591,18 +599,23 @@ void Merger::dumpBits(const BitVector &bits) const {
     if (bits[b]) {
       unsigned t = tensor(b);
       unsigned i = index(b);
+      DimLevelFormat f = dims[t][i];
       llvm::dbgs() << " i_" << t << "_" << i << "_";
-      switch (dims[t][i]) {
-      case kSparse:
-        llvm::dbgs() << "S";
-        break;
-      case kDense:
+      switch (f.levelType) {
+      case DimLvlType::kDense:
         llvm::dbgs() << "D";
         break;
-      case kUndef:
+      case DimLvlType::kCompressed:
+        llvm::dbgs() << "C";
+        break;
+      case DimLvlType::kSingleton:
+        llvm::dbgs() << "S";
+        break;
+      case DimLvlType::kUndef:
         llvm::dbgs() << "U";
         break;
       }
+      llvm::dbgs() << "[O=" << f.isOrdered << ",U=" << f.isUnique << "]";
     }
   }
 }
@@ -855,9 +868,8 @@ static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
   if (isa<linalg::IndexOp>(def))
     return true;
   // Operation defined outside branch.
-  if (def->getBlock() != block) {
+  if (def->getBlock() != block)
     return def->getBlock() != op->getBlock(); // invariant?
-  }
   // Operation defined within branch. Anything is accepted,
   // as long as all subexpressions are admissable.
   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
@@ -1038,7 +1050,6 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     if (x.has_value() && y.has_value() && z.has_value()) {
       unsigned e0 = x.value();
       unsigned e1 = y.value();
-      // unsigned e2 = z.getValue();
       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
         if (isAdmissableBranch(redop, redop.getRegion()))
           return addExp(kReduce, e0, e1, Value(), def);

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 3bf424c40a30d..8d41558d31e98 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -310,15 +310,15 @@ class MergerTest3T1L : public MergerTestBase {
   MergerTest3T1L() : MergerTestBase(3, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kDense);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
   }
 };
 
@@ -333,19 +333,19 @@ class MergerTest4T1L : public MergerTestBase {
   MergerTest4T1L() : MergerTestBase(4, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 2: sparse input vector
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 3: dense output vector
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDim(t3, l0, Dim::kDense);
+    merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense));
   }
 };
 
@@ -364,15 +364,15 @@ class MergerTest3T1LD : public MergerTestBase {
   MergerTest3T1LD() : MergerTestBase(3, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kDense);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kDense);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
   }
 };
 


        


More information about the Mlir-commits mailing list