[Mlir-commits] [mlir] 744146f - [MLIR][Sparse] Refactor lattice code into its own file
Gus Smith
llvmlistbot at llvm.org
Thu Jun 24 16:03:52 PDT 2021
Author: Gus Smith
Date: 2021-06-24T23:03:44Z
New Revision: 744146f60bbf74872039871dee771d18f69bff89
URL: https://github.com/llvm/llvm-project/commit/744146f60bbf74872039871dee771d18f69bff89
DIFF: https://github.com/llvm/llvm-project/commit/744146f60bbf74872039871dee771d18f69bff89.diff
LOG: [MLIR][Sparse] Refactor lattice code into its own file
Moves iteration lattice/merger code into new SparseTensor/Utils directory. A follow-up CL will add lattice/merger unit tests.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D104757
Added:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Modified:
mlir/lib/Dialect/SparseTensor/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
new file mode 100644
index 0000000000000..0ffd00131fd44
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -0,0 +1,163 @@
+//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for dealing with iteration lattices.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
+#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
+
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/BitVector.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
+enum class Dim { kSparse, kDense, kSingle, kUndef };
+
+/// Tensor expression. Represents a MLIR expression in tensor index notation.
+/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
+/// stored directly. For binary operations, e0 and e1 denote the index of the
+/// children tensor expressions.
+struct TensorExp {
+ TensorExp(Kind k, unsigned x, unsigned y, Value v)
+ : kind(k), e0(x), e1(y), val(v) {
+ assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
+ (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
+ (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
+ }
+ Kind kind;
+ /// Indices of children expression(s).
+ unsigned e0;
+ unsigned e1;
+ /// Direct link to IR for an invariant. During code generation,
+ /// field is used to cache "hoisted" loop invariant tensor loads.
+ Value val;
+};
+
+/// Lattice point. Each lattice point consists of a conjunction of tensor
+/// loop indices (encoded in a bitvector) and the index of the corresponding
+/// tensor expression.
+struct LatPoint {
+ LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
+ bits.set(b);
+ }
+ LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
+ /// Conjunction of tensor loop indices as bitvector. This represents
+ /// all indices involved in the tensor expression
+ llvm::BitVector bits;
+ /// Simplified conjunction of tensor loop indices as bitvector. This
+ /// represents a simplified condition under which this tensor expression
+ /// must execute. Pre-computed during codegen to avoid repeated eval.
+ llvm::BitVector simple;
+ /// Index of the tensor expresssion.
+ unsigned exp;
+};
+
+/// A class to handle all iteration lattice operations. This class abstracts
+/// away from some implementation details of storing iteration lattices and
+/// tensor expressions. This allows for fine-tuning performance characteristics
+/// independently from the basic algorithm if bottlenecks are identified.
+class Merger {
+public:
+ /// Constructs a merger for the given number of tensors and loops. The
+ /// user supplies the number of tensors involved in the kernel, with the
+ /// last tensor in this set denoting the output tensor. The merger adds an
+ /// additional synthetic tensor at the end of this set to represent all
+ /// invariant expressions in the kernel.
+ Merger(unsigned t, unsigned l)
+ : outTensor(t - 1), numTensors(t + 1), numLoops(l),
+ dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
+
+ /// Adds a tensor expression. Returns its index.
+ unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
+ unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
+
+ /// Adds an iteration lattice point. Returns its index.
+ unsigned addLat(unsigned t, unsigned i, unsigned e);
+
+ /// Adds a new, initially empty, set. Returns its index.
+ unsigned addSet();
+
+ /// Computes a single conjunction of two lattice points by taking the "union"
+ /// of loop indices (effectively constructing a larger "intersection" of those
+ /// indices) with a newly constructed tensor (sub)expression of given kind.
+ /// Returns the index of the new lattice point.
+ unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
+
+ /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
+ /// cartesian product. Returns the index of the new set.
+ unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
+
+ /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
+ /// Returns the index of the new set.
+ unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
+
+ /// Optimizes the iteration lattice points in the given set. This
+ /// method should be called right before code generation to avoid
+ /// generating redundant loops and conditions.
+ unsigned optimizeSet(unsigned s0);
+
+ /// Simplifies the conditions in a conjunction of a given lattice point
+ /// within the given set using just two basic rules:
+ /// (1) multiple dense conditions are reduced to single dense, and
+ /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
+ llvm::BitVector simplifyCond(unsigned s, unsigned p0);
+
+ /// Returns true if Li > Lj.
+ bool latGT(unsigned i, unsigned j) const;
+
+ /// Returns true if Li and Lj only
diff er in dense.
+ bool onlyDenseDiff(unsigned i, unsigned j);
+
+ /// Bit translation.
+ unsigned tensor(unsigned b) const { return b % numTensors; }
+ unsigned index(unsigned b) const { return b / numTensors; }
+
+ /// Returns true if bit corresponds to queried dim.
+ bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
+
+ /// Returns true if bit corresponds to index of output tensor.
+ bool isOutTensor(unsigned b, unsigned i) const {
+ return tensor(b) == outTensor && index(b) == i;
+ }
+
+ /// Returns true if tensor access at given index has queried dim.
+ bool isDim(unsigned t, unsigned i, Dim d) const {
+ assert(t < numTensors && i < numLoops);
+ return dims[t][i] == d;
+ }
+
+ /// Returns true if any set bit corresponds to queried dim.
+ bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
+
+ /// Setter
+ void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
+
+ /// Getters.
+ TensorExp &exp(unsigned e) { return tensorExps[e]; }
+ LatPoint &lat(unsigned l) { return latPoints[l]; }
+ SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
+
+private:
+ const unsigned outTensor;
+ const unsigned numTensors;
+ const unsigned numLoops;
+
+ std::vector<std::vector<Dim>> dims;
+ llvm::SmallVector<TensorExp, 32> tensorExps;
+ llvm::SmallVector<LatPoint, 16> latPoints;
+ llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
diff --git a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt
index 9f57627c321fb..31167e6af908b 100644
--- a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 68adb6fe1db18..24600aace642c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
MLIRSCF
MLIRStandard
MLIRSparseTensor
+ MLIRSparseTensorUtils
MLIRTensor
MLIRTransforms
MLIRVector
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 9c406a36f0728..f12aaccb31692 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -47,6 +47,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Matchers.h"
@@ -58,245 +59,6 @@ using namespace mlir::sparse_tensor;
namespace {
-enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
-enum class Dim { kSparse, kDense, kSingle, kUndef };
-
-/// Tensor expression. Represents a MLIR expression in tensor index notation.
-/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
-/// stored directly. For binary operations, e0 and e1 denote the index of the
-/// children tensor expressions.
-struct TensorExp {
- TensorExp(Kind k, unsigned x, unsigned y, Value v)
- : kind(k), e0(x), e1(y), val(v) {
- assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
- (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
- (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
- }
- Kind kind;
- /// Indices of children expression(s).
- unsigned e0;
- unsigned e1;
- /// Direct link to IR for an invariant. During code generation,
- /// field is used to cache "hoisted" loop invariant tensor loads.
- Value val;
-};
-
-/// Lattice point. Each lattice point consists of a conjunction of tensor
-/// loop indices (encoded in a bitvector) and the index of the corresponding
-/// tensor expression.
-struct LatPoint {
- LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
- bits.set(b);
- }
- LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
- /// Conjunction of tensor loop indices as bitvector. This represents
- /// all indices involved in the tensor expression
- llvm::BitVector bits;
- /// Simplified conjunction of tensor loop indices as bitvector. This
- /// represents a simplified condition under which this tensor expression
- /// must execute. Pre-computed during codegen to avoid repeated eval.
- llvm::BitVector simple;
- /// Index of the tensor expresssion.
- unsigned exp;
-};
-
-/// A class to handle all iteration lattice operations. This class abstracts
-/// away from some implementation details of storing iteration lattices and
-/// tensor expressions. This allows for fine-tuning performance characteristics
-/// independently from the basic algorithm if bottlenecks are identified.
-class Merger {
-public:
- /// Constructs a merger for the given number of tensors and loops. The
- /// user supplies the number of tensors involved in the kernel, with the
- /// last tensor in this set denoting the output tensor. The merger adds an
- /// additional synthetic tensor at the end of this set to represent all
- /// invariant expressions in the kernel.
- Merger(unsigned t, unsigned l)
- : outTensor(t - 1), numTensors(t + 1), numLoops(l),
- dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
-
- /// Adds a tensor expression. Returns its index.
- unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
- unsigned e = tensorExps.size();
- tensorExps.push_back(TensorExp(k, e0, e1, v));
- return e;
- }
- unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
-
- /// Adds an iteration lattice point. Returns its index.
- unsigned addLat(unsigned t, unsigned i, unsigned e) {
- assert(t < numTensors && i < numLoops);
- unsigned p = latPoints.size();
- latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
- return p;
- }
-
- /// Adds a new, initially empty, set. Returns its index.
- unsigned addSet() {
- unsigned s = latSets.size();
- latSets.emplace_back(SmallVector<unsigned, 16>());
- return s;
- }
-
- /// Computes a single conjunction of two lattice points by taking the "union"
- /// of loop indices (effectively constructing a larger "intersection" of those
- /// indices) with a newly constructed tensor (sub)expression of given kind.
- /// Returns the index of the new lattice point.
- unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
- unsigned p = latPoints.size();
- llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
- nb |= latPoints[p1].bits;
- unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
- latPoints.push_back(LatPoint(nb, e));
- return p;
- }
-
- /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
- /// cartesian product. Returns the index of the new set.
- unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
- unsigned s = addSet();
- for (unsigned p0 : latSets[s0])
- for (unsigned p1 : latSets[s1])
- latSets[s].push_back(conjLatPoint(kind, p0, p1));
- return s;
- }
-
- /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
- /// Returns the index of the new set.
- unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
- unsigned s = takeConj(kind, s0, s1);
- for (unsigned p : latSets[s0])
- latSets[s].push_back(p);
- for (unsigned p : latSets[s1])
- latSets[s].push_back(p);
- return s;
- }
-
- /// Optimizes the iteration lattice points in the given set. This
- /// method should be called right before code generation to avoid
- /// generating redundant loops and conditions.
- unsigned optimizeSet(unsigned s0) {
- unsigned s = addSet();
- assert(latSets[s0].size() != 0);
- unsigned p0 = latSets[s0][0];
- for (unsigned p1 : latSets[s0]) {
- bool add = true;
- if (p0 != p1) {
- // Is this a straightforward copy?
- unsigned e = latPoints[p1].exp;
- if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
- continue;
- // Conjunction already covered?
- for (unsigned p2 : latSets[s]) {
- assert(!latGT(p1, p2)); // Lj => Li would be bad
- if (onlyDenseDiff(p2, p1)) {
- add = false;
- break;
- }
- }
- assert(!add || latGT(p0, p1));
- }
- if (add)
- latSets[s].push_back(p1);
- }
- for (unsigned p : latSets[s])
- latPoints[p].simple = simplifyCond(s, p);
- return s;
- }
-
- /// Simplifies the conditions in a conjunction of a given lattice point
- /// within the given set using just two basic rules:
- /// (1) multiple dense conditions are reduced to single dense, and
- /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
- llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
- // First determine if this lattice point is a *singleton*, i.e.,
- // the last point in a lattice, no other is less than this one.
- bool isSingleton = true;
- for (unsigned p1 : latSets[s]) {
- if (p0 != p1 && latGT(p0, p1)) {
- isSingleton = false;
- break;
- }
- }
- // Now apply the two basic rules.
- llvm::BitVector simple = latPoints[p0].bits;
- bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
- for (unsigned b = 0, be = simple.size(); b < be; b++) {
- if (simple[b] && !isDim(b, Dim::kSparse)) {
- if (reset)
- simple.reset(b);
- reset = true;
- }
- }
- return simple;
- }
-
- /// Returns true if Li > Lj.
- bool latGT(unsigned i, unsigned j) const {
- const llvm::BitVector &bitsi = latPoints[i].bits;
- const llvm::BitVector &bitsj = latPoints[j].bits;
- assert(bitsi.size() == bitsj.size());
- if (bitsi.count() > bitsj.count()) {
- for (unsigned b = 0, be = bitsj.size(); b < be; b++)
- if (bitsj[b] && !bitsi[b])
- return false;
- return true;
- }
- return false;
- }
-
- /// Returns true if Li and Lj only
diff er in dense.
- bool onlyDenseDiff(unsigned i, unsigned j) {
- llvm::BitVector tmp = latPoints[j].bits;
- tmp ^= latPoints[i].bits;
- return !hasAnyDimOf(tmp, Dim::kSparse);
- }
-
- /// Bit translation.
- unsigned tensor(unsigned b) const { return b % numTensors; }
- unsigned index(unsigned b) const { return b / numTensors; }
-
- /// Returns true if bit corresponds to queried dim.
- bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
-
- /// Returns true if bit corresponds to index of output tensor.
- bool isOutTensor(unsigned b, unsigned i) const {
- return tensor(b) == outTensor && index(b) == i;
- }
-
- /// Returns true if tensor access at given index has queried dim.
- bool isDim(unsigned t, unsigned i, Dim d) const {
- assert(t < numTensors && i < numLoops);
- return dims[t][i] == d;
- }
-
- /// Returns true if any set bit corresponds to queried dim.
- bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
- for (unsigned b = 0, be = bits.size(); b < be; b++)
- if (bits[b] && isDim(b, d))
- return true;
- return false;
- }
-
- /// Setter
- void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
-
- /// Getters.
- TensorExp &exp(unsigned e) { return tensorExps[e]; }
- LatPoint &lat(unsigned l) { return latPoints[l]; }
- SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
-
-private:
- const unsigned outTensor;
- const unsigned numTensors;
- const unsigned numLoops;
-
- std::vector<std::vector<Dim>> dims;
- llvm::SmallVector<TensorExp, 32> tensorExps;
- llvm::SmallVector<LatPoint, 16> latPoints;
- llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
-};
-
// Code generation.
struct CodeGen {
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..bfd614cb8df4f
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_mlir_dialect_library(MLIRSparseTensorUtils
+ Merger.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+)
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
new file mode 100644
index 0000000000000..0d1d34597afcc
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -0,0 +1,138 @@
+//===- Merger.cpp - Implementation of iteration lattices ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
+ unsigned e = tensorExps.size();
+ tensorExps.push_back(TensorExp(k, e0, e1, v));
+ return e;
+}
+
+unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
+ assert(t < numTensors && i < numLoops);
+ unsigned p = latPoints.size();
+ latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
+ return p;
+}
+
+unsigned Merger::addSet() {
+ unsigned s = latSets.size();
+ latSets.emplace_back(SmallVector<unsigned, 16>());
+ return s;
+}
+
+unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
+ unsigned p = latPoints.size();
+ llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
+ nb |= latPoints[p1].bits;
+ unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
+ latPoints.push_back(LatPoint(nb, e));
+ return p;
+}
+
+unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
+ unsigned s = addSet();
+ for (unsigned p0 : latSets[s0])
+ for (unsigned p1 : latSets[s1])
+ latSets[s].push_back(conjLatPoint(kind, p0, p1));
+ return s;
+}
+
+unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
+ unsigned s = takeConj(kind, s0, s1);
+ for (unsigned p : latSets[s0])
+ latSets[s].push_back(p);
+ for (unsigned p : latSets[s1])
+ latSets[s].push_back(p);
+ return s;
+}
+
+unsigned Merger::optimizeSet(unsigned s0) {
+ unsigned s = addSet();
+ assert(latSets[s0].size() != 0);
+ unsigned p0 = latSets[s0][0];
+ for (unsigned p1 : latSets[s0]) {
+ bool add = true;
+ if (p0 != p1) {
+ // Is this a straightforward copy?
+ unsigned e = latPoints[p1].exp;
+ if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
+ continue;
+ // Conjunction already covered?
+ for (unsigned p2 : latSets[s]) {
+ assert(!latGT(p1, p2)); // Lj => Li would be bad
+ if (onlyDenseDiff(p2, p1)) {
+ add = false;
+ break;
+ }
+ }
+ assert(!add || latGT(p0, p1));
+ }
+ if (add)
+ latSets[s].push_back(p1);
+ }
+ for (unsigned p : latSets[s])
+ latPoints[p].simple = simplifyCond(s, p);
+ return s;
+}
+
+llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
+ // First determine if this lattice point is a *singleton*, i.e.,
+ // the last point in a lattice, no other is less than this one.
+ bool isSingleton = true;
+ for (unsigned p1 : latSets[s]) {
+ if (p0 != p1 && latGT(p0, p1)) {
+ isSingleton = false;
+ break;
+ }
+ }
+ // Now apply the two basic rules.
+ llvm::BitVector simple = latPoints[p0].bits;
+ bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
+ for (unsigned b = 0, be = simple.size(); b < be; b++) {
+ if (simple[b] && !isDim(b, Dim::kSparse)) {
+ if (reset)
+ simple.reset(b);
+ reset = true;
+ }
+ }
+ return simple;
+}
+
+bool Merger::latGT(unsigned i, unsigned j) const {
+ const llvm::BitVector &bitsi = latPoints[i].bits;
+ const llvm::BitVector &bitsj = latPoints[j].bits;
+ assert(bitsi.size() == bitsj.size());
+ if (bitsi.count() > bitsj.count()) {
+ for (unsigned b = 0, be = bitsj.size(); b < be; b++)
+ if (bitsj[b] && !bitsi[b])
+ return false;
+ return true;
+ }
+ return false;
+}
+
+bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
+ llvm::BitVector tmp = latPoints[j].bits;
+ tmp ^= latPoints[i].bits;
+ return !hasAnyDimOf(tmp, Dim::kSparse);
+}
+
+bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
+ for (unsigned b = 0, be = bits.size(); b < be; b++)
+ if (bits[b] && isDim(b, d))
+ return true;
+ return false;
+}
+
+} // namespace sparse_tensor
+} // namespace mlir
More information about the Mlir-commits
mailing list