[Mlir-commits] [mlir] 744146f - [MLIR][Sparse] Refactor lattice code into its own file

Gus Smith llvmlistbot at llvm.org
Thu Jun 24 16:03:52 PDT 2021


Author: Gus Smith
Date: 2021-06-24T23:03:44Z
New Revision: 744146f60bbf74872039871dee771d18f69bff89

URL: https://github.com/llvm/llvm-project/commit/744146f60bbf74872039871dee771d18f69bff89
DIFF: https://github.com/llvm/llvm-project/commit/744146f60bbf74872039871dee771d18f69bff89.diff

LOG: [MLIR][Sparse] Refactor lattice code into its own file

Moves iteration lattice/merger code into new SparseTensor/Utils directory. A follow-up CL will add lattice/merger unit tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D104757

Added: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Modified: 
    mlir/lib/Dialect/SparseTensor/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
new file mode 100644
index 0000000000000..0ffd00131fd44
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -0,0 +1,163 @@
+//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for dealing with iteration lattices.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
+#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
+
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/BitVector.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
+enum class Dim { kSparse, kDense, kSingle, kUndef };
+
+/// Tensor expression. Represents a MLIR expression in tensor index notation.
+/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
+/// stored directly. For binary operations, e0 and e1 denote the index of the
+/// children tensor expressions.
+struct TensorExp {
+  TensorExp(Kind k, unsigned x, unsigned y, Value v)
+      : kind(k), e0(x), e1(y), val(v) {
+    assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
+           (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
+           (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
+  }
+  Kind kind;
+  /// Indices of children expression(s).
+  unsigned e0;
+  unsigned e1;
+  /// Direct link to IR for an invariant. During code generation,
+  /// field is used to cache "hoisted" loop invariant tensor loads.
+  Value val;
+};
+
+/// Lattice point. Each lattice point consists of a conjunction of tensor
+/// loop indices (encoded in a bitvector) and the index of the corresponding
+/// tensor expression.
+struct LatPoint {
+  LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
+    bits.set(b);
+  }
+  LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
+  /// Conjunction of tensor loop indices as bitvector. This represents
+  /// all indices involved in the tensor expression
+  llvm::BitVector bits;
+  /// Simplified conjunction of tensor loop indices as bitvector. This
+  /// represents a simplified condition under which this tensor expression
+  /// must execute. Pre-computed during codegen to avoid repeated eval.
+  llvm::BitVector simple;
+  /// Index of the tensor expresssion.
+  unsigned exp;
+};
+
+/// A class to handle all iteration lattice operations. This class abstracts
+/// away from some implementation details of storing iteration lattices and
+/// tensor expressions. This allows for fine-tuning performance characteristics
+/// independently from the basic algorithm if bottlenecks are identified.
+class Merger {
+public:
+  /// Constructs a merger for the given number of tensors and loops. The
+  /// user supplies the number of tensors involved in the kernel, with the
+  /// last tensor in this set denoting the output tensor. The merger adds an
+  /// additional synthetic tensor at the end of this set to represent all
+  /// invariant expressions in the kernel.
+  Merger(unsigned t, unsigned l)
+      : outTensor(t - 1), numTensors(t + 1), numLoops(l),
+        dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
+
+  /// Adds a tensor expression. Returns its index.
+  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
+  unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
+
+  /// Adds an iteration lattice point. Returns its index.
+  unsigned addLat(unsigned t, unsigned i, unsigned e);
+
+  /// Adds a new, initially empty, set. Returns its index.
+  unsigned addSet();
+
+  /// Computes a single conjunction of two lattice points by taking the "union"
+  /// of loop indices (effectively constructing a larger "intersection" of those
+  /// indices) with a newly constructed tensor (sub)expression of given kind.
+  /// Returns the index of the new lattice point.
+  unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
+
+  /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
+  /// cartesian product. Returns the index of the new set.
+  unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
+
+  /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
+  /// Returns the index of the new set.
+  unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
+
+  /// Optimizes the iteration lattice points in the given set. This
+  /// method should be called right before code generation to avoid
+  /// generating redundant loops and conditions.
+  unsigned optimizeSet(unsigned s0);
+
+  /// Simplifies the conditions in a conjunction of a given lattice point
+  /// within the given set using just two basic rules:
+  /// (1) multiple dense conditions are reduced to single dense, and
+  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
+  llvm::BitVector simplifyCond(unsigned s, unsigned p0);
+
+  /// Returns true if Li > Lj.
+  bool latGT(unsigned i, unsigned j) const;
+
+  /// Returns true if Li and Lj only 
diff er in dense.
+  bool onlyDenseDiff(unsigned i, unsigned j);
+
+  /// Bit translation.
+  unsigned tensor(unsigned b) const { return b % numTensors; }
+  unsigned index(unsigned b) const { return b / numTensors; }
+
+  /// Returns true if bit corresponds to queried dim.
+  bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
+
+  /// Returns true if bit corresponds to index of output tensor.
+  bool isOutTensor(unsigned b, unsigned i) const {
+    return tensor(b) == outTensor && index(b) == i;
+  }
+
+  /// Returns true if tensor access at given index has queried dim.
+  bool isDim(unsigned t, unsigned i, Dim d) const {
+    assert(t < numTensors && i < numLoops);
+    return dims[t][i] == d;
+  }
+
+  /// Returns true if any set bit corresponds to queried dim.
+  bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
+
+  /// Setter
+  void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
+
+  /// Getters.
+  TensorExp &exp(unsigned e) { return tensorExps[e]; }
+  LatPoint &lat(unsigned l) { return latPoints[l]; }
+  SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
+
+private:
+  const unsigned outTensor;
+  const unsigned numTensors;
+  const unsigned numLoops;
+
+  std::vector<std::vector<Dim>> dims;
+  llvm::SmallVector<TensorExp, 32> tensorExps;
+  llvm::SmallVector<LatPoint, 16> latPoints;
+  llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_

diff  --git a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt
index 9f57627c321fb..31167e6af908b 100644
--- a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 68adb6fe1db18..24600aace642c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   MLIRSCF
   MLIRStandard
   MLIRSparseTensor
+  MLIRSparseTensorUtils
   MLIRTensor
   MLIRTransforms
   MLIRVector

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 9c406a36f0728..f12aaccb31692 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -47,6 +47,7 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Matchers.h"
@@ -58,245 +59,6 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
-enum class Dim { kSparse, kDense, kSingle, kUndef };
-
-/// Tensor expression. Represents a MLIR expression in tensor index notation.
-/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
-/// stored directly. For binary operations, e0 and e1 denote the index of the
-/// children tensor expressions.
-struct TensorExp {
-  TensorExp(Kind k, unsigned x, unsigned y, Value v)
-      : kind(k), e0(x), e1(y), val(v) {
-    assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
-           (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
-           (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
-  }
-  Kind kind;
-  /// Indices of children expression(s).
-  unsigned e0;
-  unsigned e1;
-  /// Direct link to IR for an invariant. During code generation,
-  /// field is used to cache "hoisted" loop invariant tensor loads.
-  Value val;
-};
-
-/// Lattice point. Each lattice point consists of a conjunction of tensor
-/// loop indices (encoded in a bitvector) and the index of the corresponding
-/// tensor expression.
-struct LatPoint {
-  LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
-    bits.set(b);
-  }
-  LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
-  /// Conjunction of tensor loop indices as bitvector. This represents
-  /// all indices involved in the tensor expression
-  llvm::BitVector bits;
-  /// Simplified conjunction of tensor loop indices as bitvector. This
-  /// represents a simplified condition under which this tensor expression
-  /// must execute. Pre-computed during codegen to avoid repeated eval.
-  llvm::BitVector simple;
-  /// Index of the tensor expresssion.
-  unsigned exp;
-};
-
-/// A class to handle all iteration lattice operations. This class abstracts
-/// away from some implementation details of storing iteration lattices and
-/// tensor expressions. This allows for fine-tuning performance characteristics
-/// independently from the basic algorithm if bottlenecks are identified.
-class Merger {
-public:
-  /// Constructs a merger for the given number of tensors and loops. The
-  /// user supplies the number of tensors involved in the kernel, with the
-  /// last tensor in this set denoting the output tensor. The merger adds an
-  /// additional synthetic tensor at the end of this set to represent all
-  /// invariant expressions in the kernel.
-  Merger(unsigned t, unsigned l)
-      : outTensor(t - 1), numTensors(t + 1), numLoops(l),
-        dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
-
-  /// Adds a tensor expression. Returns its index.
-  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
-    unsigned e = tensorExps.size();
-    tensorExps.push_back(TensorExp(k, e0, e1, v));
-    return e;
-  }
-  unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
-
-  /// Adds an iteration lattice point. Returns its index.
-  unsigned addLat(unsigned t, unsigned i, unsigned e) {
-    assert(t < numTensors && i < numLoops);
-    unsigned p = latPoints.size();
-    latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
-    return p;
-  }
-
-  /// Adds a new, initially empty, set. Returns its index.
-  unsigned addSet() {
-    unsigned s = latSets.size();
-    latSets.emplace_back(SmallVector<unsigned, 16>());
-    return s;
-  }
-
-  /// Computes a single conjunction of two lattice points by taking the "union"
-  /// of loop indices (effectively constructing a larger "intersection" of those
-  /// indices) with a newly constructed tensor (sub)expression of given kind.
-  /// Returns the index of the new lattice point.
-  unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
-    unsigned p = latPoints.size();
-    llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
-    nb |= latPoints[p1].bits;
-    unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
-    latPoints.push_back(LatPoint(nb, e));
-    return p;
-  }
-
-  /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
-  /// cartesian product. Returns the index of the new set.
-  unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
-    unsigned s = addSet();
-    for (unsigned p0 : latSets[s0])
-      for (unsigned p1 : latSets[s1])
-        latSets[s].push_back(conjLatPoint(kind, p0, p1));
-    return s;
-  }
-
-  /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
-  /// Returns the index of the new set.
-  unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
-    unsigned s = takeConj(kind, s0, s1);
-    for (unsigned p : latSets[s0])
-      latSets[s].push_back(p);
-    for (unsigned p : latSets[s1])
-      latSets[s].push_back(p);
-    return s;
-  }
-
-  /// Optimizes the iteration lattice points in the given set. This
-  /// method should be called right before code generation to avoid
-  /// generating redundant loops and conditions.
-  unsigned optimizeSet(unsigned s0) {
-    unsigned s = addSet();
-    assert(latSets[s0].size() != 0);
-    unsigned p0 = latSets[s0][0];
-    for (unsigned p1 : latSets[s0]) {
-      bool add = true;
-      if (p0 != p1) {
-        // Is this a straightforward copy?
-        unsigned e = latPoints[p1].exp;
-        if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
-          continue;
-        // Conjunction already covered?
-        for (unsigned p2 : latSets[s]) {
-          assert(!latGT(p1, p2)); // Lj => Li would be bad
-          if (onlyDenseDiff(p2, p1)) {
-            add = false;
-            break;
-          }
-        }
-        assert(!add || latGT(p0, p1));
-      }
-      if (add)
-        latSets[s].push_back(p1);
-    }
-    for (unsigned p : latSets[s])
-      latPoints[p].simple = simplifyCond(s, p);
-    return s;
-  }
-
-  /// Simplifies the conditions in a conjunction of a given lattice point
-  /// within the given set using just two basic rules:
-  /// (1) multiple dense conditions are reduced to single dense, and
-  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
-  llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
-    // First determine if this lattice point is a *singleton*, i.e.,
-    // the last point in a lattice, no other is less than this one.
-    bool isSingleton = true;
-    for (unsigned p1 : latSets[s]) {
-      if (p0 != p1 && latGT(p0, p1)) {
-        isSingleton = false;
-        break;
-      }
-    }
-    // Now apply the two basic rules.
-    llvm::BitVector simple = latPoints[p0].bits;
-    bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
-    for (unsigned b = 0, be = simple.size(); b < be; b++) {
-      if (simple[b] && !isDim(b, Dim::kSparse)) {
-        if (reset)
-          simple.reset(b);
-        reset = true;
-      }
-    }
-    return simple;
-  }
-
-  /// Returns true if Li > Lj.
-  bool latGT(unsigned i, unsigned j) const {
-    const llvm::BitVector &bitsi = latPoints[i].bits;
-    const llvm::BitVector &bitsj = latPoints[j].bits;
-    assert(bitsi.size() == bitsj.size());
-    if (bitsi.count() > bitsj.count()) {
-      for (unsigned b = 0, be = bitsj.size(); b < be; b++)
-        if (bitsj[b] && !bitsi[b])
-          return false;
-      return true;
-    }
-    return false;
-  }
-
-  /// Returns true if Li and Lj only 
diff er in dense.
-  bool onlyDenseDiff(unsigned i, unsigned j) {
-    llvm::BitVector tmp = latPoints[j].bits;
-    tmp ^= latPoints[i].bits;
-    return !hasAnyDimOf(tmp, Dim::kSparse);
-  }
-
-  /// Bit translation.
-  unsigned tensor(unsigned b) const { return b % numTensors; }
-  unsigned index(unsigned b) const { return b / numTensors; }
-
-  /// Returns true if bit corresponds to queried dim.
-  bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
-
-  /// Returns true if bit corresponds to index of output tensor.
-  bool isOutTensor(unsigned b, unsigned i) const {
-    return tensor(b) == outTensor && index(b) == i;
-  }
-
-  /// Returns true if tensor access at given index has queried dim.
-  bool isDim(unsigned t, unsigned i, Dim d) const {
-    assert(t < numTensors && i < numLoops);
-    return dims[t][i] == d;
-  }
-
-  /// Returns true if any set bit corresponds to queried dim.
-  bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
-    for (unsigned b = 0, be = bits.size(); b < be; b++)
-      if (bits[b] && isDim(b, d))
-        return true;
-    return false;
-  }
-
-  /// Setter
-  void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
-
-  /// Getters.
-  TensorExp &exp(unsigned e) { return tensorExps[e]; }
-  LatPoint &lat(unsigned l) { return latPoints[l]; }
-  SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
-
-private:
-  const unsigned outTensor;
-  const unsigned numTensors;
-  const unsigned numLoops;
-
-  std::vector<std::vector<Dim>> dims;
-  llvm::SmallVector<TensorExp, 32> tensorExps;
-  llvm::SmallVector<LatPoint, 16> latPoints;
-  llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
-};
-
 // Code generation.
 struct CodeGen {
   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..bfd614cb8df4f
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_mlir_dialect_library(MLIRSparseTensorUtils
+  Merger.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+)

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
new file mode 100644
index 0000000000000..0d1d34597afcc
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -0,0 +1,138 @@
+//===- Merger.cpp - Implementation of iteration lattices ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
+  unsigned e = tensorExps.size();
+  tensorExps.push_back(TensorExp(k, e0, e1, v));
+  return e;
+}
+
+unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
+  assert(t < numTensors && i < numLoops);
+  unsigned p = latPoints.size();
+  latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
+  return p;
+}
+
+unsigned Merger::addSet() {
+  unsigned s = latSets.size();
+  latSets.emplace_back(SmallVector<unsigned, 16>());
+  return s;
+}
+
+unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
+  unsigned p = latPoints.size();
+  llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
+  nb |= latPoints[p1].bits;
+  unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
+  latPoints.push_back(LatPoint(nb, e));
+  return p;
+}
+
+unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
+  unsigned s = addSet();
+  for (unsigned p0 : latSets[s0])
+    for (unsigned p1 : latSets[s1])
+      latSets[s].push_back(conjLatPoint(kind, p0, p1));
+  return s;
+}
+
+unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
+  unsigned s = takeConj(kind, s0, s1);
+  for (unsigned p : latSets[s0])
+    latSets[s].push_back(p);
+  for (unsigned p : latSets[s1])
+    latSets[s].push_back(p);
+  return s;
+}
+
+unsigned Merger::optimizeSet(unsigned s0) {
+  unsigned s = addSet();
+  assert(latSets[s0].size() != 0);
+  unsigned p0 = latSets[s0][0];
+  for (unsigned p1 : latSets[s0]) {
+    bool add = true;
+    if (p0 != p1) {
+      // Is this a straightforward copy?
+      unsigned e = latPoints[p1].exp;
+      if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
+        continue;
+      // Conjunction already covered?
+      for (unsigned p2 : latSets[s]) {
+        assert(!latGT(p1, p2)); // Lj => Li would be bad
+        if (onlyDenseDiff(p2, p1)) {
+          add = false;
+          break;
+        }
+      }
+      assert(!add || latGT(p0, p1));
+    }
+    if (add)
+      latSets[s].push_back(p1);
+  }
+  for (unsigned p : latSets[s])
+    latPoints[p].simple = simplifyCond(s, p);
+  return s;
+}
+
+llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
+  // First determine if this lattice point is a *singleton*, i.e.,
+  // the last point in a lattice, no other is less than this one.
+  bool isSingleton = true;
+  for (unsigned p1 : latSets[s]) {
+    if (p0 != p1 && latGT(p0, p1)) {
+      isSingleton = false;
+      break;
+    }
+  }
+  // Now apply the two basic rules.
+  llvm::BitVector simple = latPoints[p0].bits;
+  bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
+  for (unsigned b = 0, be = simple.size(); b < be; b++) {
+    if (simple[b] && !isDim(b, Dim::kSparse)) {
+      if (reset)
+        simple.reset(b);
+      reset = true;
+    }
+  }
+  return simple;
+}
+
+bool Merger::latGT(unsigned i, unsigned j) const {
+  const llvm::BitVector &bitsi = latPoints[i].bits;
+  const llvm::BitVector &bitsj = latPoints[j].bits;
+  assert(bitsi.size() == bitsj.size());
+  if (bitsi.count() > bitsj.count()) {
+    for (unsigned b = 0, be = bitsj.size(); b < be; b++)
+      if (bitsj[b] && !bitsi[b])
+        return false;
+    return true;
+  }
+  return false;
+}
+
+bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
+  llvm::BitVector tmp = latPoints[j].bits;
+  tmp ^= latPoints[i].bits;
+  return !hasAnyDimOf(tmp, Dim::kSparse);
+}
+
+bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
+  for (unsigned b = 0, be = bits.size(); b < be; b++)
+    if (bits[b] && isDim(b, d))
+      return true;
+  return false;
+}
+
+} // namespace sparse_tensor
+} // namespace mlir


        


More information about the Mlir-commits mailing list