[Mlir-commits] [mlir] b8cf7af - [mlir][sparse] Cleaning up names in {Merger, LoopEmitter, CodegenEnv}.{h, cpp}
wren romano
llvmlistbot at llvm.org
Tue Mar 14 11:51:06 PDT 2023
Author: wren romano
Date: 2023-03-14T11:50:56-07:00
New Revision: b8cf7af9090b0a4d98f47b0b1bf3e6e9041ab59c
URL: https://github.com/llvm/llvm-project/commit/b8cf7af9090b0a4d98f47b0b1bf3e6e9041ab59c
DIFF: https://github.com/llvm/llvm-project/commit/b8cf7af9090b0a4d98f47b0b1bf3e6e9041ab59c.diff
LOG: [mlir][sparse] Cleaning up names in {Merger,LoopEmitter,CodegenEnv}.{h,cpp}
This change does a bunch of renaming to clear up confusions in these files. In particular, this change:
* Renames variables and methods to clarify the "dim"/"lvl" distinction, and changes them to use the `Dimension`/`Level` types as appropriate.
* Introduces new typedefs
* `ExprId`, `LatPointId`, `LatSetId`: to clarify the interning design of the Merger.
* `LoopId`, `LoopOrd`: to clarify the distinction between arbitrary names for loop-variables, vs numeric identifiers based on the actual order of loop generation.
* `TensorId`
* (Future CLs will change these from typedefs to structs/classes, so that the typechecker can help avoid mixups.)
* Updates documentation to match the new terminology
* Adds additional assertions
* Adds `const` to local variables along the way
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D145756
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 888eb7f325ac..bdcd3632b5d3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -61,6 +61,7 @@ class SparseTensorType {
// Copy-assignment would be implicitly deleted (because our fields
// are const), so we explicitly delete it for clarity.
SparseTensorType &operator=(const SparseTensorType &) = delete;
+ // So we must explicitly define the copy-ctor to silence -Wdeprecated-copy.
SparseTensorType(const SparseTensorType &) = default;
/// Constructs a new `SparseTensorType` with the same dimension-shape
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4ad069d18fd8..3c5d2d37e3e0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/BitVector.h"
#include <optional>
@@ -23,11 +24,27 @@ namespace mlir {
namespace sparse_tensor {
/// Tensor expression kind.
+///
+/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
+/// That is, its argument is a `LoopId` identifying the loop-variable
+/// in question, and its value will be the current iteration's value
+/// of that loop-variable. See the `LoopId` documentation for more details.
+//
+// TODO: make this an `enum class` nested in the `TensorExp` class;
+// to improve namespacing, and match the pattern used by other "Kind"
+// enums in MLIR.
+//
+// TODO: Modify this definition so that the numeric values already encode
+// the `ExpArity` (while extending the notion of "arity" to include not
+// just the number of `ExprId` children the node has, but also whether the
+// node has a `Value` and/or `Operation*`). Doing this will avoid needing
+// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
+// and should help clean up a few other places as well.
enum Kind {
// Leaf.
kTensor = 0,
kInvariant,
- kIndex,
+ kLoopVar,
// Unary operations.
kAbsF,
kAbsC,
@@ -87,27 +104,94 @@ enum Kind {
kReduce, // semiring reduction op
};
+// TODO: These type aliases currently only serve to make the code more
+// self-documenting, however because they are not type-checked they can
+// do nothing to prevent mixups. We should really change them from mere
+// aliases to actual struct definitions, so that we can type-check them.
+
+/// Tensor identifiers. The valid set of identifiers is defined by the
+/// first argument passed to the `Merger` ctor.
+using TensorId = unsigned;
+
+/// Loop identifiers. The valid set of identifiers is defined by the
+/// second two arguments to the `Merger` ctor.
+///
+/// These identifiers serve as proxies for the `$dim` argument to
+/// `linalg::IndexOp`, however the numerical value of a `LoopId` should
+/// not necessarily be equated with the numerical value of the corresponding
+/// `$dim` argument. The `$dim` arguments are De Bruijn indices: that
+/// is, they identify the loop which binds the loop-variable by counting
+/// the enclosing loops from innermost to outermost, starting from zero.
+/// Whereas `LoopId` are considered to be arbitrary names for identifying
+/// loops; since the `Merger` does not care about the actual ordering of
+/// loops, and leaves it up to the `LoopEmitter` to specify the actual
+/// loop ordering (`LoopOrd`).
+///
+/// TODO: Despite the above claim that `$dim` and `LoopId` need not be
+/// numerically equal, some code in the `Merger` class does equate them
+/// (e.g., `buildTensorExp`). So we need to explicate the exact relationship
+/// between `$dim`, `LoopId`, and `LoopOrd`; especially with regards to their
+/// providence. If `LoopId` really is supposed to be equated with `$dim`,
+/// then we should change the name to `LoopIdx` or similar, to capture the
+/// fact that its numerical value is not invariant when entering/exiting
+/// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which
+/// are invariant identifiers).
+using LoopId = unsigned;
+
+/// A compressed representation of `std::pair<TensorId, LoopId>`.
+/// The compression scheme is such that this also serves as an index
+/// into the bitvector stored in `LatPoint` (since that bitvector is
+/// just the implementation for a set of `TensorLoopId` values).
+using TensorLoopId = unsigned;
+
+/// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
+/// and serve as unique identifiers for the corresponding `TensorExp` object.
+using ExprId = unsigned;
+
+/// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
+/// and serve as unique identifiers for the corresponding `LatPoint` object.
+using LatPointId = unsigned;
+
+/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and
+/// by other methods calling that one), and serve as unique identifiers
+/// for the corresponding `SmallVector<LatPointId>` object.
+using LatSetId = unsigned;
+
+/// A constant serving as the canonically invalid identifier, regardless
+/// of the identifier type.
+static constexpr unsigned kInvalidId = -1u;
+
/// Children subexpressions of tensor operations.
struct Children {
- unsigned e0;
- unsigned e1;
+ ExprId e0;
+ ExprId e1;
};
/// Tensor expression. Represents a MLIR expression in tensor index notation.
struct TensorExp {
- TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation);
+ // The `x` parameter has
diff erent types depending on the value of the
+ // `k` parameter. The correspondences are:
+ // * `kTensor` -> `TensorId`
+ // * `kInvariant` -> `kInvalidId`
+ // * `kLoopVar` -> `LoopId`
+ // * else -> `ExprId`
+ //
+ // The `y`, `v`, and `op` parameters either must or must not be
+ // `kInvalidId`/`nullptr`, depending on the value of the `k` parameter;
+ // however, they have uniform C++ types regardless of the value of `k`.
+ TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op);
/// Tensor expression kind.
Kind kind;
union {
- /// Expressions representing tensors simply have a tensor number.
- unsigned tensor;
+ /// `kTensor` expressions simply have a tensor identifier.
+ TensorId tensor;
- /// Indices hold the index number.
- unsigned index;
+ /// `kLoopVar` expressions simply have a loop identifier.
+ LoopId loop;
- /// Tensor operations hold the indices of their children.
+ /// All other expressions hold the `ExprId`s of their children.
Children children;
};
@@ -123,24 +207,29 @@ struct TensorExp {
Operation *op;
};
-/// Lattice point. Each lattice point consists of a conjunction of tensor
-/// loop indices (encoded in a bitvector) and the index of the corresponding
-/// tensor expression.
+/// Lattice point. Each lattice point consists of a formal conjunction
+/// of `TensorLoopId`s, together with the identifier of the corresponding
+/// tensor expression. The formal conjunction is represented as a set of
+/// `TensorLoopId`, where that set is implemented as a `BitVector`.
struct LatPoint {
- LatPoint(unsigned n, unsigned e, unsigned b);
- LatPoint(const BitVector &b, unsigned e);
+ /// Construct the lattice point from a given set of `TensorLoopId`s.
+ LatPoint(const BitVector &bits, ExprId e);
- /// Conjunction of tensor loop indices as bitvector. This represents
- /// all indices involved in the tensor expression
+ /// Construct a lattice point with `(t,i)` as the only `TensorLoopId`,
+ /// where `(t,i) < (numTensors,numLoops)`.
+ LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i,
+ ExprId e);
+
+ /// Conjunction of all `TensorLoopId`s involved in the tensor expression.
BitVector bits;
- /// Simplified conjunction of tensor loop indices as bitvector. This
+ /// Simplified conjunction of `TensorLoopId` as bitvector. This
/// represents a simplified condition under which this tensor expression
/// must execute. Pre-computed during codegen to avoid repeated eval.
BitVector simple;
- /// Index of the tensor expression.
- unsigned exp;
+ /// Identifier of the tensor expression.
+ ExprId exp;
};
/// A class to handle all iteration lattice operations. This class abstracts
@@ -157,251 +246,274 @@ class Merger {
///
/// In addition to natives loops (which are specified by the GenericOp),
/// extra filter loops are needed in order to handle affine expressions on
- /// sparse dimensions. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive
+ /// sparse levels. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive
/// implementation of the filter loop could be generated as
///
- /// for (coord : sparse_dim[0])
- /// if (coord == d0 + d1) {
+ /// for (const auto c0 : coordinates[0]) {
+ /// if (c0 == d0 + d1) {
/// generated_code;
/// }
/// }
///
/// to filter out coordinates that are not equal to the affine expression.
- ///
- /// TODO: we want to make the filter loop more efficient in the future, e.g.,
- /// by avoiding scanning the full stored index sparse (keeping the last
- /// position in ordered list) or even apply binary search to find the index.
- ///
- Merger(unsigned t, unsigned l, unsigned fl);
-
- /// Adds a tensor expression. Returns its index.
- unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
- Operation *op = nullptr);
- unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) {
- return addExp(k, e, -1u, v, op);
+ //
+ // TODO: we want to make the filter loop more efficient in the future,
+ // e.g., by avoiding scanning the full list of stored coordinates (keeping
+ // the last position in ordered list) or even apply binary search to find
+ // the coordinate.
+ //
+ // TODO: would be cleaner to understand/document if the first argument
+ // gave the number of input tensors, instead of the current number of
+ // input+output tensors.
+ Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
+ unsigned numFilterLoops);
+
+ /// Constructs a new tensor expression, and returns its identifier.
+ /// The type of the `e0` argument varies according to the value of the
+ /// `k` argument, as described by the `TensorExp` ctor.
+ ExprId addExp(Kind k, unsigned e0, ExprId e1 = kInvalidId, Value v = Value(),
+ Operation *op = nullptr);
+ ExprId addExp(Kind k, ExprId e, Value v, Operation *op = nullptr) {
+ return addExp(k, e, kInvalidId, v, op);
}
- unsigned addExp(Kind k, Value v, Operation *op = nullptr) {
- return addExp(k, -1u, -1u, v, op);
+ ExprId addExp(Kind k, Value v, Operation *op = nullptr) {
+ return addExp(k, kInvalidId, kInvalidId, v, op);
}
- /// Adds an iteration lattice point. Returns its index.
- unsigned addLat(unsigned t, unsigned i, unsigned e);
+ /// Constructs a new iteration lattice point, and returns its identifier.
+ LatPointId addLat(TensorId t, LoopId i, ExprId e);
- /// Adds a new, initially empty, set. Returns its index.
- unsigned addSet();
+ /// Constructs a new (initially empty) set, and returns its identifier.
+ LatSetId addSet();
/// Computes a single conjunction of two lattice points by taking the "union"
- /// of loop indices (effectively constructing a larger "intersection" of those
- /// indices) with a newly constructed tensor (sub)expression of given kind.
- /// Returns the index of the new lattice point.
- unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1,
- Operation *op = nullptr);
-
- /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
- /// cartesian product. Returns the index of the new set.
- unsigned takeConj(Kind kind, unsigned s0, unsigned s1,
- Operation *op = nullptr);
-
- /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
- /// Returns the index of the new set.
- unsigned takeDisj(Kind kind, unsigned s0, unsigned s1,
- Operation *op = nullptr);
-
- /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of
- /// the overlap, left, and right regions. Any region may be left missing in
- /// the output. Returns the index of the new set.
- unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
- bool includeLeft, Kind ltrans, Operation *opleft,
- bool includeRight, Kind rtrans, Operation *opright);
+ /// of `LoopId` (effectively constructing a larger "intersection" of those
+ /// loops) with a newly constructed tensor (sub)expression of given kind.
+ /// Returns the identifier of the new lattice point.
+ LatPointId conjLat(Kind kind, LatPointId p0, LatPointId p1,
+ Operation *op = nullptr);
+
+ /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
+ /// Returns the identifier of the new set.
+ LatSetId conjSet(Kind kind, LatSetId s0, LatSetId s1,
+ Operation *op = nullptr);
+
+ /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
+ /// Returns the identifier of the new set.
+ LatSetId disjSet(Kind kind, LatSetId s0, LatSetId s1,
+ Operation *op = nullptr);
+
+ /// Disjunctive merge of two lattice sets with custom handling of the
+ /// overlap, left, and right regions. Any region may be left missing
+ /// in the output. Returns the identifier of the new set.
+ LatSetId combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
+ bool includeLeft, Kind ltrans, Operation *opleft,
+ bool includeRight, Kind rtrans, Operation *opright);
/// Maps the unary operator over the lattice set of the operand, i.e. each
/// lattice point on an expression E is simply copied over, but with OP E
- /// as new expression. Returns the index of the new set.
- unsigned mapSet(Kind kind, unsigned s0, Value v = Value(),
+ /// as new expression. Returns the identifier of the new set.
+ LatSetId mapSet(Kind kind, LatSetId s, Value v = Value(),
Operation *op = nullptr);
/// Optimizes the iteration lattice points in the given set. This
/// method should be called right before code generation to avoid
/// generating redundant loops and conditions.
- unsigned optimizeSet(unsigned s0);
+ LatSetId optimizeSet(LatSetId s);
/// Simplifies the conditions in a conjunction of a given lattice point
/// within the given set using just two basic rules:
/// (1) multiple dense conditions are reduced to single dense, and
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
- BitVector simplifyCond(unsigned s0, unsigned p0);
-
- /// Returns true if Li > Lj.
- bool latGT(unsigned i, unsigned j) const;
-
- /// Returns true if Li and Lj only
diff er in dense.
- bool onlyDenseDiff(unsigned i, unsigned j);
-
- /// Bit translation (get tensor ID).
- unsigned tensor(unsigned b) const { return b % numTensors; }
- /// Bit translation (get loop index).
- unsigned index(unsigned b) const { return b / numTensors; }
-
- /// Get the number of total loops (native loops + filter loops).
- unsigned getNumLoops() const { return numLoops; }
- /// Get the number of native loops.
- unsigned getNumNativeLoops() const { return numNativeLoops; }
- /// Get the number of filter loops.
- unsigned getNumFilterLoops() const { return numLoops - numNativeLoops; }
- /// Get the starting filter loop index.
- unsigned getFilterLoopStartingIdx() const { return getNumNativeLoops(); }
-
- /// Returns true if bit corresponds to index of output tensor.
- bool isOutTensor(unsigned b, unsigned i) const {
- return tensor(b) == outTensor && index(b) == i;
+ BitVector simplifyCond(LatSetId s, LatPointId p);
+
+ /// Returns true if p0 > p1.
+ bool latGT(LatPointId p0, LatPointId p1) const;
+
+ /// Returns true if p0 and p1 only
diff er in dense.
+ bool onlyDenseDiff(LatPointId p0, LatPointId p1) const;
+
+ /// Gets the tensor-identifier of the `TensorLoopId`.
+ TensorId tensor(TensorLoopId b) const { return b % numTensors; }
+ /// Gets the loop-identifier of the `TensorLoopId`.
+ LoopId loop(TensorLoopId b) const { return b / numTensors; }
+
+ /// Get the total number of tensors (including the output-tensor and
+ /// synthetic-tensor). The result is given the type `TensorId` since
+ /// the result is primarily used as an upper bound for `TensorId`s.
+ TensorId getNumTensors() const { return numTensors; }
+
+ /// Get the total number of loops (native loops + filter loops).
+ /// The result is given the type `LoopId` since the result will
+ /// generally be used as a for-loop upper bound.
+ LoopId getNumLoops() const { return numLoops; }
+ /// Get the number of native loops. The result is given the type
+ /// `LoopId` since the result will generally be used as a for-loop
+ /// upper bound.
+ LoopId getNumNativeLoops() const { return numNativeLoops; }
+ /// Get the number of filter loops. The result is given the type
+ /// `LoopId` since the result will generally be used as a for-loop
+ /// upper bound.
+ LoopId getNumFilterLoops() const { return numLoops - numNativeLoops; }
+ /// Get the identifier of the first filter-loop.
+ LoopId getStartingFilterLoopId() const { return getNumNativeLoops(); }
+
+ /// Returns true if `b` is the `i`th loop of the output tensor.
+ bool isOutTensor(TensorLoopId b, LoopId i) const {
+ assert(i < numLoops);
+ return b == numTensors * i + outTensor;
}
- /// Gets tensor ID for the output tensor.
- unsigned getOutTensorID() const { return outTensor; }
- /// Gets tensor ID for the synthetic tensor (used for all invariant tensor
- /// expressions).
- unsigned getSynTensorID() const { return syntheticTensor; }
+ /// Get the output tensor's identifier.
+ TensorId getOutTensorID() const { return outTensor; }
+ /// Get the synthetic tensor's identifier (used for all invariant
+ /// tensor expressions).
+ TensorId getSynTensorID() const { return syntheticTensor; }
- bool isFilterLoop(unsigned ldx) const {
- assert(ldx < numLoops);
- return ldx >= numNativeLoops;
+ bool isFilterLoop(LoopId i) const {
+ assert(i < numLoops);
+ return i >= numNativeLoops;
}
/// Returns true if the expression is `(kTensor t)`.
- bool expIsTensor(unsigned e, unsigned t) const {
+ bool expIsTensor(ExprId e, TensorId t) const {
return tensorExps[e].kind == kTensor && tensorExps[e].tensor == t;
}
- /// Returns true if the expression contains the `t` as an operand.
- bool expContainsTensor(unsigned e, unsigned t) const;
+ /// Returns true if the expression contains the tensor as an operand.
+ bool expContainsTensor(ExprId e, TensorId t) const;
/// Returns true if the expression contains a negation on output tensor.
/// I.e., `- outTensor` or `exp - outputTensor`
/// NOTE: this is an trivial tests in that it does not handle recursive
/// negation, i.e., it returns true when the expression is `-(-tensor)`.
- bool hasNegateOnOut(unsigned e) const;
+ bool hasNegateOnOut(ExprId e) const;
/// Returns true if given tensor iterates *only* in the given tensor
/// expression. For the output tensor, this defines a "simply dynamic"
/// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
/// sparse vector a.
- bool isSingleCondition(unsigned t, unsigned e) const;
+ bool isSingleCondition(TensorId t, ExprId e) const;
- /// Returns true if any set bit corresponds to sparse dimension level type.
+ /// Returns true if any `TensorLoopId` in the bitvector corresponds
+ /// to sparse level-type.
bool hasAnySparse(const BitVector &bits) const;
- /// Gets the dimension level type of the `t`th tensor on `i`th loop.
- DimLevelType getDimLevelType(unsigned t, unsigned i) const {
+ /// Gets the level-type of the `t`th tensor on `i`th loop.
+ DimLevelType getDimLevelType(TensorId t, LoopId i) const {
assert(t < numTensors && i < numLoops);
- return dimTypes[t][i];
+ return lvlTypes[t][i];
}
-
- /// Gets the dimension level type of `b`.
- DimLevelType getDimLevelType(unsigned b) const {
- return getDimLevelType(tensor(b), index(b));
+ DimLevelType getDimLevelType(TensorLoopId b) const {
+ return getDimLevelType(tensor(b), loop(b));
}
- std::optional<unsigned> getLoopIdx(unsigned t, unsigned dim) const {
- assert(t < numTensors && dim < numLoops);
- return dimToLoopIdx[t][dim];
+ /// Gets the loop identifier for the `lvl`th level of the `t`th tensor.
+ std::optional<LoopId> getLoopId(TensorId t, Level lvl) const {
+ assert(t < numTensors && lvl < lvlToLoop[t].size());
+ return lvlToLoop[t][lvl];
}
- /// Gets the dimension number of the the `t`th tensor on `i`th loop.
- std::optional<unsigned> getDimNum(unsigned t, unsigned i) const {
+ /// Gets the level number of the the `t`th tensor on `i`th loop.
+ std::optional<Level> getLvl(TensorId t, LoopId i) const {
assert(t < numTensors && i < numLoops);
- return loopIdxToDim[t][i];
+ return loopToLvl[t][i];
}
-
- /// Gets the dimension number of `b`.
- std::optional<unsigned> getDimNum(unsigned b) const {
- return getDimNum(tensor(b), index(b));
+ std::optional<Level> getLvl(TensorLoopId b) const {
+ return getLvl(tensor(b), loop(b));
}
- /// Sets the dimension and dimension level type of the `t`th tensor on `i`th
- /// loop.
- void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim,
- DimLevelType dlt) {
- assert(isValidDLT(dlt));
- dimTypes[t][i] = dlt;
- loopIdxToDim[t][i] = dim;
- assert(dim < numLoops);
- dimToLoopIdx[t][dim] = i;
+ /// Sets the level number and level-type of the `t`th tensor on
+ /// `i`th loop.
+ void setLevelAndType(TensorId t, LoopId i, Level lvl, DimLevelType dlt) {
+ assert(t < numTensors && i < numLoops && lvl < lvlToLoop[t].size() &&
+ isValidDLT(dlt));
+ lvlTypes[t][i] = dlt;
+ loopToLvl[t][i] = lvl;
+ lvlToLoop[t][lvl] = i;
}
- // Iterates the bits of a lattice, for each set bit, converts it into the
- // corresponding tensor dimension and invokes the callback.
- void foreachTidDimPairInBits(
- const BitVector &bits,
- function_ref<void(unsigned b, unsigned tid, std::optional<unsigned> dim,
- DimLevelType dlt)>
- cb) {
- for (unsigned b : bits.set_bits())
- cb(b, tensor(b), getDimNum(b), getDimLevelType(b));
+ /// Iterates over a set of `TensorLoopId`s, invoking the callback
+ /// for each `TensorLoopId` and passing it the corresponding tensor
+ /// identifier, level, and level-type.
+ void
+ foreachTensorLoopId(const BitVector &bits,
+ function_ref<void(TensorLoopId, TensorId,
+ std::optional<Level>, DimLevelType)>
+ callback) const {
+ for (const TensorLoopId b : bits.set_bits())
+ callback(b, tensor(b), getLvl(b), getDimLevelType(b));
}
- // Has sparse output tensor setter.
+ /// Sets whether the output tensor is sparse or not.
void setHasSparseOut(bool s) { hasSparseOut = s; }
/// Convenience getters to immediately access the stored nodes.
/// Typically it is inadvisible to keep the reference around, as in
- /// "TensorExpr &te = merger.exp(e))", since insertions into the merger
+ /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger
/// may cause data movement and invalidate the underlying memory address.
- TensorExp &exp(unsigned e) { return tensorExps[e]; }
- LatPoint &lat(unsigned l) { return latPoints[l]; }
- SmallVector<unsigned> &set(unsigned s) { return latSets[s]; }
+ TensorExp &exp(ExprId e) { return tensorExps[e]; }
+ LatPoint &lat(LatPointId p) { return latPoints[p]; }
+ SmallVector<LatPointId> &set(LatSetId s) { return latSets[s]; }
#ifndef NDEBUG
/// Print methods (for debugging).
- void dumpExp(unsigned e) const;
- void dumpLat(unsigned p) const;
- void dumpSet(unsigned s) const;
+ void dumpExp(ExprId e) const;
+ void dumpLat(LatPointId p) const;
+ void dumpSet(LatSetId s) const;
void dumpBits(const BitVector &bits) const;
#endif
/// Builds the iteration lattices in a bottom-up traversal given the
- /// remaining tensor (sub)expression and the next loop index in the
- /// iteration graph. Returns index of the root expression.
- unsigned buildLattices(unsigned e, unsigned i);
+ /// remaining tensor (sub)expression and the next loop in the iteration
+ /// graph. Returns the identifier of the root set.
+ LatSetId buildLattices(ExprId e, LoopId i);
/// Builds a tensor expression from the given Linalg operation.
- /// Returns index of the root expression on success.
- std::optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
+ /// On success, returns the identifier of the root expression.
+ std::optional<ExprId> buildTensorExpFromLinalg(linalg::GenericOp op);
/// Rebuilds SSA format from a tensor expression.
- Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0,
- Value v1);
+ Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
+ Value v1) const;
private:
/// Private helpers.
- bool maybeZero(unsigned e) const;
- bool isInvariant(unsigned e) const;
- Type inferType(unsigned e, Value src);
+ bool maybeZero(ExprId e) const;
+ bool isInvariant(ExprId e) const;
+ Type inferType(ExprId e, Value src) const;
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
- std::optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
+ std::optional<ExprId> buildTensorExp(linalg::GenericOp op, Value v);
/// Merger data structures.
- const unsigned outTensor;
- const unsigned syntheticTensor;
+ const TensorId outTensor;
+ const TensorId syntheticTensor;
const unsigned numTensors;
const unsigned numNativeLoops;
const unsigned numLoops;
bool hasSparseOut;
- // Map that converts pair<tensor id, loop id> to the corresponding dimension
- // level type.
- std::vector<std::vector<DimLevelType>> dimTypes;
+ // Below we use `std::vector` for things which have a priori fixed
+ // sizes, whereas we use `llvm::SmallVector` for things with variable
+ // size. Do beware that these two classes
diff er in the semantics of
+ // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
+ // does not.
+
+ // Map that converts pair<TensorId, LoopId> to the corresponding
+ // level-type.
+ std::vector<std::vector<DimLevelType>> lvlTypes;
- // Map that converts pair<tensor id, loop id> to the corresponding
- // dimension.
- std::vector<std::vector<std::optional<unsigned>>> loopIdxToDim;
+ // Map that converts pair<TensorId, LoopId> to the corresponding
+ // level.
+ std::vector<std::vector<std::optional<Level>>> loopToLvl;
- // Map that converts pair<tensor id, dim> to the corresponding loop id.
- std::vector<std::vector<std::optional<unsigned>>> dimToLoopIdx;
+ // Map that converts pair<TensorId, Level> to the corresponding LoopId.
+ std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
llvm::SmallVector<TensorExp> tensorExps;
llvm::SmallVector<LatPoint> latPoints;
- llvm::SmallVector<SmallVector<unsigned>> latSets;
+ llvm::SmallVector<SmallVector<LatPointId>> latSets;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 62acac20cd96..8e4904ad3a59 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -38,12 +38,12 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
: linalgOp(linop), sparseOptions(opts),
latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(),
topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
- expFilled(), expAdded(), expCount(), redVal(), redExp(-1u),
- redCustom(-1u), redValidLexInsert() {}
+ expFilled(), expAdded(), expCount(), redVal(), redExp(kInvalidId),
+ redCustom(kInvalidId), redValidLexInsert() {}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.
- std::optional<unsigned> optExp = latticeMerger.buildTensorExpFromLinalg(op());
+ std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
if (!optExp || !isAdmissibleTensorExp(*optExp))
return failure();
@@ -101,7 +101,7 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
// Code generation environment verify functions.
//===----------------------------------------------------------------------===//
-bool CodegenEnv::isAdmissibleTensorExp(unsigned exp) {
+bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// We reject any expression that makes a reduction from `-outTensor`, as those
// expressions create a dependency between the current iteration (i) and the
// previous iteration (i-1). It would require iterating over the whole
@@ -115,7 +115,10 @@ bool CodegenEnv::isAdmissibleTensorExp(unsigned exp) {
}
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
- unsigned tensor = lhs->getOperandNumber();
+ // That the operand number is a valid `TensorId` will be verified
+ // by the call to `isSingleCondition` below; though we may want to add
+ // assertions to check it here, in order to give better error messages.
+ const TensorId tensor = lhs->getOperandNumber();
// An non-annotated output tensor is assumed dense, and becomes a random
// access n-dim memref. Admissible since insertions cannot occur.
if (getSparseTensorType(lhs->get()).isAllDense())
@@ -140,13 +143,14 @@ bool CodegenEnv::isAdmissibleTopoOrder() {
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
// Accept "truly dynamic" if the output tensor materializes uninitialized
// into the computation and insertions occur in lexicographic index order.
- unsigned nest = 0;
- auto iteratorTypes = linalgOp.getIteratorTypesArray();
- for (unsigned i = 0, e = latticeMerger.getNumLoops(); i < e; i++) {
- if (!latticeMerger.isFilterLoop(topSortAt(i))) {
+ LoopOrd nest = 0;
+ const auto iteratorTypes = linalgOp.getIteratorTypesArray();
+ assert(topSortSize() == latticeMerger.getNumLoops());
+ for (const LoopId i : topSort) {
+ if (!latticeMerger.isFilterLoop(i)) {
// We only count non-filter loops as filter loops should be considered
- // as a special type of parallel loops.
- if (linalg::isReductionIterator(iteratorTypes[topSortAt(i)]))
+ // a special type of parallel loops.
+ if (linalg::isReductionIterator(iteratorTypes[i]))
break; // terminate at first reduction
nest++;
}
@@ -154,7 +158,7 @@ bool CodegenEnv::isAdmissibleTopoOrder() {
// Determine admissible dynamic insertion situations:
// (1) fully injective, since there are no reductions,
// (2) admissible 1-d expansion in innermost dimension.
- if (nest >= linalgOp.getRank(lhs) - 1) {
+ if (static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1) {
outerParNest = nest;
return true;
}
@@ -165,19 +169,26 @@ bool CodegenEnv::isAdmissibleTopoOrder() {
// Code generation environment topological sort methods
//===----------------------------------------------------------------------===//
-ArrayRef<unsigned> CodegenEnv::getTopSortSlice(size_t n, size_t m) const {
- return ArrayRef<unsigned>(topSort).slice(n, m);
+ArrayRef<LoopId> CodegenEnv::getTopSortSlice(LoopOrd n, LoopOrd m) const {
+ return ArrayRef<LoopId>(topSort).slice(n, m);
}
-ArrayRef<unsigned> CodegenEnv::getLoopCurStack() const {
- return getTopSortSlice(0, loopEmitter.getCurrentDepth());
+ArrayRef<LoopId> CodegenEnv::getLoopStackUpTo(LoopOrd n) const {
+ return ArrayRef<LoopId>(topSort).take_front(n);
}
-Value CodegenEnv::getLoopIdxValue(size_t loopIdx) const {
- for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
- if (topSort[lv] == loopIdx)
- return loopEmitter.getLoopIV(lv);
- llvm_unreachable("invalid loop index");
+ArrayRef<LoopId> CodegenEnv::getCurrentLoopStack() const {
+ return getLoopStackUpTo(loopEmitter.getCurrentDepth());
+}
+
+Value CodegenEnv::getLoopVar(LoopId i) const {
+ // TODO: this class should store the inverse of `topSort` so that
+ // it can do this conversion directly, instead of searching through
+ // `topSort` every time. (Or else, `LoopEmitter` should handle this.)
+ for (LoopOrd n = 0, numLoops = topSortSize(); n < numLoops; n++)
+ if (topSort[n] == i)
+ return loopEmitter.getLoopIV(n);
+ llvm_unreachable("invalid loop identifier");
}
//===----------------------------------------------------------------------===//
@@ -189,8 +200,10 @@ void CodegenEnv::updateInsertionChain(Value chain) {
insChain = chain;
}
-bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const {
- return sparseOut == o && outerParNest == rank - 1 && outerParNest == lv;
+// FIXME: clarify what this "rank" is really supposed to mean/be.
+bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopOrd n) const {
+ return sparseOut == o && outerParNest == static_cast<LoopOrd>(rank - 1) &&
+ outerParNest == n;
}
void CodegenEnv::startExpand(Value values, Value filled, Value added,
@@ -216,21 +229,21 @@ void CodegenEnv::endExpand() {
// Code generation environment reduction methods
//===----------------------------------------------------------------------===//
-void CodegenEnv::startReduc(unsigned exp, Value val) {
- assert(redExp == -1u && exp != -1u);
+void CodegenEnv::startReduc(ExprId exp, Value val) {
+ assert(!isReduc() && exp != kInvalidId);
redExp = exp;
updateReduc(val);
}
void CodegenEnv::updateReduc(Value val) {
- assert(redExp != -1u);
+ assert(isReduc());
redVal = exp(redExp).val = val;
}
Value CodegenEnv::endReduc() {
Value val = redVal;
updateReduc(Value());
- redExp = -1u;
+ redExp = kInvalidId;
return val;
}
@@ -244,17 +257,17 @@ void CodegenEnv::clearValidLexInsert() {
redValidLexInsert = Value();
}
-void CodegenEnv::startCustomReduc(unsigned exp) {
- assert(redCustom == -1u && exp != -1u);
+void CodegenEnv::startCustomReduc(ExprId exp) {
+ assert(!isCustomReduc() && exp != kInvalidId);
redCustom = exp;
}
Value CodegenEnv::getCustomRedId() {
- assert(redCustom != -1u);
+ assert(isCustomReduc());
return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
}
void CodegenEnv::endCustomReduc() {
- assert(redCustom != -1u);
- redCustom = -1u;
+ assert(isCustomReduc());
+ redCustom = kInvalidId;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index b210ca873520..8c6a7bd6433d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -45,7 +45,7 @@ class CodegenEnv {
//
LogicalResult initTensorExp();
- unsigned getTensorExp() const { return tensorExp; }
+ ExprId getExprId() const { return tensorExp; }
linalg::GenericOp op() const { return linalgOp; }
const SparsificationOptions &options() const { return sparseOptions; }
@@ -65,13 +65,13 @@ class CodegenEnv {
// Merger delegates.
//
- TensorExp &exp(unsigned e) { return latticeMerger.exp(e); }
- LatPoint &lat(unsigned l) { return latticeMerger.lat(l); }
- SmallVector<unsigned> &set(unsigned s) { return latticeMerger.set(s); }
- DimLevelType dlt(unsigned t, unsigned i) const {
+ TensorExp &exp(ExprId e) { return latticeMerger.exp(e); }
+ LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); }
+ SmallVector<LatPointId> &set(LatSetId s) { return latticeMerger.set(s); }
+ DimLevelType dlt(TensorId t, LoopId i) const {
return latticeMerger.getDimLevelType(t, i);
}
- DimLevelType dlt(unsigned b) const {
+ DimLevelType dlt(TensorLoopId b) const {
return latticeMerger.getDimLevelType(b);
}
@@ -81,7 +81,7 @@ class CodegenEnv {
/// Whether the tensor expression is admissible for codegen.
/// It also sets the sparseOut if the output tensor is sparse.
- bool isAdmissibleTensorExp(unsigned exp);
+ bool isAdmissibleTensorExp(ExprId e);
/// Whether the iteration graph is sorted in admissible topoOrder.
/// Sets outerParNest on success with sparse output
@@ -91,17 +91,21 @@ class CodegenEnv {
// Topological delegate and sort methods.
//
- size_t topSortSize() const { return topSort.size(); }
- unsigned topSortAt(unsigned i) const { return topSort.at(i); }
- void topSortPushBack(unsigned i) { topSort.push_back(i); }
- void topSortClear(unsigned capacity = 0) {
+ LoopOrd topSortSize() const { return topSort.size(); }
+ LoopId topSortAt(LoopOrd n) const { return topSort.at(n); }
+ void topSortPushBack(LoopId i) { topSort.push_back(i); }
+ void topSortClear(size_t capacity = 0) {
topSort.clear();
topSort.reserve(capacity);
}
- ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const;
- ArrayRef<unsigned> getLoopCurStack() const;
- Value getLoopIdxValue(size_t loopIdx) const;
+ ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
+ ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
+ ArrayRef<LoopId> getCurrentLoopStack() const;
+ /// Returns the induction-variable for the loop identified by the given
+ /// `LoopId`. This method handles application of the topological sort
+ /// in order to convert the `LoopId` into the corresponding `LoopOrd`.
+ Value getLoopVar(LoopId i) const;
//
// Sparse tensor output and expansion methods.
@@ -113,7 +117,8 @@ class CodegenEnv {
Value getInsertionChain() const { return insChain; }
void updateInsertionChain(Value chain);
- bool atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const;
+ // FIXME: clarify what this "rank" is really supposed to mean/be.
+ bool atExpandLevel(OpOperand *o, unsigned rank, LoopOrd n) const;
void startExpand(Value values, Value filled, Value added, Value count);
bool isExpand() const { return expValues != nullptr; }
void updateExpandCount(Value count);
@@ -127,8 +132,8 @@ class CodegenEnv {
// Reduction methods.
//
- void startReduc(unsigned exp, Value val);
- bool isReduc() const { return redExp != -1u; }
+ void startReduc(ExprId exp, Value val);
+ bool isReduc() const { return redExp != kInvalidId; }
void updateReduc(Value val);
Value getReduc() const { return redVal; }
Value endReduc();
@@ -136,8 +141,8 @@ class CodegenEnv {
void clearValidLexInsert();
Value getValidLexInsert() const { return redValidLexInsert; }
- void startCustomReduc(unsigned exp);
- bool isCustomReduc() const { return redCustom != -1u; }
+ void startCustomReduc(ExprId exp);
+ bool isCustomReduc() const { return redCustom != kInvalidId; }
Value getCustomRedId();
void endCustomReduc();
@@ -154,14 +159,16 @@ class CodegenEnv {
// Loop emitter helper class.
LoopEmitter loopEmitter;
- // Topological sort.
- std::vector<unsigned> topSort;
+ // Topological sort. This serves as a mapping from `LoopOrd` to `LoopId`
+ // (cf., `getLoopVar` and `topSortAt`).
+ std::vector<LoopId> topSort;
// Sparse tensor as output. Implemented either through direct injective
// insertion in lexicographic index order or through access pattern
// expansion in the innermost loop nest (`expValues` through `expCount`).
OpOperand *sparseOut;
- unsigned outerParNest;
+ // The count of outer non-filter loops, as defined by `isAdmissibleTopoOrder`.
+ LoopOrd outerParNest;
Value insChain;
Value expValues;
Value expFilled;
@@ -172,8 +179,8 @@ class CodegenEnv {
// into the merger's expression tree. When the indices of a tensor reduction
// expression are exhausted, all inner loops can use a scalarized reduction.
Value redVal;
- unsigned redExp;
- unsigned redCustom;
+ ExprId redExp;
+ ExprId redCustom;
// Bookkeeping for lex insertion during reductions. Holds the runtime boolean
// value of whether any reduction occurred. This is only set during a
@@ -181,7 +188,7 @@ class CodegenEnv {
Value redValidLexInsert;
// The root tensor expression of the kernel.
- unsigned tensorExp;
+ ExprId tensorExp;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f48520b2286b..c3823c0f204d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
using namespace mlir;
@@ -44,53 +45,56 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
}
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
- unsigned lvl) {
+ Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
// FIXME: `toOrigDim` is deprecated
return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
}
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
- unsigned lvl) {
+ Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
// FIXME: `toOrigDim` is deprecated
return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
}
-// Converts a coordinate relative to the slice to the coordinate relative
-// to the underlying tensor.
-static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
- Value offset, Value stride, Value tensor,
- unsigned lvl) {
- // iv = iv * stride + offset
- v = builder.create<arith::MulIOp>(loc, v, stride);
- v = builder.create<arith::AddIOp>(loc, v, offset);
- return v;
+/// Converts a coordinate relative to the slice to the coordinate relative
+/// to the underlying tensor.
+// FIXME: that description says "sliceCrd -> tensorCrd"; but the function
+// name suggests it should be "tensorCrd -> sliceCrd".
+static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd,
+ Value offset, Value stride, Value tensor, Level lvl) {
+ // tensorCrd = sliceCrd * stride + offset
+ crd = builder.create<arith::MulIOp>(loc, crd, stride);
+ crd = builder.create<arith::AddIOp>(loc, crd, offset);
+ return crd;
}
-// Converts a coordinate relative to the underlying tensor to the coordinate
-// relative to the slice, returns a extra reminder value
+/// Converts a coordinate relative to the underlying tensor to the coordinate
+/// relative to the slice, returns a extra reminder value
+// FIXME: that description says "tensorCrd -> sliceCrd"; but the function
+// name suggests it should be "sliceCrd -> tensorCrd".
static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
- Value iv, Value offset,
+ Value crd, Value offset,
Value stride, Value tensor,
- unsigned lvl) {
- // iv = (iv - offset) / stride
- iv = builder.create<arith::SubIOp>(loc, iv, offset);
- Value rem = builder.create<arith::RemUIOp>(loc, iv, stride);
- iv = builder.create<arith::DivUIOp>(loc, iv, stride);
- return std::make_pair(iv, rem);
+ Level lvl) {
+ // sliceCrd = (tensorCrd - offset) / stride
+ crd = builder.create<arith::SubIOp>(loc, crd, offset);
+ Value rem = builder.create<arith::RemUIOp>(loc, crd, stride);
+ crd = builder.create<arith::DivUIOp>(loc, crd, stride);
+ return std::make_pair(crd, rem);
}
std::pair<Value, Value>
LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
- unsigned tid, unsigned lvl) {
+ TensorId tid, Level lvl) {
assert(isSparseSlices[tid]);
Value slice = tensors[tid];
Value offset = sliceOffsets[tid][lvl];
Value stride = sliceStrides[tid][lvl];
auto enc = getSparseTensorEncoding(slice.getType());
- std::pair<Value, Value> transformedCrd =
+ const auto [newCrd, crdRem] =
fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl);
SmallVector<Value, 3> conds; // at most 3 conditions
@@ -104,16 +108,15 @@ LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
}
// Second, coord_in_slice < length
- auto ltLength = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, transformedCrd.first, lvlSizes[tid][lvl]);
+ auto ltLength = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ newCrd, lvlSizes[tid][lvl]);
conds.push_back(ltLength);
// Third, rem == 0 (skip the check if stride is known to be 1).
if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
!(staticStride.has_value() && *staticStride == 1)) {
auto fitStride = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, transformedCrd.second,
- constantIndex(builder, loc, 0));
+ loc, arith::CmpIPredicate::eq, crdRem, constantIndex(builder, loc, 0));
conds.push_back(fitStride);
}
@@ -122,81 +125,81 @@ LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
for (auto cond : ValueRange(conds).drop_front())
pred = builder.create<arith::AndIOp>(loc, pred, cond);
- return {transformedCrd.first, pred};
+ return {newCrd, pred};
}
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
-Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
- size_t dim, Value iv) {
- Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
- Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
+Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
+ Level lvl, Value crd) {
+ Value pos = lvl == 0 ? constantIndex(builder, loc, 0) : posits[tid][lvl - 1];
+ Value mul = builder.create<arith::MulIOp>(loc, highs[tid][lvl], pos);
if (isSparseSlices[tid])
- iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim],
- sliceStrides[tid][dim], tensors[tid], dim);
- Value add = builder.create<arith::AddIOp>(loc, mul, iv);
+ crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl],
+ sliceStrides[tid][lvl], tensors[tid], lvl);
+ Value add = builder.create<arith::AddIOp>(loc, mul, crd);
return add;
}
-Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc, size_t tid,
- size_t lvl, Value pos, Value pHi) {
- Value prevCrd = genIndexLoad(builder, loc, crdBuffer[tid][lvl], pos);
- // De-duplicates repeated elements.
- //
- // while (pos < pHi && coord[pos] == prev_coord)
- // pos++;
- // return pos;
+Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
+ TensorId tid, Level lvl, Value pLo,
+ Value pHi) {
+ const auto coordinates = coordinatesBuffers[tid][lvl];
+ const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo);
auto whileOp = builder.create<scf::WhileOp>(
- loc, builder.getIndexType(), pos,
+ loc, builder.getIndexType(), pLo,
/*beforeBuilder=*/
- [this, tid, lvl, pHi, prevCrd](OpBuilder &builder, Location loc,
- ValueRange ivs) {
+ [pHi, coordinates, sameCrd](OpBuilder &builder, Location loc,
+ ValueRange ivs) {
+ const auto pos = ivs[0];
Value inBound = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, ivs[0], pHi);
- auto ifOp =
+ loc, arith::CmpIPredicate::ult, pos, pHi);
+ auto ifInBound =
builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
{
OpBuilder::InsertionGuard guard(builder);
// Load the next coordinates only when inbound (to avoid OOB
// acccesses).
- builder.setInsertionPointToStart(ifOp.thenBlock());
- Value nxCrd = genIndexLoad(builder, loc, crdBuffer[tid][lvl], ivs[0]);
- Value cont = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, nxCrd, prevCrd);
- builder.create<scf::YieldOp>(loc, cont);
+ builder.setInsertionPointToStart(ifInBound.thenBlock());
+ Value crd = genIndexLoad(builder, loc, coordinates, pos);
+ Value isSameCrd = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, crd, sameCrd);
+ builder.create<scf::YieldOp>(loc, isSameCrd);
// Else, the position is out of bound, yield false to terminate the
// loop.
- builder.setInsertionPointToStart(ifOp.elseBlock());
+ builder.setInsertionPointToStart(ifInBound.elseBlock());
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
}
- builder.create<scf::ConditionOp>(loc, ifOp.getResults()[0], ivs);
+ builder.create<scf::ConditionOp>(loc, ifInBound.getResults()[0], ivs);
},
/*afterBuilder=*/
[](OpBuilder &builder, Location loc, ValueRange ivs) {
// pos ++
- Value nxPos = builder.create<arith::AddIOp>(
+ Value nextPos = builder.create<arith::AddIOp>(
loc, ivs[0], constantIndex(builder, loc, 1));
- builder.create<scf::YieldOp>(loc, nxPos);
+ builder.create<scf::YieldOp>(loc, nextPos);
});
// Return the segment high.
return whileOp.getResult(0);
}
-Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
- size_t dstLvl) {
+Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
+ Level dstLvl) {
Value crd = constantIndex(builder, loc, 0);
const auto reassoc = getCollapseReassociation(tid, dstLvl);
- for (unsigned i = 0; i < reassoc.size(); i++) {
- const auto srcLvl = reassoc[i];
+ const unsigned reassocSize = reassoc.size();
+ for (unsigned i = 0; i < reassocSize; i++) {
+ const Level srcLvl = reassoc[i];
// A load on the coordinates array yields the coordinate.
- const Value mem = crdBuffer[tid][srcLvl];
- const Value pos = pidxs[tid][dstLvl];
+ const Value mem = coordinatesBuffers[tid][srcLvl];
+ /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
+ const Value pos = posits[tid][dstLvl];
const Value off = genIndexLoad(builder, loc, mem, pos);
// Linearized the coordinates within the same collapse reassociation.
crd = builder.create<arith::AddIOp>(loc, crd, off);
- if (i != reassoc.size() - 1) {
+ if (i != reassocSize - 1) {
crd = builder.create<arith::MulIOp>(loc, crd,
this->lvlSizes[tid][reassoc[i + 1]]);
}
@@ -205,35 +208,43 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
}
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
- bool isSparseOut, ArrayRef<unsigned> topSort) {
+ bool isSparseOut, ArrayRef<LoopId> topSort) {
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
- bool isSparseOut, ArrayRef<unsigned> topSort) {
- // First initializes fields.
+ bool isSparseOut, ArrayRef<LoopId> topSort) {
+ // First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
+
+ const TensorId numTensors = ts.size();
this->tensors.assign(ts.begin(), ts.end());
- this->isSparseSlices.assign(tensors.size(), false);
- this->sliceOffsets.assign(tensors.size(), std::vector<Value>());
- this->sliceStrides.assign(tensors.size(), std::vector<Value>());
- this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
- this->pidxs.assign(tensors.size(), std::vector<Value>());
- this->segHi.assign(tensors.size(), std::vector<Value>());
- this->coord.assign(tensors.size(), std::vector<Value>());
- this->highs.assign(tensors.size(), std::vector<Value>());
- this->lvlSizes.assign(tensors.size(), std::vector<Value>());
- this->posBuffer.assign(tensors.size(), std::vector<Value>());
- this->crdBuffer.assign(tensors.size(), std::vector<Value>());
- this->valBuffer.assign(tensors.size(), nullptr);
- this->loopStack.reserve(topSort.size());
- this->sparsiferLoopLvlMap.assign(topSort.size(), 0);
- this->collapseReassoc.assign(tensors.size(), nullptr);
-
- for (size_t tid = 0, e = tensors.size(); tid < e; tid++) {
- auto t = tensors[tid];
+ this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
+ this->lvlSizes.assign(numTensors, std::vector<Value>());
+ this->highs.assign(numTensors, std::vector<Value>());
+ this->segHi.assign(numTensors, std::vector<Value>());
+ this->posits.assign(numTensors, std::vector<Value>());
+ this->coords.assign(numTensors, std::vector<Value>());
+ this->positionsBuffers.assign(numTensors, std::vector<Value>());
+ this->coordinatesBuffers.assign(numTensors, std::vector<Value>());
+ this->valBuffer.assign(numTensors, nullptr);
+ this->collapseReassoc.assign(numTensors, nullptr);
+ this->isSparseSlices.assign(numTensors, false);
+ this->sliceOffsets.assign(numTensors, std::vector<Value>());
+ this->sliceStrides.assign(numTensors, std::vector<Value>());
+
+ const LoopOrd numLoops = topSort.size();
+ // These zeros will be overwritten below, but we need to initialize
+ // them to something since we'll need random-access assignment.
+ this->loopIdToOrd.assign(numLoops, 0);
+ this->loopStack.reserve(numLoops);
+ this->loopSeqStack.reserve(numLoops);
+
+ // Initialize nested types of `TensorId`-indexed fields.
+ for (TensorId tid = 0; tid < numTensors; tid++) {
+ const Value t = tensors[tid];
// a scalar or 0-dimension tensors
if (isZeroRankedTensorOrScalar(t.getType()))
continue;
@@ -247,46 +258,51 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
collapseReassoc[tid] = reshape.getReassociation();
rtp = reshape.getSrcType();
// Overwrites the tensor to the source tensor of reshape operations.
- tensors[tid] = t = reshape.getSrc();
+ tensors[tid] = reshape.getSrc();
}
- auto rank = static_cast<size_t>(rtp.getRank());
- auto enc = getSparseTensorEncoding(rtp);
+ const SparseTensorType stt(rtp);
+ const Level lvlRank = stt.getLvlRank();
// We always treat sparse output tensor as dense so that we always iterate
- // it based on dim size.
- if (enc && !(isOutputTensor(tid) && isSparseOut)) {
+ // it based on lvl size.
+ if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+ const auto enc = stt.getEncoding();
isSparseSlices[tid] = enc.isSlice();
- for (auto dimTp : enc.getDimLevelType())
- dimTypes[tid].push_back(dimTp);
- } else
- dimTypes[tid].assign(rank, DimLevelType::Dense);
+ for (auto lvlTp : enc.getDimLevelType())
+ lvlTypes[tid].push_back(lvlTp);
+ } else {
+ lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
+ }
// Initialize using empty value.
- sliceOffsets[tid].assign(rank, Value());
- sliceStrides[tid].assign(rank, Value());
- pidxs[tid].assign(rank, Value());
- segHi[tid].assign(rank, Value());
- coord[tid].assign(rank, Value());
- highs[tid].assign(rank, Value());
- lvlSizes[tid].assign(rank, Value());
- posBuffer[tid].assign(rank, Value());
- crdBuffer[tid].assign(rank, Value());
+ lvlSizes[tid].assign(lvlRank, Value());
+ highs[tid].assign(lvlRank, Value());
+ segHi[tid].assign(lvlRank, Value());
+ posits[tid].assign(lvlRank, Value());
+ coords[tid].assign(lvlRank, Value());
+ positionsBuffers[tid].assign(lvlRank, Value());
+ coordinatesBuffers[tid].assign(lvlRank, Value());
+ sliceOffsets[tid].assign(lvlRank, Value());
+ sliceStrides[tid].assign(lvlRank, Value());
}
+ // Construct the inverse of the `topSort` from the sparsifier.
+ // This is needed to map `AffineDimExpr`s back to the `LoopOrd`
+ // used in loop emitter.
// FIXME: This map should be maintained outside loop emitter.
- for (unsigned i = 0, e = topSort.size(); i < e; i++) {
- // This is an inverse map of the topologically sorted loop index from
- // sparsifier. This is needed to map the AffineDimExpr back to the loopStack
- // index used in loop emitter.
- sparsiferLoopLvlMap[topSort[i]] = i;
- }
+ for (LoopOrd n = 0; n < numLoops; n++)
+ loopIdToOrd[topSort[n]] = n;
}
void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
LoopEmitter::OutputUpdater updater) {
- // For every tensor, find lower and upper bound on dimensions, set the
- // same bounds on loop indices, and obtain dense or sparse buffer(s).
- for (size_t t = 0, e = tensors.size(); t < e; t++) {
- const auto tensor = tensors[t];
+ // For every tensor:
+ // * get the values buffer.
+ // * For every level:
+ // * get the positions and coordinates buffers
+ // * get/compute the level-size, which is also used as the upper-bound
+ // on positions.
+ for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) {
+ const Value tensor = tensors[t];
const auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
if (!rtp)
// Skips only scalar, zero ranked tensor still need to be bufferized and
@@ -302,24 +318,27 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
// This should be called only once at beginning.
- assert(!posBuffer[t][l] && !crdBuffer[t][l] && !highs[t][l]);
- const auto dlt = dimTypes[t][l];
+ assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] &&
+ !highs[t][l]);
+ const auto lvlTp = lvlTypes[t][l];
// Handle sparse storage schemes.
- if (isCompressedDLT(dlt)) {
- // Generate sparse primitives to obtains positions and coordinates.
- posBuffer[t][l] = genToPositions(builder, loc, tensor, l);
- crdBuffer[t][l] = genToCoordinates(builder, loc, tensor, l, cooStart);
- } else if (isSingletonDLT(dlt)) {
+ if (isCompressedDLT(lvlTp)) {
+ // Generate sparse primitives to obtain positions and coordinates.
+ positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
+ coordinatesBuffers[t][l] =
+ genToCoordinates(builder, loc, tensor, l, cooStart);
+ } else if (isSingletonDLT(lvlTp)) {
// Singleton level, fetch coordinates.
- crdBuffer[t][l] = genToCoordinates(builder, loc, tensor, l, cooStart);
+ coordinatesBuffers[t][l] =
+ genToCoordinates(builder, loc, tensor, l, cooStart);
} else {
// Dense level, nothing to fetch.
- assert(isDenseDLT(dlt));
+ assert(isDenseDLT(lvlTp));
}
- // FIXME: `toOrigDim` is deprecated
- // Since we do not have HigherOrdering now, we can always rely on the 1:1
- // mapping from level to dimension to retrieve the level size.
+ // FIXME: `toOrigDim` is deprecated. For now this relies on the
+ // 1:1 mapping between levels and dimensions, since nowhere else
+ // in the code supports HigherOrdering yet either.
Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
toOrigDim(enc, l));
// Find upper bound in current dimension.
@@ -355,44 +374,49 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
valBuffer[t] = denseVal;
} else {
// Annotated sparse tensors.
- // We also need the value buffer for annotated all dense `sparse` tensor.
+ // We also need the value buffer for all-dense annotated "sparse" tensors.
valBuffer[t] = genToValues(builder, loc, tensor);
}
- // NOTE: we can also prepare for 0 dim here in advance, this will hosit
+ // NOTE: we can also prepare for 0 lvl here in advance, this will hoist
// some loop preparation from tensor iteration, but will also (undesirably)
- // hosit the code ouside if conditions.
+ // hoist the code ouside if-conditions.
}
}
void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims) {
+ ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls) {
// TODO: sort
assert(loopSeqStack.size() == loopStack.size());
// Universal Index starts from 0.
loopSeqStack.emplace_back(constantIndex(builder, loc, 0));
// Prepares for all the tensors used in the current loop sequence.
- for (auto [tid, dim] : llvm::zip(tids, dims))
- prepareLoopOverTensorAtDim(builder, loc, tid, dim);
+ assert(tids.size() == lvls.size());
+ for (auto [tid, lvl] : llvm::zip(tids, lvls))
+ prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
}
-Value LoopEmitter::genAffine(OpBuilder &builder, AffineExpr a, Location loc) {
+Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- unsigned idx = a.cast<AffineDimExpr>().getPosition();
- return loopStack[sparsiferLoopLvlMap[idx]].iv;
+ // FIXME: since the one callsite in Sparsification passes in a
+ // level-expression, the `getPosition` must in fact be a `Dimension`.
+ // However, elsewhere we have been lead to expect that `loopIdToOrd`
+ // should be indexed by `LoopId`...
+ const LoopId i = a.cast<AffineDimExpr>().getPosition();
+ return loopStack[loopIdToOrd[i]].iv;
}
case AffineExprKind::Add: {
auto binOp = a.cast<AffineBinaryOpExpr>();
return builder.create<arith::AddIOp>(
- loc, genAffine(builder, binOp.getLHS(), loc),
- genAffine(builder, binOp.getRHS(), loc));
+ loc, genAffine(builder, loc, binOp.getLHS()),
+ genAffine(builder, loc, binOp.getRHS()));
}
case AffineExprKind::Mul: {
auto binOp = a.cast<AffineBinaryOpExpr>();
return builder.create<arith::MulIOp>(
- loc, genAffine(builder, binOp.getLHS(), loc),
- genAffine(builder, binOp.getRHS(), loc));
+ loc, genAffine(builder, loc, binOp.getLHS()),
+ genAffine(builder, loc, binOp.getRHS()));
}
case AffineExprKind::Constant: {
int64_t c = a.cast<AffineConstantExpr>().getValue();
@@ -403,40 +427,44 @@ Value LoopEmitter::genAffine(OpBuilder &builder, AffineExpr a, Location loc) {
}
}
-Operation *LoopEmitter::enterLoopOverTensorAtDim(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
+Operation *LoopEmitter::enterLoopOverTensorAtLvl(
+ OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls, MutableArrayRef<Value> reduc, bool isParallel) {
// TODO: support multiple return on parallel for?
assert(!isParallel || reduc.size() <= 1);
bool isSparseInput = false;
- size_t tid = tids.front(), dim = dims.front();
- for (auto [t, d] : llvm::zip(tids, dims)) {
- assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair
- assert(!coord[t][d]); // We cannot re-enter the same level
- auto dimType = dimTypes[t][d];
- // Must be a recognizable DLT.
- assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
- isSingletonDLT(dimType));
- bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType);
+ TensorId tid = tids.front();
+ Level dstLvl = lvls.front();
+ assert(tids.size() == lvls.size());
+ for (auto [t, l] : llvm::zip(tids, lvls)) {
+ // TODO: this check for validity of the (t,l) pairs should be
+ // checked/enforced at the callsites, if possible.
+ assert(t < lvlTypes.size() && l < lvlTypes[t].size());
+ assert(!coords[t][l]); // We cannot re-enter the same level
+ const auto lvlTp = lvlTypes[t][l];
+ const bool isSparse = isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp);
+ // Must be a recognizable level-type.
+ assert(isSparse || isDenseDLT(lvlTp));
// We can at most have one sparse input, otherwise, a while loop is required
// to co-iterate multiple sparse tensors.
assert(!isSparseInput || !isSparse);
if (isSparse) {
tid = t;
- dim = d;
+ dstLvl = l;
}
isSparseInput = isSparseInput || isSparse;
}
- const auto reassoc = getCollapseReassociation(tid, dim);
+ const auto reassoc = getCollapseReassociation(tid, dstLvl);
// TODO: support dynamic slices.
- // Uses the first dimension here to build the loop bound (which is also the
- // biggest range).
- const auto fdim = reassoc.front();
- Value step = constantIndex(builder, loc, 1);
- Value lo = isSparseInput ? pidxs[tid][fdim] // current offset
- : loopSeqStack.back(); // universal index
- Value hi = highs[tid][fdim];
+ // Use the first source-level here to build the loop bound (which is
+ // also the biggest range).
+ const Level srcLvl = reassoc.front();
+ const Value step = constantIndex(builder, loc, 1);
+ /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
+ const Value lo = isSparseInput ? posits[tid][srcLvl] // current position
+ : loopSeqStack.back(); // universal index
+ const Value hi = highs[tid][srcLvl];
Operation *loop = nullptr;
Value iv;
@@ -450,7 +478,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
// In-place update on the reduction variable vector.
// Note that the init vals is not the actual reduction variables but instead
- // used as a `special handle` to (temporarily) represent them. The
+ // used as a "special handle" to (temporarily) represent them. The
// expression on init vals will be moved into scf.reduce and replaced with
// the block arguments when exiting the loop (see exitForLoop). This is
// needed as we can not build the actual reduction block and get the actual
@@ -475,9 +503,10 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
if (isSparseInput) {
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
// For COO, the position is the same across consecutive levels.
+ /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
llvm::for_each(reassoc,
- [this, tid, iv](Level lvl) { pidxs[tid][lvl] = iv; });
- crd = genSparseCrd(builder, loc, tid, dim);
+ [this, tid, iv](Level srcLvl) { posits[tid][srcLvl] = iv; });
+ crd = genSparseCrd(builder, loc, tid, dstLvl);
} else {
// Dense tensor, the coordinate is the inducation variable.
crd = iv;
@@ -490,7 +519,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
for (Value red : reduc)
types.push_back(red.getType());
- auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, dim);
+ auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, srcLvl);
bool hasReduc = !types.empty();
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
/*else*/ hasReduc);
@@ -512,35 +541,33 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
}
assert(crd);
- coord[tid][dim] = crd;
- // NOTE: we can also prepare for next dim here in advance
+ coords[tid][srcLvl] = crd;
+ // NOTE: we can also prepare for next level here in advance
// Push the loop into stack
- loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
- builder.getInsertionBlock(), coord[tid][dim], loopTag);
+ loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(srcLvl), loop,
+ builder.getInsertionBlock(), crd, loopTag);
// Emit extra locals.
- emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
+ emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
return loop;
}
-Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
- OpBuilder &builder, Location loc, size_t tid, size_t dim, AffineExpr affine,
- MutableArrayRef<Value> reduc) {
- assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(dimTypes[tid][dim]));
- assert(dimTypes[tid].size() > dim);
+Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
+ OpBuilder &builder, Location loc, TensorId tid, Level lvl,
+ AffineExpr affine, MutableArrayRef<Value> reduc) {
+ assert(tid < lvlTypes.size() && lvl < lvlTypes[tid].size());
+ assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(lvlTypes[tid][lvl]));
// We can not re-enter the same level.
- assert(!coord[tid][dim]);
-
- Value step = constantIndex(builder, loc, 1);
-
- Value lo = pidxs[tid][dim];
- Value hi = highs[tid][dim];
+ assert(!coords[tid][lvl]);
// TODO: We should instead use a whileOp for filter loop to allow early
- // break when exceeding (for ordered dimensions).
+ // break when exceeding (for ordered levels).
// TODO: There are many other potiential opportunities that we might apply in
- // the future. E.g., we could use binary search to located the position index.
- scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ // the future. E.g., we could use binary search to locate positions.
+ const Value step = constantIndex(builder, loc, 1);
+ const Value pLo = posits[tid][lvl];
+ const Value pHi = highs[tid][lvl];
+ scf::ForOp forOp = builder.create<scf::ForOp>(loc, pLo, pHi, step, reduc);
// In-place update on the reduction variable vector.
assert(forOp.getNumRegionIterArgs() == reduc.size());
@@ -548,18 +575,19 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
reduc[i] = forOp.getRegionIterArg(i);
builder.setInsertionPointToStart(forOp.getBody());
- Value iv = forOp.getInductionVar();
-
- pidxs[tid][dim] = iv;
- // Generating a load on the coordinates array yields the coordinate.
- Value mem = crdBuffer[tid][dim];
- coord[tid][dim] = genIndexLoad(builder, loc, mem, iv);
+ // The induction variable gives the position.
+ const Value pos = forOp.getInductionVar();
+ posits[tid][lvl] = pos;
+ // Generating a load on the coordinates array yields the crd.
+ const Value mem = coordinatesBuffers[tid][lvl];
+ const Value crd = genIndexLoad(builder, loc, mem, pos);
+ coords[tid][lvl] = crd;
// Generate an if-condition to filter out coordinates that are not
// equal to the result of the affine expression.
- Value expected = genAffine(builder, affine, loc);
- auto pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- coord[tid][dim], expected);
+ Value expected = genAffine(builder, loc, affine);
+ auto pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, crd,
+ expected);
SmallVector<Type> types;
for (Value red : reduc) {
types.push_back(red.getType());
@@ -583,43 +611,45 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
// Set the insert point to matched branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- // NOTE: we can also prepare for next dim here in advance
+ // NOTE: we can also prepare for next lvl here in advance
// Push the loop into stack
- loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
- builder.getInsertionBlock(), coord[tid][dim], nullptr);
+ loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl), forOp,
+ builder.getInsertionBlock(), crd, nullptr);
return forOp;
}
-void LoopEmitter::genDenseAffineAddressAtCurLevel(OpBuilder &builder,
- Location loc, size_t tid,
- size_t dim,
- AffineExpr affine) {
- Value affineV = genAffine(builder, affine, loc);
- pidxs[tid][dim] = genAddress(builder, loc, tid, dim, affineV);
+void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
+ TensorId tid, Level lvl,
+ AffineExpr lvlExpr) {
+ assert(isDenseDLT(lvlTypes[tid][lvl]));
+ // For dense levels, the level-coordinate also serves as the position.
+ Value lvlCrd = genAffine(builder, loc, lvlExpr);
+ posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd);
}
-Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc) {
- assert(tids.size() == dims.size());
+Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
+ OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls, bool needsUniv, MutableArrayRef<Value> reduc) {
+ assert(tids.size() == lvls.size());
SmallVector<Type> types;
SmallVector<Value> operands;
// Construct the while-loop with a parameter for each coordinate.
- Type indexType = builder.getIndexType();
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- const auto reassoc = getCollapseReassociation(tid, dim);
+ const Type indexType = builder.getIndexType();
+ for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+ const auto lvlTp = lvlTypes[tid][lvl];
+ if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
+ const auto reassoc = getCollapseReassociation(tid, lvl);
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- if (!isUniqueDLT(dimTypes[tid][reassoc[i]])) {
+ if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
// This is the segment high for each non-unique levels.
types.push_back(indexType);
operands.push_back(constantIndex(builder, loc, 0));
}
}
- assert(pidxs[tid][dim]);
+ const auto pos = posits[tid][reassoc.front()];
+ assert(pos);
types.push_back(indexType);
- operands.push_back(pidxs[tid][reassoc.front()]);
+ operands.push_back(pos);
}
}
// The position where user-supplied reduction variable starts.
@@ -644,14 +674,14 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
builder.setInsertionPointToStart(&whileOp.getBefore().front());
Value cond;
unsigned o = 0;
- for (auto [t, lvl] : llvm::zip(tids, dims)) {
+ for (auto [t, lvl] : llvm::zip(tids, lvls)) {
unsigned tid = t; // Why `t` can not be captured by lambda?
- if (isCompressedDLT(dimTypes[tid][lvl]) ||
- isSingletonDLT(dimTypes[tid][lvl])) {
+ const auto lvlTp = lvlTypes[tid][lvl];
+ if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
const auto reassoc = getCollapseReassociation(tid, lvl);
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- if (!isUniqueDLT(dimTypes[tid][reassoc[i]])) {
+ if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
// Links the SSA chain for segHi.
segHi[tid][reassoc[i]] = after->getArgument(o++);
}
@@ -665,8 +695,10 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// Update positions
Value pos = after->getArgument(o++);
// For COO, the position is the same across consecutive levels.
- llvm::for_each(reassoc,
- [this, tid, pos](Level lvl) { pidxs[tid][lvl] = pos; });
+ /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
+ llvm::for_each(reassoc, [this, tid, pos](Level srcLvl) {
+ posits[tid][srcLvl] = pos;
+ });
}
}
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
@@ -676,17 +708,17 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
SmallVector<std::pair<Value, unsigned>> slicesPreds;
unsigned i = 0;
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
// Prepares for next level.
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- coord[tid][dim] = genSparseCrd(builder, loc, tid, dim);
+ const auto lvlTp = lvlTypes[tid][lvl];
+ if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
+ coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
if (isSparseSlices[tid]) {
auto [trans, pred] =
- genSliceLegitPredicate(builder, loc, coord[tid][dim], tid, dim);
+ genSliceLegitPredicate(builder, loc, coords[tid][lvl], tid, lvl);
slicesPreds.emplace_back(pred, i);
// Updates to the relative coordinate to the slice.
- coord[tid][dim] = trans;
+ coords[tid][lvl] = trans;
}
i++;
}
@@ -696,14 +728,16 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// Skips invalid loop iteration when slice coordinate is inapplicable.
SmallVector<Value> yields(after->getArguments());
// Generates a list of if statments
- // pidx = in_slice ? pidx : pidx + 1
- // TODO: instead of always picking pidx + 1, we should set pidx = high to
- // break to loop if the coordinates is larger than the slice size.
+ // pos = in_slice ? pos : pos + 1
+ // TODO: instead of always picking pos + 1, we should set pos = high to
+ // break to loop if the coordinates are larger than the slice size.
+ //
+ // This "idx" is the index into `llvm::zip(tids, lvls)`
for (auto [pred, idx] : slicesPreds) {
- Value nextPidx = builder.create<arith::AddIOp>(
+ Value nextPos = builder.create<arith::AddIOp>(
loc, yields[idx], constantIndex(builder, loc, 1));
yields[idx] =
- builder.create<arith::SelectOp>(loc, pred, yields[idx], nextPidx);
+ builder.create<arith::SelectOp>(loc, pred, yields[idx], nextPos);
}
Value pred = slicesPreds.front().first;
@@ -726,31 +760,32 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
Value min;
// Finds the minimum coordinate
if (!needsUniv) {
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
+ for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+ const auto lvlTp = lvlTypes[tid][lvl];
+ if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
+ const auto crd = coords[tid][lvl];
if (min) {
Value cmp = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, coord[tid][dim], min);
- min = builder.create<arith::SelectOp>(loc, cmp, coord[tid][dim], min);
+ loc, arith::CmpIPredicate::ult, crd, min);
+ min = builder.create<arith::SelectOp>(loc, cmp, crd, min);
} else {
- min = coord[tid][dim];
+ min = crd;
}
}
}
} else {
assert(!min);
- // Otherwise, universal index is the minimal pidx.
+ // Otherwise, universal index is the minimal pos.
min = after->getArguments().back();
}
// Sets up the loop stack.
- loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min,
+ loopStack.emplace_back(tids, lvls, whileOp, builder.getInsertionBlock(), min,
loopTag);
assert(loopStack.size() == loopSeqStack.size());
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- const auto reassoc = getCollapseReassociation(tid, dim);
+ for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) {
+ const auto reassoc = getCollapseReassociation(tid, dstLvl);
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
// TODO: Refactors this into smaller functions.
// NOTE: For all the collapsed level (except for the last one, that is why
@@ -767,36 +802,37 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// segHi[lvl=2] = 1,
// the first iteration does not invalidate segHi[0] and segHi[1]
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- const auto lvl = reassoc[i];
- if (!isUniqueDLT(dimTypes[tid][lvl])) {
- Value pos = pidxs[tid][lvl];
- assert(segHi[tid][lvl]);
+ const Level srcLvl = reassoc[i];
+ if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+ const Value pos = posits[tid][srcLvl];
+ const auto oldSegHi = segHi[tid][srcLvl];
+ assert(oldSegHi);
Value newSegHi = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, pos, segHi[tid][lvl]);
- auto ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(),
- newSegHi, true);
+ loc, arith::CmpIPredicate::uge, pos, oldSegHi);
+ auto ifNewSegHi = builder.create<scf::IfOp>(loc, builder.getIndexType(),
+ newSegHi, true);
{
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(ifOp.thenBlock());
- builder.create<scf::YieldOp>(
- loc,
- genSegmentHigh(builder, loc, tid, lvl, pos, highs[tid][lvl]));
+ builder.setInsertionPointToStart(ifNewSegHi.thenBlock());
+ builder.create<scf::YieldOp>(loc,
+ genSegmentHigh(builder, loc, tid, srcLvl,
+ pos, highs[tid][srcLvl]));
// Else, resues the same segment high.
- builder.setInsertionPointToStart(ifOp.elseBlock());
- builder.create<scf::YieldOp>(loc, segHi[tid][lvl]);
+ builder.setInsertionPointToStart(ifNewSegHi.elseBlock());
+ builder.create<scf::YieldOp>(loc, oldSegHi);
}
- highs[tid][lvl + 1] = segHi[tid][lvl] = ifOp.getResult(0);
+ highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0);
}
};
- const auto lvl = reassoc.back();
- if (!isUniqueDLT(dimTypes[tid][lvl])) {
- segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl, pidxs[tid][lvl],
- highs[tid][lvl]);
+ const auto srcLvl = reassoc.back();
+ if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+ segHi[tid][srcLvl] = genSegmentHigh(
+ builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]);
}
}
// Emits extra locals
- emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
+ emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
// Updates reduction variables
assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
@@ -807,74 +843,75 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
return whileOp;
}
-void LoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc,
- size_t tid, size_t dim) {
- assert(dimTypes[tid].size() > dim);
- auto dimType = dimTypes[tid][dim];
+void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
+ TensorId tid, Level dstLvl) {
+ assert(tid < lvlTypes.size() && dstLvl < lvlTypes[tid].size());
+ const auto lvlTp = lvlTypes[tid][dstLvl];
- if (isDenseDLT(dimType))
+ if (isDenseDLT(lvlTp))
return;
- for (auto lvl : getCollapseReassociation(tid, dim)) {
+ const Value c0 = constantIndex(builder, loc, 0);
+ const Value c1 = constantIndex(builder, loc, 1);
+ for (const Level srcLvl : getCollapseReassociation(tid, dstLvl)) {
// Either the first level, or the previous level has been set.
- assert(lvl == 0 || pidxs[tid][lvl - 1]);
- Value c0 = constantIndex(builder, loc, 0);
- Value c1 = constantIndex(builder, loc, 1);
- if (isCompressedDLT(dimType)) {
- Value mem = posBuffer[tid][lvl];
+ /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
+ assert(srcLvl == 0 || posits[tid][srcLvl - 1]);
+ if (!isCompressedDLT(lvlTp) && !isSingletonDLT(lvlTp))
+ continue;
+ if (isCompressedDLT(lvlTp)) {
+ const Value mem = positionsBuffers[tid][srcLvl];
- Value pLo = lvl == 0 ? c0 : pidxs[tid][lvl - 1];
- pidxs[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
+ const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1];
+ posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo);
- Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
- highs[tid][lvl] = genIndexLoad(builder, loc, mem, pHi);
+ const Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
+ highs[tid][srcLvl] = genIndexLoad(builder, loc, mem, pHi);
return;
}
- if (isSingletonDLT(dimType)) {
- Value pLo = lvl == 0 ? c0 : pidxs[tid][lvl - 1];
- Value pHi;
- // If this is non-unique, the pHi is bound by the segment high of the
- // previous level.
- if (!isUniqueDLT(dimTypes[tid][lvl - 1]))
- pHi = segHi[tid][lvl - 1];
-
- // If pHi is still uninitialized, we set it to one as it is a singleton
- // level.
- // NOTE: Even if the level is non-unique, the pHi might not have been set
- // in the previous statement, as we only compute segment high when we are
- // coiterating non-unique levels.
- if (!pHi)
- pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
- pidxs[tid][lvl] = pLo;
- highs[tid][lvl] = pHi;
+ if (isSingletonDLT(lvlTp)) {
+ const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1];
+ posits[tid][srcLvl] = pLo;
+
+ // If we are coiterating non-unique levels, then use pHi=segHi;
+ // otherwise use pHi=pLo+1.
+ // NOTE: Just because the level is non-unique, that does not
+ // guarantee that segHi is defined: because we only generate segHi
+ // whenever coiterating, in order to improve code quality for the
+ // non-coiterating cases.
+ const auto theSegHi = segHi[tid][srcLvl - 1];
+ highs[tid][srcLvl] = (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && theSegHi)
+ ? theSegHi
+ : builder.create<arith::AddIOp>(loc, pLo, c1);
return;
}
}
- llvm_unreachable("Unrecognizable dimesion type!");
+ llvm_unreachable("Unrecognized level-type!");
}
-void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder,
+void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder,
Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims) {
+ ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls) {
// Initialize dense positions. Note that we generate dense coordinates of the
// output tensor unconditionally, since they may not appear in the lattice,
// but may be needed for linearized codegen.
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isDenseDLT(dimTypes[tid][dim])) {
+ assert(tids.size() == lvls.size());
+ for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+ if (isDenseDLT(lvlTypes[tid][lvl])) {
auto enc = getSparseTensorEncoding(tensors[tid].getType());
if (enc && !isSparseOutput(tid)) {
- bool validPidx = dim == 0 || pidxs[tid][dim - 1];
- if (!validPidx) {
- // We might not find the pidx for the sparse output tensor as it is
+ bool validPos = lvl == 0 || posits[tid][lvl - 1];
+ if (!validPos) {
+ // We might not find the pos for the sparse output tensor as it is
// unconditionally required by the sparsification.
assert(isOutputTensor(tid));
continue;
}
- pidxs[tid][dim] =
- genAddress(builder, loc, tid, dim, loopStack.back().iv);
- // NOTE: we can also prepare for next dim here in advance
+ posits[tid][lvl] =
+ genAddress(builder, loc, tid, lvl, loopStack.back().iv);
+ // NOTE: we can also prepare for next lvl here in advance
}
}
}
@@ -882,12 +919,9 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder,
void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
- LoopLevelInfo &loopInfo = loopStack.back();
+ const LoopInfo &loopInfo = loopStack.back();
rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
- auto &dims = loopStack.back().dims;
- auto &tids = loopStack.back().tids;
- auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
- if (forOp) {
+ if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
if (!reduc.empty()) {
assert(reduc.size() == forOp.getNumResults());
rewriter.create<scf::YieldOp>(loc, reduc);
@@ -918,6 +952,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
// One of the operands must be the init value (which is also the
// previous reduction value).
assert(curVal);
+#ifndef NDEBUG
// The reduction expression should be the only user of the reduction val
// inside the parallel for.
unsigned numUsers = 0;
@@ -926,7 +961,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
numUsers++;
}
assert(numUsers == 1);
- (void)numUsers; // to silence unused variable warning in release build
+#endif // NDEBUG
rewriter.setInsertionPointAfter(redExp);
auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
@@ -952,23 +987,21 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
// Finished iterating a tensor, clean up
// We only do the clean up on for loop as while loops do not necessarily
// finish the iteration on a sparse tensor
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ for (auto [tid, lvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
// Reset to null.
- coord[tid][dim] = Value();
- pidxs[tid][dim] = Value();
- // Dense dimension, high is fixed.
- if (!isDenseDLT(dimTypes[tid][dim]))
- highs[tid][dim] = Value();
+ coords[tid][lvl] = Value();
+ posits[tid][lvl] = Value();
+ // Dense level, high is fixed.
+ if (!isDenseDLT(lvlTypes[tid][lvl]))
+ highs[tid][lvl] = Value();
}
}
void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
- const LoopLevelInfo &loopInfo = loopStack.back();
+ const LoopInfo &loopInfo = loopStack.back();
auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
- auto &dims = loopInfo.dims;
- auto &tids = loopInfo.tids;
Value iv = loopInfo.iv;
// Finalize the induction. Note that the induction could be performed
// in the individual if-branches to avoid re-evaluating the conditions.
@@ -978,41 +1011,44 @@ void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
unsigned o = 0;
SmallVector<Value> operands;
Value one = constantIndex(builder, loc, 1);
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- const auto reassoc = getCollapseReassociation(tid, dim);
+ for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
+ const auto lvlTp = lvlTypes[tid][dstLvl];
+ if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
+ const auto reassoc = getCollapseReassociation(tid, dstLvl);
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- const auto lvl = reassoc[i];
- if (!isUniqueDLT(dimTypes[tid][lvl])) {
- operands.push_back(segHi[tid][lvl]);
+ const Level srcLvl = reassoc[i];
+ if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+ operands.push_back(segHi[tid][srcLvl]);
o++;
}
}
- Value op1 = coord[tid][dim];
- Value op3 = pidxs[tid][dim];
+ const Value crd = coords[tid][dstLvl];
+ const Value pos = posits[tid][dstLvl];
Value cmp =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1, iv);
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, crd, iv);
// If the loop contains a coiteration with non-unique level, we fast
// forward all the duplicated coords by setting the position to the
// segment high.
- // If this is a collapsed dim, we forward pidx based on the last level in
- // the collapsed level set.
- Value add = !isUniqueDLT(dimTypes[tid][reassoc.back()])
+ Value add = !isUniqueDLT(lvlTypes[tid][reassoc.back()])
? segHi[tid][reassoc.back()]
- : builder.create<arith::AddIOp>(loc, op3, one);
+ : builder.create<arith::AddIOp>(loc, pos, one);
- operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3));
+ operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, pos));
// Following loops continue iteration from the break point of the
// current while loop.
- Value pos = whileOp->getResult(o++);
- const auto t = tid;
- llvm::for_each(reassoc, [this, t, pos](Level l) { pidxs[t][l] = pos; });
- // The coordinates are invalid now.
- coord[tid][dim] = nullptr;
- // The segment high are invalid now
- segHi[tid][dim] = nullptr;
+ const Value newPos = whileOp->getResult(o++);
+ // We need to define a new local variable for `tid` to avoid
+ // warnings about "captured structured bindings are a C++20 extension".
+ // FIXME(wrengr): define a helper function to capture this idiom!
+ const TensorId newTid = tid;
+ llvm::for_each(reassoc, [this, newTid, newPos](Level srcLvl) {
+ posits[newTid][srcLvl] = newPos;
+ });
+ // The coordinate is invalid now.
+ coords[tid][dstLvl] = nullptr;
+ // The segment high is invalid now.
+ segHi[tid][dstLvl] = nullptr;
// highs remains unchanged.
}
}
@@ -1042,8 +1078,8 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
// Clean up the values, it would help use to discover potential bug at a
// earlier stage (instead of silently using a wrong value).
- LoopLevelInfo &loopInfo = loopStack.back();
- assert(loopInfo.tids.size() == loopInfo.dims.size());
+ const LoopInfo &loopInfo = loopStack.back();
+ assert(loopInfo.tids.size() == loopInfo.lvls.size());
SmallVector<Value> red;
if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
exitCoIterationLoop(rewriter, loc, reduc);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 8bc5da077c5c..8e6c65fd96c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -12,11 +12,36 @@
#include <vector>
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace sparse_tensor {
+//===----------------------------------------------------------------------===//
+/// The position of a loop in the loop-stack, or the position of a
+/// `LoopId` in a topologically-sorted list of `LoopId`s.
+///
+/// Although this type may have the same cardinality as `LoopId`, it must
+/// not be confused with that type. The `LoopId` type is used by the `Merger`
+/// as a unique identifier for loop-variables, regardless of the ordering
+/// of those loops. Whereas the `LoopOrd` type is used by the `LoopEmitter`
+/// (and `CodegenEnv`) to refer to the actual order in which loops are
+/// generated.
+///
+/// TODO: further explicate the correspondences between these various
+/// types. In particular, since the `$dim` argument to `linalg::IndexOp`
+/// is a De Bruijn index, it seems like that should correspond to `LoopOrd`,
+/// and yet the `Merger` has that correspond with `LoopId` instead.
+/// In addition `LoopEmitter::genAffine` has `AffineDimExpr::position`
+/// correspond to `LoopId`, however it is unclear what the providence
+/// of those `AffineDimExpr` is.
+//
+// TODO: use a struct/class rather than a typedef, so that we can actually
+// typecheck this to avoid mixups in the code.
+using LoopOrd = unsigned;
+
//===----------------------------------------------------------------------===//
// SparseTensorLoopEmiter class, manages sparse tensors and helps to
// generate loop structure to (co)-iterate sparse tensors.
@@ -33,13 +58,13 @@ namespace sparse_tensor {
//
// One can use
//
-// SparseTensorLoopEmiter loopEmiter({T1, T1});
+// LoopEmiter loopEmiter({T1, T1});
// loopEmiter.initializeLoopEmit();
-// loopEmiter.enterLoopOverTensorAtDim(T1, 0);
-// loopEmiter.enterLoopOverTensorAtDim(T2, 0);
-// loopEmiter.enterLoopOverTensorAtDim(T1, 1);
+// loopEmiter.enterLoopOverTensorAtLvl(T1, 0);
+// loopEmiter.enterLoopOverTensorAtLvl(T2, 0);
+// loopEmiter.enterLoopOverTensorAtLvl(T1, 1);
// loopEmiter.exitCurrentLoop();
-// loopEmiter.enterLoopOverTensorAtDim(T2, 1);
+// loopEmiter.enterLoopOverTensorAtLvl(T2, 1);
// loopEmiter.exitCurrentLoop(); // exit k
// loopEmiter.exitCurrentLoop(); // exit j
// loopEmiter.exitCurrentLoop(); // exit i
@@ -54,30 +79,31 @@ class LoopEmitter {
LoopEmitter() = default;
- /// Takes an array of tensors inputs, on which the generated loops will
- /// iterate on. The index of the tensor in the array is also the tensor id
- /// (tid) used in related functions. If isSparseOut is set, loop emitter
- /// assume that the sparse output tensor is empty, and will always generate
- /// loops on it based on the dim sizes. An optional array could be provided
- /// (by sparsification) to indicate the loop id sequence that will be
- /// generated. It is used to establish the mapping between affineDimExpr to
- /// the corresponding loop index in the loop stack that are maintained by the
- /// loop emitter.
+ /// Takes an array of input tensors, which the generated loops will
+ /// iterate over. Each tensor is given a `TensorId` (numerically equal
+ /// to the position of that tensor `Value` in the array). Setting
+ /// `isSparseOut` indicates that the sparse output tensor is empty,
+ /// so the loop emitter will generate loops over it according to the
+ /// level-sizes. The `topSort` array specifies the actual order in
+ /// which loops are generated, thus providing a mapping from `LoopOrd`
+ /// to `LoopId`.
void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<unsigned> topSort = {});
+ ArrayRef<LoopId> topSort = {});
explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<unsigned> topSort = {});
+ ArrayRef<LoopId> topSort = {});
- /// Starts a loop emitting session by generating all the buffers needed to
- /// iterate tensors.
+ /// Starts a loop emitting session by generating all the buffers needed
+ /// for iterating over the tensors.
void initializeLoopEmit(OpBuilder &builder, Location loc,
OutputUpdater updater = nullptr);
- /// Generates a list of operations to compute the affine expression.
- Value genAffine(OpBuilder &builder, AffineExpr a, Location loc);
+ /// Generates code to compute an affine expression whose variables are
+ /// `LoopId`s (i.e., `a.cast<AffineDimExpr>().getPosition()` is a valid
+ /// `LoopId`).
+ Value genAffine(OpBuilder &builder, Location loc, AffineExpr a);
/// Enters a new loop sequence, the loops within the same sequence starts
/// from the break points of previous loop instead of starting over from 0.
@@ -93,73 +119,77 @@ class LoopEmitter {
/// ...
/// // loop sequence end.
/// }
- void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims);
+ void enterNewLoopSeq(OpBuilder &builder, Location loc,
+ ArrayRef<TensorId> tids, ArrayRef<Level> lvls);
- // exit the current loop sequence, this will reset universal index to 0.
+ /// Exits the current loop sequence, this will reset universal index to 0.
void exitCurrentLoopSeq() {
assert(loopSeqStack.size() == loopStack.size() + 1);
loopSeqStack.pop_back();
}
- // TODO: Gets rid of `dim` in the argument list? Track the dimension we
- // are currently at internally. Then it would be enterNextDimForTensor.
- // Still need a way to specify the dim for non annoated dense tensor though,
- // as it can be accessed out of order.
- /// Emits loop over tensor_tid_dim, it assumes that loops between
- /// tensor_tid_[0, dim - 1] have already been generated.
+ // TODO: Get rid of `lvls` in the argument list? Track the level we
+ // are currently at internally. Then it would be enterNextLvlForTensor.
+ // Still need a way to specify the lvl for non-annotated tensors though,
+ // as those can be accessed out of order.
+ //
+ /// Emits loop over tensor_tid_lvl, it assumes that loops between
+ /// tensor_tid_[0, lvl - 1] have already been generated.
/// The function will also perform in-place update on the `reduc` vector to
/// return the reduction variable used inside the generated loop.
- Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims,
+ Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
+ ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls,
MutableArrayRef<Value> reduc = {},
bool isParallel = false);
- Operation *enterFilterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
- size_t tid, size_t dim,
+ Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
+ TensorId tid, Level lvl,
AffineExpr affine,
MutableArrayRef<Value> reduc = {});
- void genDenseAffineAddressAtCurLevel(OpBuilder &builder, Location loc,
- size_t tid, size_t dim,
- AffineExpr affine);
+ void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorId tid,
+ Level lvl, AffineExpr lvlExpr);
/// Emits a co-iteration loop over a set of tensors.
- Operation *enterCoIterationOverTensorsAtDims(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc = {});
+ Operation *enterCoIterationOverTensorsAtLvls(
+ OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls, bool needsUniv, MutableArrayRef<Value> reduc = {});
void exitCurrentLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc = {});
- /// Returns the array of coordinate for all the loop generated till now.
- void getCoordinateArray(SmallVectorImpl<Value> &coords) const {
+ /// Fills the out-parameter with the loop induction variables for all
+ /// loops in the current loop-stack. The variables are given in the
+ /// same order as the loop-stack, hence `ivs` should be indexed into
+ /// by `LoopOrd` (not `LoopId`).
+ void getLoopIVs(SmallVectorImpl<Value> &ivs) const {
+ ivs.clear();
+ ivs.reserve(getCurrentDepth());
for (auto &l : loopStack)
- coords.push_back(l.iv);
+ ivs.push_back(l.iv);
}
- /// Gets loop induction variable at the given level.
- unsigned getCurrentDepth() const { return loopStack.size(); }
+ /// Gets the current depth of the loop-stack. The result is given
+ /// the type `LoopOrd` for the same reason as one-past-the-end iterators.
+ LoopOrd getCurrentDepth() const { return loopStack.size(); }
- /// Gets loop induction variable at the given level.
- Value getLoopIV(size_t level) const {
- if (level < loopStack.size())
- return loopStack[level].iv;
- return nullptr;
+ /// Gets loop induction variable for the given `LoopOrd`.
+ Value getLoopIV(LoopOrd n) const {
+ return n < getCurrentDepth() ? loopStack[n].iv : Value();
}
///
/// Getters.
///
- const std::vector<std::vector<Value>> &getPidxs() const { return pidxs; };
- const std::vector<std::vector<Value>> &getCoord() const { return coord; };
+ const std::vector<std::vector<Value>> &getPosits() const { return posits; };
+ const std::vector<std::vector<Value>> &getCoords() const { return coords; };
const std::vector<std::vector<Value>> &getHighs() const { return highs; };
- const std::vector<std::vector<Value>> &getPosBuffer() const {
- return posBuffer;
+ const std::vector<std::vector<Value>> &getPositionBuffers() const {
+ return positionsBuffers;
};
- const std::vector<std::vector<Value>> &getCrdBuffer() const {
- return crdBuffer;
+ const std::vector<std::vector<Value>> &getCoordinateBuffers() const {
+ return coordinatesBuffers;
};
const std::vector<Value> &getValBuffer() const { return valBuffer; };
@@ -168,64 +198,74 @@ class LoopEmitter {
}
private:
- struct LoopLevelInfo {
- LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
- Block *userBlock, Value iv, StringAttr loopTag)
- : tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) {
+ struct LoopInfo {
+ LoopInfo(ArrayRef<TensorId> tids, ArrayRef<Level> lvls, Operation *loop,
+ Block *userBlock, Value iv, StringAttr loopTag)
+ : tids(tids), lvls(lvls), loop(loop), userCodeBlock(userBlock), iv(iv) {
// Attached a special tag to loop emitter generated loop.
if (loopTag)
loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
}
- // TODO: maybe use a vector<pair> for tid and dim?
+ // TODO: maybe use a vector<pair> for tid and lvl?
+ // (Better yet, compress them together a la `TensorLoopId`.)
// The set of tensors that the loop is operating on
- const llvm::SmallVector<size_t> tids;
- // The corresponding dims for the tensors
- const llvm::SmallVector<size_t> dims;
+ const llvm::SmallVector<TensorId> tids;
+ // The corresponding levels for the tensors
+ const llvm::SmallVector<Level> lvls;
const Operation *loop; // the loop operation
Block *const userCodeBlock; // the block holding users' generated code.
const Value iv; // the induction variable for the loop
};
- /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
- Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim,
+ /// Linearizes address for dense level (i.e., p = (i * d0) + j).
+ Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
Value iv);
/// Generates the segment high for a non-unique level (to fast forward
- /// duplicated coordinates).
- Value genSegmentHigh(OpBuilder &builder, Location loc, size_t tid, size_t lvl,
- Value pos, Value pHi);
+ /// duplicated coordinates). That is, it generates the code:
+ ///
+ /// crd = coordinates_tid_lvl[pos]
+ /// while (pos < pHi && coordinates_tid_lvl[pos] == crd)
+ /// pos++;
+ /// <return pos>;
+ Value genSegmentHigh(OpBuilder &builder, Location loc, TensorId tid,
+ Level lvl, Value pos, Value pHi);
/// Generates instructions to compute the coordinate of tensors[tid][lvl]
/// under the current loop context. The final argument is the
/// collapsed-output level, whereas this function handles converting
/// that to the uncollapsed-input level
- Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
- size_t dstLvl);
+ Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
+ Level dstLvl);
/// Generates a predicate to determine whether the tranformed coordinates are
/// in the given slice.
/// Returns std::pair<Transformed coordinates, Predicate>
std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
Location loc, Value crd,
- unsigned tid, unsigned lvl);
+ TensorId tid, Level lvl);
+
+ TensorId getNumTensors() const { return tensors.size(); }
- bool isOutputTensor(size_t tid) {
- return hasOutput && tid == tensors.size() - 1;
+ bool isOutputTensor(TensorId tid) const {
+ return hasOutput && tid == static_cast<TensorId>(getNumTensors() - 1);
}
- bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; }
+ bool isSparseOutput(TensorId tid) const {
+ return isOutputTensor(tid) && isSparseOut;
+ }
- /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0
- /// ...dims-1] has already been setup.
- void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid,
- size_t dim);
+ /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
+ /// that `tensor[0...lvl-1]` loops have already been set up.
+ void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
+ TensorId tid, Level lvl);
/// Emits extra locals, since the locals might not be in simplified lattices
- /// point used to generate the loops, but are still required to generates
+ /// point used to generate the loops, but are still required to generate
/// expressions.
- void emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder, Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims);
+ void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc,
+ ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls);
/// Exits a for loop, returns the reduction results, e.g.,
/// For sequential for loops:
@@ -258,6 +298,38 @@ class LoopEmitter {
void exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc);
+ //
+ // View-based-reshape methods.
+ //
+
+ /// Get the collapse reassociation for `tensors[tid][dstLvl]`.
+ /// For unreshaped operands, the reassociation is simply an identity
+ /// transformation.
+ ///
+ /// NOTE: the result uses `Level` rather than the `int64_t` of
+ /// `ReassociationIndices`, since the former gives clarity to what
+ /// the values actually mean.
+ ///
+ /// TODO: why not do this computation when we first store the reassoc,
+ /// instead of doing it every time we look it up?
+ SmallVector<Level, 2> getCollapseReassociation(TensorId tid, Level dstLvl) {
+ assert(tid < getNumTensors() && "Invalid TensorId");
+ assert(collapseReassoc.size() == getNumTensors());
+ if (const auto reassoc = collapseReassoc[tid]) {
+ // TODO: store the dstLvlRank in the LoopEmitter so that we can
+ // check `dstLvl < dstLvlRank` at the top; and only here need to
+ // assert that `reassoc.size() == dstLvlRank`.
+ assert(dstLvl < reassoc.size() && "Level is out-of-bounds");
+ const auto srcLvls = reassoc[dstLvl].cast<ArrayAttr>();
+ return llvm::to_vector<2>(
+ llvm::map_range(srcLvls, [&](Attribute srcLvl) -> Level {
+ // TODO: replace this with the converter for `LevelAttr`.
+ return srcLvl.cast<IntegerAttr>().getValue().getZExtValue();
+ }));
+ }
+ return {dstLvl};
+ }
+
/// A optional string attribute that should be attached to the loop
/// generated by loop emitter, it might help following passes to identify
/// loops that operates on sparse tensors more easily.
@@ -266,22 +338,41 @@ class LoopEmitter {
/// tensor.
bool hasOutput;
bool isSparseOut;
+
+ //
+ // Fields which have `numTensor` many entries.
+ //
+ // TODO: switch to an AOS style to avoid any possible mismatches.
+ //
+
/// Input and (optional) output tensors.
std::vector<Value> tensors;
- /// The dim type array for each tensor.
- std::vector<std::vector<DimLevelType>> dimTypes;
- /// Sparse iteration information (by tensor and dim). These arrays
- /// are updated to remain current within the current loop.
- // TODO: we may want to rename "pidx(s)" to `posCursor(s)` or similar.
- std::vector<std::vector<Value>> pidxs;
+ /// Level-types for each `(TensorId, Level)` pair.
+ std::vector<std::vector<DimLevelType>> lvlTypes;
+ // Sparse iteration information for each `(TensorId, Level)` pair.
+ // These arrays are updated to remain current within the current loop.
+ // TODO: Clarify which of these are indexed by dstLvl vs srcLvl.
+ //
+ /// The collection of positions for a given element (one such collection
+ /// for each tensor). This is the position analogue of the "coords"
+ /// naming convention.
+ ///
+ /// FIXME: [CLARIFY_POSITS_LVL] It's unclear which levels are used
+ /// to index the `posits` array. On the one hand `genSparseCrd`
+ /// uses dstLvl; on the other hand `enterLoopOverTensorAtLvl`,
+ /// `prepareLoopOverTensorAtLvl`, and `enterCoIterationOverTensorsAtLvls`
+ /// uses srcLvl. So which is it?
+ std::vector<std::vector<Value>> posits;
+ /// The collection of coordinates for a given element (one such
+ /// collection for each tensor).
+ std::vector<std::vector<Value>> coords;
// The segment upper bound for non-uniques level after de-duplication.
std::vector<std::vector<Value>> segHi;
- std::vector<std::vector<Value>> coord;
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> lvlSizes;
- std::vector<std::vector<Value>> posBuffer; // to_positions
- std::vector<std::vector<Value>> crdBuffer; // to_coordinates
- std::vector<Value> valBuffer; // to_value
+ std::vector<std::vector<Value>> positionsBuffers; // to_positions
+ std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+ std::vector<Value> valBuffer; // to_value
/// Whether the sparse input is a slice.
std::vector<bool> isSparseSlices;
@@ -289,44 +380,30 @@ class LoopEmitter {
std::vector<std::vector<Value>> sliceOffsets;
std::vector<std::vector<Value>> sliceStrides;
+ /// Collapse Reassociations related to a specific tensor
+ // TODO: support expand.
+ std::vector<ArrayAttr> collapseReassoc;
+
+ /// TODO: not yet used, it should track the current level for each tensor
+ /// to help eliminate `lvls` paramters from above APIs.
+ /// std::vector<Level> curLvl;
+
+ //
+ // Fields which have at most `numLoops` many entries.
+ //
+
/// Loop Stack, stores the information of all the nested loops that are
/// alive.
- std::vector<LoopLevelInfo> loopStack;
+ std::vector<LoopInfo> loopStack;
- /// Loop Sequence Stack, stores the unversial index for the current loop
+ /// Loop Sequence Stack, stores the universal index for the current loop
/// sequence.
std::vector<Value> loopSeqStack;
- /// Maps AffineDimExpr to the index of the loop in loopStack.
+ /// Maps `LoopId` (used by `AffineDimExpr`) to `LoopOrd` (in the `loopStack`).
/// TODO: We should probably use a callback function here to make it more
/// general.
- std::vector<unsigned> sparsiferLoopLvlMap;
-
- //
- // View based reshape related-fields and methods
- //
-
- /// Collapse Reassociations related to a specific tensor
- // TODO: support expand.
- std::vector<ArrayAttr> collapseReassoc;
-
- /// Get the collapse reassociation for tensors[tid] on l. For unreshaped
- /// operands, the reassociation is simply an identity transformation.
- SmallVector<int64_t, 2> getCollapseReassociation(unsigned tid, unsigned l) {
- // Returns for SmallVector<int64_t, 2> just like `ReassociaionIndices`
- if (auto reass = collapseReassoc[tid]) {
- auto attr = reass[l];
- return llvm::to_vector<2>(
- llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
- return indexAttr.cast<IntegerAttr>().getInt();
- }));
- }
- return {l};
- }
-
- /// TODO: not yet used, it should track the current level for each tensor
- /// to help eliminate `dim` paramters from above APIs.
- /// std::vector<size_t> curLv;
+ std::vector<LoopOrd> loopIdToOrd;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index d5e604d05c00..93a53ec877da 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -938,20 +938,21 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
ValueRange{input},
StringAttr::get(getContext(), ForeachOp::getOperationName()));
loopEmitter.initializeLoopEmit(rewriter, loc);
- for (Dimension d = 0; d < dimRank; d++) {
+ for (Level l = 0; l < lvlRank; l++) {
// TODO: provide utility function for loop sequences that only contains
// one for loop?
- const Level l = op.getOrder() ? op.getOrder()->getDimPosition(d) : d;
- loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast<size_t>(l));
+ // FIXME(wrengr): what is this "ld" supposed to be really?
+ const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l;
+ loopEmitter.enterNewLoopSeq(rewriter, loc, 0, ld);
// Note that reduc will be taken care of by loop emitter and get updated
// in place.
- loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, l, reduc);
+ loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, 0, l, reduc);
}
SmallVector<Value> lcvs;
lcvs.reserve(lvlRank);
- loopEmitter.getCoordinateArray(lcvs);
+ loopEmitter.getLoopIVs(lcvs);
if (op.getOrder()) {
// FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
@@ -962,10 +963,10 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
}
}
Value vals = loopEmitter.getValBuffer()[0];
- Value pidx = loopEmitter.getPidxs()[0].back();
+ Value pos = loopEmitter.getPosits()[0].back();
// Loads the value from sparse tensor using position-index;
// loads the value from dense tensor using coords.
- Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pidx)
+ Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
: rewriter.create<memref::LoadOp>(loc, vals, lcvs);
// 2. Inline the block in the foreach operator.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 29900e55b2bb..9fedd5a78658 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -99,7 +99,9 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
AffineDimExpr getDimExpr() const { return pickedDim.cast<AffineDimExpr>(); }
private:
- /// The picked AffineDimExpr after visit.
+ /// The picked AffineDimExpr after visit. This must be stored as
+ /// `AffineExpr` rather than `AffineDimExpr`, because the latter
+ /// doesn't have a default ctor.
AffineExpr pickedDim;
/// The iterator type that we want.
utils::IteratorType pickIterType;
@@ -113,20 +115,25 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
// Sparse compiler analysis methods.
//===----------------------------------------------------------------------===//
+// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
+// and those letters are too easy to confuse visually. We should switch
+// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
+// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
+
/// Determines if affine expression is invariant.
-static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
- unsigned ldx, bool &atLevel) {
+static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
+ LoopId ldx, bool &isAtLoop) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- unsigned idx = a.cast<AffineDimExpr>().getPosition();
- if (idx == ldx) {
- atLevel = true;
- // Must be invariant if we are at the level.
+ const LoopId i = a.cast<AffineDimExpr>().getPosition();
+ if (i == ldx) {
+ isAtLoop = true;
+ // Must be invariant if we are at the given loop.
return true;
}
bool isInvariant = false;
- for (unsigned loop : loopStack) {
- isInvariant = (loop == idx);
+ for (LoopId l : loopStack) {
+ isInvariant = (l == i);
if (isInvariant)
break;
}
@@ -135,8 +142,8 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
case AffineExprKind::Add:
case AffineExprKind::Mul: {
auto binOp = a.cast<AffineBinaryOpExpr>();
- return isInvariantAffine(binOp.getLHS(), loopStack, ldx, atLevel) &&
- isInvariantAffine(binOp.getRHS(), loopStack, ldx, atLevel);
+ return isInvariantAffine(binOp.getLHS(), loopStack, ldx, isAtLoop) &&
+ isInvariantAffine(binOp.getRHS(), loopStack, ldx, isAtLoop);
}
default: {
assert(a.isa<AffineConstantExpr>());
@@ -146,34 +153,42 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
}
/// Determines if affine expression is invariant.
-static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, unsigned ldx,
- bool &atLevel) {
- return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel);
+static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, LoopId ldx,
+ bool &isAtLoop) {
+ return isInvariantAffine(a, env.getCurrentLoopStack(), ldx, isAtLoop);
}
/// Helper method to construct a permuted dimension ordering
/// that adheres to the given topological sort.
+//
+// FIXME: does the above actually mean "dimensions", or should it say
+// "level ordering"? The same dim/lvl confusion applies to all the code
+// and comments in the definition below.
static AffineMap permute(CodegenEnv &env, AffineMap m) {
assert(m.getNumDims() + env.merger().getNumFilterLoops() ==
env.topSortSize() &&
"size mismatch");
// Construct the inverse of `m`; to avoid the asymptotic complexity
// of calling `m.getPermutedPosition` repeatedly.
+ //
+ // The variable `perm` must use `unsigned` rather than `Dimension`/`Level`,
+ // because that's what `AffineMap::getPermutationMap` requires.
+ // TODO: however, `perm` should be renamed to make clear what exactly
+ // it's storing a permutation of.
SmallVector<unsigned> perm;
- unsigned numResults = m.getNumResults();
+ const unsigned numResults = m.getNumResults();
BitVector worklist(numResults, true);
- unsigned loopDepth = 1;
+ LoopOrd loopDepth = 1;
// Construct the permutation.
while (worklist.any() && loopDepth <= env.topSortSize()) {
- unsigned preSize = perm.size();
- for (auto dim : worklist.set_bits()) {
- bool atLevel = false;
+ const unsigned preSize = perm.size();
+ for (unsigned dim : worklist.set_bits()) {
+ bool isAtLoop = false;
if (m.getResult(dim).isa<AffineConstantExpr>() ||
- (isInvariantAffine(m.getResult(dim),
- env.getTopSortSlice(0, loopDepth),
- env.topSortAt(loopDepth - 1), atLevel) &&
- atLevel)) {
+ (isInvariantAffine(m.getResult(dim), env.getLoopStackUpTo(loopDepth),
+ env.topSortAt(loopDepth - 1), isAtLoop) &&
+ isAtLoop)) {
// If the matching affine is constant expression or just become
// invariant. We can visit the dimension now without breaking the
// topSort constraint.
@@ -185,8 +200,8 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
for (unsigned i = preSize, e = perm.size(); i < e; i++)
worklist.reset(perm[i]);
- // Tries to entering the next loop level.
- loopDepth += 1;
+ // Try entering the next loop in the stack.
+ loopDepth++;
}
assert(perm.size() == numResults);
@@ -199,26 +214,26 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
/// filterIdx stores the current filter loop idx should be used for the next
/// compound affine sparse level, and it will be incremented by one when
/// used.
-static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
- AffineExpr a, DimLevelType dlt, unsigned &filterLdx,
+static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
+ DimLevelType dlt, LoopId &filterLdx,
bool setLvlFormat = true) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- unsigned idx = a.cast<AffineDimExpr>().getPosition();
- if (!isUndefDLT(merger.getDimLevelType(tensor, idx)))
+ const LoopId idx = a.cast<AffineDimExpr>().getPosition();
+ if (!isUndefDLT(merger.getDimLevelType(tid, idx)))
return false; // used more than once
if (setLvlFormat)
- merger.setDimAndDimLevelType(tensor, idx, dim, dlt);
+ merger.setLevelAndType(tid, idx, lvl, dlt);
return true;
}
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Constant: {
if (!isDenseDLT(dlt) && setLvlFormat) {
- assert(isUndefDLT(merger.getDimLevelType(tensor, filterLdx)));
+ assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx)));
// Use a filter loop for sparse affine expression.
- merger.setDimAndDimLevelType(tensor, filterLdx++, dim, dlt);
+ merger.setLevelAndType(tid, filterLdx++, lvl, dlt);
}
if (auto binOp = a.dyn_cast<AffineBinaryOpExpr>()) {
@@ -226,9 +241,9 @@ static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
// either loop index at d0 or d1.
// We continue the recursion merely to check whether current affine is
// admissible or not.
- return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, filterLdx,
+ return findAffine(merger, tid, lvl, binOp.getLHS(), dlt, filterLdx,
false) &&
- findAffine(merger, tensor, dim, binOp.getRHS(), dlt, filterLdx,
+ findAffine(merger, tid, lvl, binOp.getRHS(), dlt, filterLdx,
false);
}
// Falls through when it is a constant Affine
@@ -239,40 +254,61 @@ static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
}
}
-/// Get the total number of compound affine expressions in affineMap that are
-/// attached to the given tensor. For the following inputs:
+/// Get the total number of compound affine expressions in the
+/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
///
-/// affineMap = (d0, d1, d2) => (d0 + d1, d2)
-/// tensor = ["compressed", "compressed"]
+/// map = (d0, d1, d2) => (d0 + d1, d2)
+/// lvlTypes = ["compressed", "compressed"]
///
/// Returns 1 (because the first level is compressed and its corresponding
-/// affineMap is d0 + d1)
-static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap,
- Value tensor) {
+/// indexing-expression is `d0 + d1`)
+static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) {
+ // The `tensor` is not guaranted to have `RankedTensorType`, therefore
+ // we can't use `getRankedTensorType`/`getSparseTensorType` here.
+ // However, we don't need to handle `StorageSpecifierType`, so we
+ // can use `SparseTensorType` once we guard against non-tensors.
+ const auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
+ if (!rtp)
+ return 0;
+ const SparseTensorType stt(rtp);
+
+ // FIXME: There's some dim/lvl confusion here. The previous version of
+ // the code asserted that there are `lvlRank`-many expressions, but then
+ // the `exprs[d]` expression assumes there are in fact `dimRank`-many
+ // expressions. Even though `ArrayRef::operator[]` will check for OOB,
+ // the mismatch between the assertion and the usage belies that this code
+ // cannot support non-permutations.
+ //
+ // Elsewhere in this file the maps returned by
+ // `linalg::GenericOp::getMatchingIndexingMap` are inconsistent about
+ // whether they're expected to have `lvlRank`-many or `dimRank`-many
+ // expressions (cf., `genSubscript` vs `findSparseAnnotations`);
+ // so those are no help in determining which is actually intended.
+ //
+ // For now we work around this problem by asserting the two ranks agree.
+ const Dimension dimRank = stt.getDimRank();
+ const Level lvlRank = stt.getLvlRank();
+ assert(dimRank == lvlRank && "Non-permutations not currently supported");
+ const auto exprs = map.getResults();
+ assert(static_cast<Dimension>(exprs.size()) == dimRank &&
+ "AffineMap does not have dimension-rank many results");
+ (void)dimRank;
unsigned num = 0;
- const auto enc = getSparseTensorEncoding(tensor.getType());
- if (enc) {
- const ArrayRef<AffineExpr> exps = affineMap.getResults();
- const Level lvlRank = enc.getLvlRank();
- assert(static_cast<Level>(exps.size()) == lvlRank);
- for (Level l = 0; l < lvlRank; l++) {
- // FIXME: `toOrigDim` is deprecated.
- const Dimension d = toOrigDim(enc, l);
- // FIXME: there's some dim/lvl confusion here; since `d` isn't
- // guaranteed to be in bounds (for non-permutations).
- if (!exps[d].isa<AffineDimExpr>() && !enc.isDenseLvl(l))
- num++;
- }
+ for (Level l = 0; l < lvlRank; l++) {
+ // FIXME: `toOrigDim` is deprecated.
+ const Dimension d = toOrigDim(stt.getEncoding(), l);
+ if (!exprs[d].isa<AffineDimExpr>() && !stt.isDenseLvl(l))
+ num++;
}
return num;
}
-/// Get the total number of compound affine expressions attached on a sparse
-/// level in the given GenericOp.
-static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
+/// Get the total number of sparse levels with compound affine
+/// expressions, summed over all operands of the `GenericOp`.
+static unsigned getNumCompoundAffineOnSparseLvls(linalg::GenericOp op) {
unsigned num = 0;
for (OpOperand &t : op->getOpOperands())
- num += getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(&t),
+ num += getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(&t),
t.get());
return num;
}
@@ -281,7 +317,7 @@ static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) {
OpOperand *out = op.getDpsInitOperand(0);
if (getSparseTensorType(out->get()).isAllDense())
return false;
- return getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(out),
+ return getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(out),
out->get());
}
@@ -292,7 +328,8 @@ static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) {
/// no annotations are found or inadmissible constructs occur.
static bool findSparseAnnotations(CodegenEnv &env) {
bool annotated = false;
- unsigned filterLdx = env.merger().getFilterLoopStartingIdx();
+ // `filterLdx` may be mutated by `findAffine`.
+ LoopId filterLdx = env.merger().getStartingFilterLoopId();
for (OpOperand &t : env.op()->getOpOperands()) {
const auto map = env.op().getMatchingIndexingMap(&t);
const auto enc = getSparseTensorEncoding(t.get().getType());
@@ -302,10 +339,12 @@ static bool findSparseAnnotations(CodegenEnv &env) {
assert(!enc || lvlRank == enc.getLvlRank());
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
for (Level l = 0; l < lvlRank; l++) {
- const unsigned tensor = t.getOperandNumber();
+ const TensorId tid = t.getOperandNumber();
// FIXME: `toOrigDim` is deprecated.
+ // FIXME: above we asserted that there are `lvlRank` many results,
+ // but this is assuming there are in fact `dimRank` many results instead.
const AffineExpr a = map.getResult(toOrigDim(enc, l));
- if (!findAffine(env.merger(), tensor, l, a, enc.getLvlType(l), filterLdx))
+ if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx))
return false; // inadmissible affine expression
}
}
@@ -317,14 +356,18 @@ static bool findSparseAnnotations(CodegenEnv &env) {
/// as we use adj matrix for the graph.
/// The sorted result will put the first Reduction iterator to the
/// latest possible index.
-static bool topSortOptimal(CodegenEnv &env, unsigned n,
+/// FIXME(wrengr): correct the above "index"
+///
+/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
+/// `(LoopId,LoopId)`.
+static bool topSortOptimal(CodegenEnv &env, LoopId n,
ArrayRef<utils::IteratorType> iteratorTypes,
std::vector<unsigned> &inDegree,
std::vector<std::vector<bool>> &adjM) {
- std::vector<unsigned> redIt; // reduce iterator with 0 degree
- std::vector<unsigned> parIt; // parallel iterator with 0 degree
- std::vector<unsigned> filterIt; // filter loop with 0 degree
- for (unsigned i = 0; i < n; i++) {
+ std::vector<LoopId> redIt; // reduce iterator with 0 degree
+ std::vector<LoopId> parIt; // parallel iterator with 0 degree
+ std::vector<LoopId> filterIt; // filter loop with 0 degree
+ for (LoopId i = 0; i < n; i++) {
if (inDegree[i] == 0) {
if (env.merger().isFilterLoop(i))
filterIt.push_back(i);
@@ -360,7 +403,7 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
env.topSortPushBack(src);
it.pop_back();
// Update in-degree, and push 0-degree node into worklist.
- for (unsigned dst = 0; dst < n; dst++) {
+ for (LoopId dst = 0; dst < n; dst++) {
if (adjM[src][dst] && --inDegree[dst] == 0) {
if (env.merger().isFilterLoop(dst))
filterIt.push_back(dst);
@@ -381,14 +424,17 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
/// b = (i0 + i1) < fidx => i0 < fidx, i1 < fidx.
/// The affine expression `b` is empty iff `tidx` have a value, leading to
/// tidx < a = (i0 + i1) => tidx < i0, tidx < i1.
+///
+/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
+/// `(LoopId,LoopId)`.
static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
std::vector<unsigned> &inDegree, AffineExpr a,
- AffineExpr b, std::optional<unsigned> fidx,
- std::optional<unsigned> tidx) {
+ AffineExpr b, std::optional<LoopId> fidx,
+ std::optional<LoopId> tidx) {
if (!a && !b) {
// Recursion leaf.
assert(fidx && tidx);
- unsigned f = *fidx, t = *tidx;
+ const LoopId f = *fidx, t = *tidx;
if (!adjM[f][t]) {
adjM[f][t] = true;
inDegree[t]++;
@@ -396,10 +442,10 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
return;
}
// Picks an affine expression and expand (recurse into) it.
- auto toExpand = a ? a : b;
+ const auto toExpand = a ? a : b;
switch (toExpand.getKind()) {
case AffineExprKind::DimId: {
- auto idx = toExpand.cast<AffineDimExpr>().getPosition();
+ std::optional<LoopId> idx = toExpand.cast<AffineDimExpr>().getPosition();
if (toExpand == a)
addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx);
else // toExpand == b
@@ -424,9 +470,9 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
}
static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
- std::optional<unsigned> &fldx,
+ std::optional<LoopId> &fldx,
AffineExpr &fa,
- std::optional<unsigned> &tldx,
+ std::optional<LoopId> &tldx,
AffineExpr &ta) {
// We use a heuristic here to only pick one dim expression from each
// compound affine expression to establish the order between two dense
@@ -467,7 +513,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
OpOperand *skip = nullptr) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
- const unsigned n = env.merger().getNumLoops();
+ const LoopId n = env.merger().getNumLoops();
std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
const auto iteratorTypes = env.op().getIteratorTypesArray();
@@ -476,7 +522,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
// Get map and encoding.
const auto map = env.op().getMatchingIndexingMap(&t);
const auto enc = getSparseTensorEncoding(t.get().getType());
- assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n);
+ assert(map.getNumDims() + getNumCompoundAffineOnSparseLvls(env.op()) == n);
// Skips dense inputs/outputs when not requested.
const bool isDenseInput = !enc && env.op().isDpsInput(&t);
@@ -489,18 +535,17 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
// will be skipped more often.
// TODO: Do we really need this?
if (includesUndef(mask)) {
- unsigned tensor = t.getOperandNumber();
- for (unsigned i = 0; i < n; i++) {
- if (isCompressedDLT(env.dlt(tensor, i)) ||
- isSingletonDLT(env.dlt(tensor, i))) {
- for (unsigned j = 0; j < n; j++)
+ const TensorId tensor = t.getOperandNumber();
+ for (LoopId i = 0; i < n; i++) {
+ const auto dltI = env.dlt(tensor, i);
+ if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) {
+ for (LoopId j = 0; j < n; j++)
if (isUndefDLT(env.dlt(tensor, j))) {
adjM[i][j] = true;
inDegree[j]++;
}
} else {
- assert(isDenseDLT(env.dlt(tensor, i)) ||
- isUndefDLT(env.dlt(tensor, i)));
+ assert(isDenseDLT(dltI) || isUndefDLT(dltI));
}
}
}
@@ -513,9 +558,11 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
assert(!enc || lvlRank == enc.getLvlRank());
for (Level l = 0; l < lvlRank; l++) {
// FIXME: `toOrigDim` is deprecated.
+ // FIXME: above we asserted that there are `lvlRank` many results,
+ // but this is assuming there are in fact `dimRank` many results instead.
AffineExpr ta = map.getResult(toOrigDim(enc, l));
- std::optional<unsigned> tldx =
- env.merger().getLoopIdx(t.getOperandNumber(), l);
+ std::optional<LoopId> tldx =
+ env.merger().getLoopId(t.getOperandNumber(), l);
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
@@ -537,9 +584,11 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
if (l > 0) {
// FIXME: `toOrigDim` is deprecated.
+ // FIXME: above we asserted that there are `lvlRank` many results,
+ // but this is assuming there are in fact `dimRank` many results.
AffineExpr fa = map.getResult(toOrigDim(enc, l - 1));
- std::optional<unsigned> fldx =
- env.merger().getLoopIdx(t.getOperandNumber(), l - 1);
+ std::optional<LoopId> fldx =
+ env.merger().getLoopId(t.getOperandNumber(), l - 1);
// Applying order constraints on every pair of dimExpr between two
// compound affine expressions can sometime too strict:
@@ -620,32 +669,37 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
// FIXME: `toOrigDim` is deprecated.
+ // FIXME: above we asserted that there are `lvlRank` many results,
+ // but this is assuming there are in fact `dimRank` many results instead.
AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1));
assert(a.getKind() == AffineExprKind::DimId);
- unsigned idx = a.cast<AffineDimExpr>().getPosition();
- return env.getLoopIdxValue(idx);
+ const LoopId idx = a.cast<AffineDimExpr>().getPosition();
+ return env.getLoopVar(idx);
}
/// Generates subscript for load/store on a dense or sparse tensor.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVectorImpl<Value> &args) {
- linalg::GenericOp op = env.op();
- unsigned tensor = t->getOperandNumber();
- auto map = op.getMatchingIndexingMap(t);
+ const Location loc = env.op().getLoc();
+ const TensorId tid = t->getOperandNumber();
+ const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
if (stt.hasEncoding()) {
- Value pidx = env.emitter().getPidxs()[tensor].back();
- assert(pidx);
- args.push_back(pidx); // position index
+ // For sparse tensors we only push the last-level's position onto `args`.
+ const auto pos = env.emitter().getPosits()[tid].back();
+ assert(pos);
+ args.push_back(pos);
} else {
+ // For dense tensors we push all level's coordinates onto `args`.
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
for (Level l = 0; l < lvlRank; l++) {
- AffineExpr a = map.getResult(l);
- args.push_back(env.emitter().genAffine(builder, a, op.getLoc()));
+ const auto lvlExpr = map.getResult(l);
+ const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
+ args.push_back(lvlCrd);
}
}
- return env.emitter().getValBuffer()[tensor];
+ return env.emitter().getValBuffer()[tid];
}
/// Generates insertion code to implement dynamic tensor load.
@@ -688,19 +742,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
Location loc = op.getLoc();
// Direct insertion in lexicographic coordinate order.
if (!env.isExpand()) {
- unsigned rank = op.getRank(t);
- // FIXME: It's not entirely clear what "indices" means here (i.e.,
- // are they "coordinates"? and if so, then are they level-coords or
- // dim-coords?)
- SmallVector<Value> indices;
- for (unsigned i = 0; i < rank; i++) {
- assert(env.emitter().getLoopIV(i));
- indices.push_back(env.emitter().getLoopIV(i));
+ const LoopOrd numLoops = op.getRank(t);
+ // TODO: rewrite this to use `env.emitter().getLoopIVs(ivs)`
+ // instead. We just need to either assert that `numLoops ==
+ // env.emitter().getCurrentDepth()`, or else update the `getLoopIVs`
+ // method to take an optional parameter to restrict to a smaller depth.
+ SmallVector<Value> ivs;
+ ivs.reserve(numLoops);
+ for (LoopOrd n = 0; n < numLoops; n++) {
+ const auto iv = env.emitter().getLoopIV(n);
+ assert(iv);
+ ivs.push_back(iv);
}
Value chain = env.getInsertionChain();
if (!env.getValidLexInsert()) {
- env.updateInsertionChain(
- builder.create<InsertOp>(loc, rhs, chain, indices));
+ env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
} else {
// Generates runtime check for a valid lex during reduction,
// to avoid inserting the identity value for empty reductions.
@@ -714,7 +770,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
/*else=*/true);
// True branch.
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
- Value res = builder.create<InsertOp>(loc, rhs, chain, indices);
+ Value res = builder.create<InsertOp>(loc, rhs, chain, ivs);
builder.create<scf::YieldOp>(loc, res);
// False branch.
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
@@ -761,7 +817,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
}
/// Generates a load on a dense or sparse tensor.
-static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) {
+static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
// Test if the load was hoisted to a higher loop nest.
Value val = env.exp(exp).val;
if (val)
@@ -782,7 +838,7 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) {
}
/// Generates a store on a dense or sparse tensor.
-static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
+static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
Value rhs) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
@@ -830,7 +886,7 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
}
/// Generates an invariant value.
-inline static Value genInvariantValue(CodegenEnv &env, unsigned exp) {
+inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
return env.exp(exp).val;
}
@@ -840,10 +896,10 @@ inline static Value genInvariantValue(CodegenEnv &env, unsigned exp) {
/// exception of index computations, which need to be relinked to actual
/// inlined cloned code.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
- Value e, unsigned ldx) {
+ Value e, LoopId ldx) {
if (Operation *def = e.getDefiningOp()) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return env.getLoopIdxValue(indexOp.getDim());
+ return env.getLoopVar(indexOp.getDim());
if (def->getBlock() == block) {
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
@@ -857,52 +913,52 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
}
/// Recursively generates tensor expression.
-static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
- unsigned ldx) {
+static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
+ LoopId ldx) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
- if (exp == -1u)
+ if (e == kInvalidId)
return Value();
- if (env.exp(exp).kind == Kind::kTensor)
- return genTensorLoad(env, rewriter, exp);
- if (env.exp(exp).kind == Kind::kInvariant)
- return genInvariantValue(env, exp);
- if (env.exp(exp).kind == Kind::kIndex)
- return env.getLoopIdxValue(env.exp(exp).index);
-
- if (env.exp(exp).kind == Kind::kReduce)
- env.startCustomReduc(exp); // enter custom
-
- Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx);
- Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx);
- Value ee = env.merger().buildExp(rewriter, loc, exp, v0, v1);
- if (ee && (env.exp(exp).kind == Kind::kUnary ||
- env.exp(exp).kind == Kind::kBinary ||
- env.exp(exp).kind == Kind::kBinaryBranch ||
- env.exp(exp).kind == Kind::kReduce ||
- env.exp(exp).kind == Kind::kSelect))
+ const TensorExp &exp = env.exp(e);
+ const auto kind = exp.kind;
+ if (kind == Kind::kTensor)
+ return genTensorLoad(env, rewriter, e);
+ if (kind == Kind::kInvariant)
+ return genInvariantValue(env, e);
+ if (kind == Kind::kLoopVar)
+ return env.getLoopVar(exp.loop);
+
+ if (kind == Kind::kReduce)
+ env.startCustomReduc(e); // enter custom
+
+ Value v0 = genExp(env, rewriter, exp.children.e0, ldx);
+ Value v1 = genExp(env, rewriter, exp.children.e1, ldx);
+ Value ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
+ if (ee && (kind == Kind::kUnary || kind == Kind::kBinary ||
+ kind == Kind::kBinaryBranch || kind == Kind::kReduce ||
+ kind == Kind::kSelect))
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
- if (env.exp(exp).kind == Kind::kReduce)
+ if (kind == Kind::kReduce)
env.endCustomReduc(); // exit custom
- if (env.exp(exp).kind == kSelect) {
- assert(!env.exp(exp).val);
- env.exp(exp).val = v0; // Preserve value for later use.
+ if (kind == kSelect) {
+ assert(!exp.val);
+ env.exp(e).val = v0; // Preserve value for later use.
}
return ee;
}
/// Hoists loop invariant tensor loads for which indices have been exhausted.
-static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
- unsigned ldx, bool atStart) {
- if (exp == -1u)
+static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopId ldx, bool atStart) {
+ if (exp == kInvalidId)
return;
if (env.exp(exp).kind == Kind::kTensor) {
// Inspect tensor indices.
- bool atLevel = ldx == -1u;
+ bool isAtLoop = ldx == kInvalidId;
linalg::GenericOp op = env.op();
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
auto map = op.getMatchingIndexingMap(&t);
@@ -911,20 +967,21 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
for (Level l = 0; l < lvlRank; l++) {
// FIXME: `toOrigDim` is deprecated.
+ // FIXME: above we asserted that there are `lvlRank` many results,
+ // but this is assuming there are in fact `dimRank` many results instead.
AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l));
- std::optional<unsigned> sldx =
- env.merger().getLoopIdx(t.getOperandNumber(), l);
+ const auto sldx = env.merger().getLoopId(t.getOperandNumber(), l);
if (sldx && env.merger().isFilterLoop(*sldx)) {
- if (!env.getLoopIdxValue(*sldx))
+ if (!env.getLoopVar(*sldx))
// The filter loops has not been constructed.
return;
if (*sldx == ldx)
- atLevel = true;
- } else if (!isInvariantAffine(env, a, ldx, atLevel))
+ isAtLoop = true;
+ } else if (!isInvariantAffine(env, a, ldx, isAtLoop))
return; // still in play
}
- // All exhausted at this level (atLevel denotes exactly at this level).
- if (!atLevel)
+ // All exhausted at this level (isAtLoop denotes exactly at this LoopId).
+ if (!isAtLoop)
return;
OpOperand *lhs = op.getDpsInitOperand(0);
if (lhs == &t) {
@@ -944,14 +1001,14 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
}
} else if (env.exp(exp).kind != Kind::kInvariant &&
- env.exp(exp).kind != Kind::kIndex) {
+ env.exp(exp).kind != Kind::kLoopVar) {
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
if (env.exp(exp).kind == Kind::kReduce)
env.startCustomReduc(exp); // enter custom
- unsigned e0 = env.exp(exp).children.e0;
- unsigned e1 = env.exp(exp).children.e1;
+ const ExprId e0 = env.exp(exp).children.e0;
+ const ExprId e1 = env.exp(exp).children.e1;
genInvariants(env, builder, e0, ldx, atStart);
genInvariants(env, builder, e1, ldx, atStart);
if (env.exp(exp).kind == Kind::kReduce)
@@ -960,7 +1017,7 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
}
/// Generates an expanded access pattern in innermost dimension.
-static void genExpand(CodegenEnv &env, OpBuilder &builder, unsigned at,
+static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
bool atStart) {
linalg::GenericOp op = env.op();
OpOperand *lhs = op.getDpsInitOperand(0);
@@ -987,7 +1044,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, unsigned at,
r.getResult(3));
} else {
SmallVector<Value> indices;
- for (unsigned i = 0; i < at; i++)
+ for (LoopOrd i = 0; i < at; i++)
indices.push_back(env.emitter().getLoopIV(i));
Value values = env.getExpandValues();
Value filled = env.getExpandFilled();
@@ -1029,34 +1086,35 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
/// Generates a for-loop on a single index.
static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
- bool isInner, unsigned idx, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims) {
+ bool isInner, LoopId ldx, ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
auto iteratorTypes = op.getIteratorTypesArray();
- bool isSparse = llvm::any_of(tids, [idx, &env](size_t tid) {
- return isCompressedDLT(env.dlt(tid, idx)) ||
- isSingletonDLT(env.dlt(tid, idx));
+ bool isSparse = llvm::any_of(tids, [ldx, &env](TensorId tid) {
+ const auto dlt = env.dlt(tid, ldx);
+ return isCompressedDLT(dlt) || isSingletonDLT(dlt);
});
bool isParallel = isParallelFor(env, isOuter, isSparse);
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
- if (env.merger().isFilterLoop(idx)) {
- size_t tid = tids.front(), dim = dims.front();
- // tids/dims must only have one value because filter loops only
+ if (env.merger().isFilterLoop(ldx)) {
+ const TensorId tid = tids.front();
+ const Level lvl = lvls.front();
+ // tids/lvls must only have one value because filter loops only
// corresponding to the one and only sparse tensor level.
- assert(isSparse && tids.size() == 1 && dims.size() == 1);
+ assert(isSparse && tids.size() == 1 && lvls.size() == 1);
OpOperand *t = &op->getOpOperand(tid);
auto enc = getSparseTensorEncoding(t->get().getType());
// Retrieves the affine expression for the filter loop.
// FIXME: `toOrigDim` is deprecated.
AffineExpr a =
- op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
- return env.emitter().enterFilterLoopOverTensorAtDim(builder, loc, tid,
- dim, a, reduc);
+ op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
+ return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid,
+ lvl, a, reduc);
}
- return env.emitter().enterLoopOverTensorAtDim(builder, loc, tids, dims,
+ return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tids, lvls,
reduc, isParallel);
});
assert(loop);
@@ -1064,14 +1122,14 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
}
/// Emit a while-loop for co-iteration over multiple indices.
-static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
- bool needsUniv, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims) {
+static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
+ bool needsUniv, ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls) {
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
// Construct the while-loop with a parameter for each
// index.
- return env.emitter().enterCoIterationOverTensorsAtDims(
- builder, env.op().getLoc(), tids, dims, needsUniv, reduc);
+ return env.emitter().enterCoIterationOverTensorsAtLvls(
+ builder, env.op().getLoc(), tids, lvls, needsUniv, reduc);
});
assert(loop);
return loop;
@@ -1079,21 +1137,21 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
-static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
- bool needsUniv, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, bool isFor) {
- assert(tids.size() == dims.size());
- unsigned idx = env.topSortAt(at);
+static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
+ bool needsUniv, ArrayRef<TensorId> tids,
+ ArrayRef<Level> lvls, bool isFor) {
+ assert(tids.size() == lvls.size());
+ const LoopId idx = env.topSortAt(at);
if (isFor) {
bool isOuter = at == 0;
bool isInner = at == env.topSortSize() - 1;
- return genFor(env, builder, isOuter, isInner, idx, tids, dims);
+ return genFor(env, builder, isOuter, isInner, idx, tids, lvls);
}
- return genWhile(env, builder, idx, needsUniv, tids, dims);
+ return genWhile(env, builder, idx, needsUniv, tids, lvls);
}
/// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
+static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
bool needsUniv, BitVector &induction,
scf::WhileOp whileOp) {
Location loc = env.op().getLoc();
@@ -1133,26 +1191,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
}
/// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx,
+static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
BitVector &conditions) {
Location loc = env.op().getLoc();
SmallVector<Type> types;
Value cond;
- for (unsigned b = 0, be = conditions.size(); b < be; b++) {
+ for (TensorLoopId b = 0, be = conditions.size(); b < be; b++) {
if (!conditions[b])
continue;
- unsigned tensor = env.merger().tensor(b);
- assert(idx == env.merger().index(b));
+ const TensorId tid = env.merger().tensor(b);
+ assert(ldx == env.merger().loop(b));
Value clause;
- if (isCompressedDLT(env.dlt(b)) || isSingletonDLT(env.dlt(b))) {
- auto dim = *env.merger().getDimNum(tensor, idx);
- Value op1 = env.emitter().getCoord()[tensor][dim];
- Value op2 = env.getLoopIdxValue(idx);
- clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
- op2);
+ const auto dlt = env.dlt(b);
+ if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
+ const Level lvl = *env.merger().getLvl(tid, ldx);
+ const Value crd = env.emitter().getCoords()[tid][lvl];
+ const Value lvar = env.getLoopVar(ldx);
+ clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, crd,
+ lvar);
} else {
- assert(isDenseDLT(env.merger().getDimLevelType(b)) ||
- isUndefDLT(env.merger().getDimLevelType(b)));
+ assert(isDenseDLT(dlt) || isUndefDLT(dlt));
clause = constantI1(builder, loc, true);
}
cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
@@ -1202,41 +1260,40 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
- unsigned at, unsigned idx, unsigned ldx,
- unsigned lts) {
- assert(!env.getLoopIdxValue(idx));
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopOrd at, LoopId idx, LoopId ldx, LatSetId lts) {
+ assert(!env.getLoopVar(idx));
// Emit invariants at this loop sequence level.
genInvariants(env, builder, exp, ldx, /*atStart=*/true);
// Emit access pattern expansion for sparse tensor output.
genExpand(env, builder, at, /*atStart=*/true);
// Emit further intitialization at this loop sequence level.
- unsigned l0 = env.set(lts)[0];
+ const LatPointId l0 = env.set(lts)[0];
bool needsUniv = false;
- SmallVector<size_t> tids;
- SmallVector<size_t> dims;
- env.merger().foreachTidDimPairInBits(
- env.lat(l0).bits, [&](unsigned b, unsigned tid,
- std::optional<unsigned> dim, DimLevelType dlt) {
- assert(env.merger().index(b) == idx);
+ SmallVector<TensorId> tids;
+ SmallVector<Level> lvls;
+ env.merger().foreachTensorLoopId(
+ env.lat(l0).bits, [&](TensorLoopId b, TensorId tid,
+ std::optional<Level> lvl, DimLevelType dlt) {
+ assert(env.merger().loop(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
} else {
- // sparse/singleton dim levels.
+ // sparse/singleton levels.
tids.push_back(tid);
- dims.push_back(*dim);
+ lvls.push_back(*lvl);
}
});
- env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, dims);
+ env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls);
// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
if (needsUniv) {
unsigned lsize = env.set(lts).size();
for (unsigned i = 1; i < lsize; i++) {
- unsigned li = env.set(lts)[i];
+ const LatPointId li = env.set(lts)[i];
if (!env.merger().hasAnySparse(env.lat(li).simple))
return true;
}
@@ -1245,23 +1302,25 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
}
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
- OpBuilder &builder, unsigned tid,
- Level lvl) {
+ OpBuilder &builder, TensorId tid,
+ Level startLvl) {
// TODO: Handle affine expression on output tensor.
linalg::GenericOp op = env.op();
assert(tid < op.getNumDpsInputs());
OpOperand *input = op.getDpsInputOperands()[tid];
- ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
+ const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
const auto enc = getSparseTensorEncoding(input->get().getType());
if (enc) {
+ const Location loc = op.getLoc();
+ const TensorId tid = input->getOperandNumber();
const Level lvlRank = enc.getLvlRank();
- assert(affines.size() == static_cast<size_t>(lvlRank));
- for (Level l = lvl; l < lvlRank; l++) {
+ assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+ // FIXME: there is dim/lvl confusion here
+ for (Level l = startLvl; l < lvlRank; l++) {
// FIXME: `toOrigDim` is deprecated.
- AffineExpr affine = affines[toOrigDim(enc, l)];
- if (enc.isDenseLvl(l) && affine.isa<AffineConstantExpr>())
- env.emitter().genDenseAffineAddressAtCurLevel(
- builder, op.getLoc(), input->getOperandNumber(), l, affine);
+ AffineExpr lvlExpr = lvlExprs[toOrigDim(enc, l)];
+ if (enc.isDenseLvl(l) && lvlExpr.isa<AffineConstantExpr>())
+ env.emitter().genDenseAffineAddress(builder, loc, tid, l, lvlExpr);
else
return; // break on first non-dense non-constant level
}
@@ -1274,45 +1333,46 @@ static void genInitConstantDenseAddress(CodegenEnv &env,
// starting from the first level as they do not depend on any thing.
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
// levels can be determined before loops.
- for (unsigned tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+ for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}
/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidDimPairs(
- CodegenEnv &env, unsigned li, unsigned idx, SmallVectorImpl<size_t> &tids,
- SmallVectorImpl<size_t> &dims, SmallVectorImpl<size_t> &affineTids,
- SmallVectorImpl<size_t> &affineDims, SmallVectorImpl<AffineExpr> &exps) {
+static bool translateBitsToTidLvlPairs(
+ CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl<TensorId> &tids,
+ SmallVectorImpl<Level> &lvls, SmallVectorImpl<TensorId> &affineTids,
+ SmallVectorImpl<Level> &affineLvls, SmallVectorImpl<AffineExpr> &exps) {
const BitVector &all = env.lat(li).bits;
const BitVector &simple = env.lat(li).simple;
+ const TensorId outTid = env.merger().getOutTensorID();
+ const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
unsigned numloopCond = 0;
bool hasNonUnique = false;
- // Converts bits to array + dim pair
- env.merger().foreachTidDimPairInBits(
- all, [&, idx](unsigned b, unsigned tid, std::optional<unsigned> dim,
+ env.merger().foreachTensorLoopId(
+ all, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
DimLevelType dlt) {
if (simple.test(b)) {
if (isUndefDLT(dlt)) {
- // An undefined dlt in the lattices, we probably mean to iterate
- // based on the dim of output tensor.
- // E.g., this could be a synthetic tensor (for invariants and sparse
+ // An undefined dlt in the lattices, we probably mean to
+ // iterate based on the level of output tensor. E.g., this
+ // could be a synthetic tensor (for invariants and sparse
// output tensor).
// out[i][j] = invariant; or a broadcast
// out[i][j] = in[i] (j is undef for input)
- tid = env.merger().getOutTensorID();
- dim = env.merger().getDimNum(tid, idx);
- // Skips invalid dim (e.g., when this is a zero ranked tensor).
- if (!dim)
+ tid = outTid;
+ lvl = outLvl;
+ // Skips invalid lvl (e.g., when this is a zero ranked tensor).
+ if (!lvl)
return;
}
hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
tids.push_back(tid);
- dims.push_back(*dim);
+ lvls.push_back(*lvl);
numloopCond++;
} else if (isDenseDLT(dlt)) {
tids.push_back(tid);
- dims.push_back(*dim);
+ lvls.push_back(*lvl);
} else {
assert(isUndefDLT(dlt));
linalg::GenericOp op = env.op();
@@ -1332,15 +1392,15 @@ static bool translateBitsToTidDimPairs(
for (Level l = 0; l < lvlRank; l++) {
// FIXME: `toOrigDim` is deprecated.
AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
- // Skip simple affine expression and non dense dimensions (which has
- // it own filter loop).
+ // Skip simple affine expression and non-dense levels (which
+ // have their own filter loop).
if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
continue;
// Constant affine expression are handled in genLoop
if (!exp.isa<AffineConstantExpr>()) {
- bool atLevel = false;
- if (isInvariantAffine(env, exp, idx, atLevel) && atLevel) {
+ bool isAtLoop = false;
+ if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
// If the compound affine is invariant and we are right at the
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
@@ -1351,7 +1411,7 @@ static bool translateBitsToTidDimPairs(
// might be accepting out-of-order access between consecutive
// dense levels.
affineTids.push_back(tid);
- affineDims.push_back(l);
+ affineLvls.push_back(l);
exps.push_back(exp);
}
}
@@ -1359,13 +1419,12 @@ static bool translateBitsToTidDimPairs(
}
});
- if (isDenseDLT(env.dlt(env.merger().getOutTensorID(), idx))) {
+ if (isDenseDLT(env.dlt(outTid, ldx))) {
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx);
- tids.push_back(env.merger().getOutTensorID());
- dims.push_back(dim);
+ tids.push_back(outTid);
+ lvls.push_back(*outLvl);
}
assert(numloopCond > 0);
@@ -1375,33 +1434,33 @@ static bool translateBitsToTidDimPairs(
}
/// Starts a single loop in current sequence.
-static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
- unsigned li, bool needsUniv) {
- // The set of tensors + dims to generate loops on
- SmallVector<size_t> tids, dims;
+static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
+ LatPointId li, bool needsUniv) {
+ // The set of tensors + lvls to generate loops on
+ SmallVector<TensorId> tids, affineTids;
+ SmallVector<Level> lvls, affineLvls;
// The set of dense tensors with non-trivial affine expression that just
// becomes invariant and the address shall now be generated at the current
// level.
- SmallVector<size_t> affineTids, affineDims;
SmallVector<AffineExpr> affines;
- bool isFor = translateBitsToTidDimPairs(
- env, li, env.topSortAt(at), tids, dims, affineTids, affineDims, affines);
+ bool isFor = translateBitsToTidLvlPairs(
+ env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines);
// Emit the for/while-loop control.
- Operation *loop = genLoop(env, builder, at, needsUniv, tids, dims, isFor);
- for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
- env.emitter().genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(),
- tid, dim, exp);
+ Operation *loop = genLoop(env, builder, at, needsUniv, tids, lvls, isFor);
+ Location loc = env.op().getLoc();
+ for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) {
+ env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp);
}
- // Until now, we have entered every <tid, dim> pair in {cond, extra,
- // affine}Tids/Dims. The addresses of the upcoming levels which are dependent
+ // Until now, we have entered every <tid, lvl> pair in {cond, extra,
+ // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
// on constant affines expression may now be determined.
- auto allTids = llvm::concat<size_t>(tids, affineTids);
- auto allDims = llvm::concat<size_t>(dims, affineDims);
- for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
+ auto allTids = llvm::concat<TensorId>(tids, affineTids);
+ auto allLvls = llvm::concat<Level>(lvls, affineLvls);
+ for (auto [tid, lvl] : llvm::zip(allTids, allLvls)) {
if (tid != env.merger().getOutTensorID())
- genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);
+ genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
}
return loop;
@@ -1409,7 +1468,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
/// Ends a single loop in current sequence. Returns new values for needsUniv.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
- unsigned idx, unsigned li, bool needsUniv) {
+ LoopId idx, LatPointId li, bool needsUniv) {
// End a while-loop.
if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
@@ -1430,9 +1489,9 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
}
/// Ends a loop sequence at given level.
-static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
- unsigned at, unsigned idx, unsigned ldx) {
- assert(env.getLoopIdxValue(idx) == nullptr);
+static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopOrd at, LoopId idx, LoopId ldx) {
+ assert(!env.getLoopVar(idx));
env.emitter().exitCurrentLoopSeq();
// Unmark bookkeeping of invariants and loop index.
genInvariants(env, builder, exp, ldx, /*atStart=*/false);
@@ -1443,20 +1502,21 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
/// Recursively generates code while computing iteration lattices in order
/// to manage the complexity of implementing co-iteration over unions
/// and intersections of sparse iterations spaces.
-static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
- unsigned at) {
+static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
+ LoopOrd at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == env.topSortSize()) {
- unsigned ldx = env.topSortAt(at - 1);
+ const LoopId ldx = env.topSortAt(at - 1);
Value rhs = genExp(env, rewriter, exp, ldx);
genTensorStore(env, rewriter, exp, rhs);
return;
}
// Construct iteration lattices for current loop index, with L0 at top.
- unsigned idx = env.topSortAt(at);
- unsigned ldx = at == 0 ? -1u : env.topSortAt(at - 1);
- unsigned lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
+ const LoopId idx = env.topSortAt(at);
+ const LoopId ldx = at == 0 ? kInvalidId : env.topSortAt(at - 1);
+ const LatSetId lts =
+ env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
// Start a loop sequence.
bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
@@ -1465,7 +1525,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
unsigned lsize = env.set(lts).size();
for (unsigned i = 0; i < lsize; i++) {
// Start a loop.
- unsigned li = env.set(lts)[i];
+ const LatPointId li = env.set(lts)[i];
Operation *loop = startLoop(env, rewriter, at, li, needsUniv);
// Visit all lattices points with Li >= Lj to generate the
@@ -1475,8 +1535,8 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
Value insInput = env.getInsertionChain();
bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
for (unsigned j = 0; j < lsize; j++) {
- unsigned lj = env.set(lts)[j];
- unsigned ej = env.lat(lj).exp;
+ const LatPointId lj = env.set(lts)[j];
+ const ExprId ej = env.lat(lj).exp;
if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
if (isWhile) {
@@ -1541,12 +1601,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure();
// Sets up a code generation environment.
- unsigned numTensors = op->getNumOperands();
- unsigned numLoops = op.getNumLoops();
- unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op);
+ const unsigned numTensors = op->getNumOperands();
+ const unsigned numLoops = op.getNumLoops();
+ const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op);
CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops);
- // Detects sparse annotations and translates the per-dimension sparsity
+ // Detects sparse annotations and translates the per-level sparsity
// information for all tensors to loop indices in the kernel.
if (!findSparseAnnotations(env))
return failure();
@@ -1568,11 +1628,11 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// computation. Must be ordered from more strict to less strict.
// Ideally (though might not be guaranteed), the eariler a constraint mask
// can be satisfied, the faster the generated kernel will be.
- const auto allMask = {
+ const auto allMasks = {
SortMask::kIncludeAll, SortMask::kIncludeDense,
SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
SortMask::kIncludeUndef, SortMask::kSparseOnly};
- for (auto mask : allMask) {
+ for (const SortMask mask : allMasks) {
if (computeIterationGraph(env, mask)) {
hasCycle = false;
if (env.isAdmissibleTopoOrder()) {
@@ -1591,7 +1651,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
env.startEmit();
genBuffers(env, rewriter);
genInitConstantDenseAddress(env, rewriter);
- genStmt(env, rewriter, env.getTensorExp(), 0);
+ genStmt(env, rewriter, env.getExprId(), 0);
genResult(env, rewriter);
return success();
}
@@ -1603,7 +1663,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// sparse input tensor in succession until an acylic
// iteration graph results.
for (OpOperand *t : env.op().getDpsInputOperands()) {
- unsigned tensor = t->getOperandNumber();
+ const TensorId tid = t->getOperandNumber();
Value tval = t->get();
auto srcEnc = getSparseTensorEncoding(tval.getType());
if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t))
@@ -1624,8 +1684,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
auto dstTp = RankedTensorType::get(srcTp.getShape(),
srcTp.getElementType(), dstEnc);
auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- rewriter.updateRootInPlace(
- env.op(), [&]() { env.op()->setOperand(tensor, convert); });
+ rewriter.updateRootInPlace(env.op(),
+ [&]() { env.op()->setOperand(tid, convert); });
rewriter.setInsertionPointAfter(env.op());
rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 0fd75476db25..029ce3f3f91e 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -30,7 +30,7 @@ static ExpArity getExpArity(Kind k) {
// Leaf.
case kTensor:
case kInvariant:
- case kIndex:
+ case kLoopVar:
return ExpArity::kNullary;
case kAbsF:
case kAbsC:
@@ -98,20 +98,20 @@ static ExpArity getExpArity(Kind k) {
// Constructors.
//===----------------------------------------------------------------------===//
-TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
+TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
: kind(k), val(v), op(o) {
switch (kind) {
// Leaf.
case kTensor:
- assert(x != -1u && y == -1u && !v && !o);
+ assert(x != kInvalidId && y == kInvalidId && !v && !o);
tensor = x;
break;
case kInvariant:
- assert(x == -1u && y == -1u && v && !o);
+ assert(x == kInvalidId && y == kInvalidId && v && !o);
break;
- case kIndex:
- assert(x != -1u && y == -1u && !v && !o);
- index = x;
+ case kLoopVar:
+ assert(x != kInvalidId && y == kInvalidId && !v && !o);
+ loop = x;
break;
// Unary operations.
case kAbsF:
@@ -134,7 +134,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kNegI:
case kCIm:
case kCRe:
- assert(x != -1u && y == -1u && !v && !o);
+ assert(x != kInvalidId && y == kInvalidId && !v && !o);
children.e0 = x;
children.e1 = y;
break;
@@ -149,20 +149,20 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kCastIdx:
case kTruncI:
case kBitCast:
- assert(x != -1u && y == -1u && v && !o);
+ assert(x != kInvalidId && y == kInvalidId && v && !o);
children.e0 = x;
children.e1 = y;
break;
case kBinaryBranch:
case kSelect:
- assert(x != -1u && y == -1u && !v && o);
+ assert(x != kInvalidId && y == kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
break;
case kUnary:
// No assertion on y can be made, as the branching paths involve both
- // a unary (mapSet) and binary (takeDisj) pathway.
- assert(x != -1u && !v && o);
+ // a unary (`mapSet`) and binary (`disjSet`) pathway.
+ assert(x != kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
break;
@@ -186,82 +186,89 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kShrS:
case kShrU:
case kShlI:
- assert(x != -1u && y != -1u && !v && !o);
+ assert(x != kInvalidId && y != kInvalidId && !v && !o);
children.e0 = x;
children.e1 = y;
break;
case kBinary:
case kReduce:
- assert(x != -1u && y != -1u && !v && o);
+ assert(x != kInvalidId && y != kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
break;
}
}
-LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
- : bits(n, false), exp(e) {
+LatPoint::LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
+
+LatPoint::LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i,
+ ExprId e)
+ : bits(numLoops * numTensors, false), exp(e) {
+ assert(t < numTensors && i < numLoops);
+ const TensorLoopId b = numTensors * i + t;
bits.set(b);
}
-LatPoint::LatPoint(const BitVector &b, unsigned e) : bits(b), exp(e) {}
-
-Merger::Merger(unsigned t, unsigned l, unsigned fl)
- : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1),
- numNativeLoops(l), numLoops(l + fl), hasSparseOut(false),
- dimTypes(numTensors,
+Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
+ unsigned numFilterLoops)
+ : outTensor(numInputOutputTensors - 1),
+ syntheticTensor(numInputOutputTensors),
+ numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops),
+ numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false),
+ lvlTypes(numTensors,
std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
- loopIdxToDim(numTensors, std::vector<std::optional<unsigned>>(
- numLoops, std::nullopt)),
- dimToLoopIdx(numTensors, std::vector<std::optional<unsigned>>(
- numLoops, std::nullopt)) {}
+ loopToLvl(numTensors,
+ std::vector<std::optional<Level>>(numLoops, std::nullopt)),
+ lvlToLoop(numTensors,
+ std::vector<std::optional<LoopId>>(numLoops, std::nullopt)) {}
//===----------------------------------------------------------------------===//
// Lattice methods.
//===----------------------------------------------------------------------===//
-unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
- Operation *op) {
- unsigned e = tensorExps.size();
- tensorExps.push_back(TensorExp(k, e0, e1, v, op));
+ExprId Merger::addExp(Kind k, unsigned x, ExprId y, Value v, Operation *op) {
+ const ExprId e = tensorExps.size();
+ assert((k != kTensor || x < numTensors) && (k != kLoopVar || x < numLoops));
+ tensorExps.emplace_back(k, x, y, v, op);
return e;
}
-unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
+LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
assert(t < numTensors && i < numLoops);
- unsigned p = latPoints.size();
- latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
+ const LatPointId p = latPoints.size();
+ latPoints.emplace_back(numTensors, numLoops, t, i, e);
return p;
}
-unsigned Merger::addSet() {
- unsigned s = latSets.size();
+LatSetId Merger::addSet() {
+ const LatSetId s = latSets.size();
latSets.emplace_back();
return s;
}
-unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
- Operation *op) {
- unsigned p = latPoints.size();
- BitVector nb = BitVector(latPoints[p0].bits);
- nb |= latPoints[p1].bits;
- unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
- latPoints.push_back(LatPoint(nb, e));
+LatPointId Merger::conjLat(Kind kind, LatPointId p0, LatPointId p1,
+ Operation *op) {
+ const LatPointId p = latPoints.size();
+ BitVector bits(latPoints[p0].bits);
+ bits |= latPoints[p1].bits;
+ const ExprId e =
+ addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
+ latPoints.emplace_back(bits, e);
return p;
}
-unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
- unsigned s = addSet();
- for (unsigned p0 : latSets[s0])
- for (unsigned p1 : latSets[s1])
- latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
+LatSetId Merger::conjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
+ const LatSetId s = addSet();
+ for (const LatPointId p0 : latSets[s0])
+ for (const LatPointId p1 : latSets[s1])
+ latSets[s].push_back(conjLat(kind, p0, p1, op));
return s;
}
-unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
- unsigned s = takeConj(kind, s0, s1, op);
+LatSetId Merger::disjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
+ const LatSetId s = conjSet(kind, s0, s1, op);
// Followed by all in s0.
- for (unsigned p : latSets[s0])
+ for (const LatPointId p : latSets[s0])
latSets[s].push_back(p);
// Map binary 0-y to unary -y.
// TODO: move this if-else logic into buildLattices
@@ -272,56 +279,56 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
else if (kind == kSubI)
s1 = mapSet(kNegI, s1);
// Followed by all in s1.
- for (unsigned p : latSets[s1])
+ for (const LatPointId p : latSets[s1])
latSets[s].push_back(p);
return s;
}
-unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
- bool includeLeft, Kind ltrans, Operation *opleft,
- bool includeRight, Kind rtrans, Operation *opright) {
- unsigned s = takeConj(kind, s0, s1, orig);
+LatSetId Merger::combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
+ bool includeLeft, Kind ltrans, Operation *opleft,
+ bool includeRight, Kind rtrans, Operation *opright) {
+ const LatSetId s = conjSet(kind, s0, s1, orig);
// Left Region.
if (includeLeft) {
if (opleft)
s0 = mapSet(ltrans, s0, Value(), opleft);
- for (unsigned p : latSets[s0])
+ for (const LatPointId p : latSets[s0])
latSets[s].push_back(p);
}
// Right Region.
if (includeRight) {
if (opright)
s1 = mapSet(rtrans, s1, Value(), opright);
- for (unsigned p : latSets[s1])
+ for (const LatPointId p : latSets[s1])
latSets[s].push_back(p);
}
return s;
}
-unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
+LatSetId Merger::mapSet(Kind kind, LatSetId s0, Value v, Operation *op) {
assert(kAbsF <= kind && kind <= kSelect);
- unsigned s = addSet();
- for (unsigned p : latSets[s0]) {
- unsigned e = addExp(kind, latPoints[p].exp, v, op);
- latPoints.push_back(LatPoint(latPoints[p].bits, e));
+ const LatSetId s = addSet();
+ for (const LatPointId p : latSets[s0]) {
+ const ExprId e = addExp(kind, latPoints[p].exp, v, op);
+ latPoints.emplace_back(latPoints[p].bits, e);
latSets[s].push_back(latPoints.size() - 1);
}
return s;
}
-unsigned Merger::optimizeSet(unsigned s0) {
- unsigned s = addSet();
+LatSetId Merger::optimizeSet(LatSetId s0) {
+ const LatSetId s = addSet();
assert(!latSets[s0].empty());
- unsigned p0 = latSets[s0][0];
- for (unsigned p1 : latSets[s0]) {
+ const LatPointId p0 = latSets[s0][0];
+ for (const LatPointId p1 : latSets[s0]) {
bool add = true;
if (p0 != p1) {
- // Is this a straightforward copy?
- unsigned e = latPoints[p1].exp;
+ // Check whether this is a straightforward copy.
+ const ExprId e = latPoints[p1].exp;
if (expIsTensor(e, outTensor))
continue;
- // Conjunction already covered?
- for (unsigned p2 : latSets[s]) {
+ // Check whether this conjunction is already covered.
+ for (const LatPointId p2 : latSets[s]) {
assert(!latGT(p1, p2)); // Lj => Li would be bad
if (onlyDenseDiff(p2, p1)) {
add = false;
@@ -333,30 +340,30 @@ unsigned Merger::optimizeSet(unsigned s0) {
if (add)
latSets[s].push_back(p1);
}
- for (unsigned p : latSets[s])
+ for (const LatPointId p : latSets[s])
latPoints[p].simple = simplifyCond(s, p);
return s;
}
-BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
+BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// First determine if this lattice point is a *singleton*, i.e.,
// the last point in a lattice, no other is less than this one.
bool isSingleton = true;
- for (unsigned p1 : latSets[s0]) {
+ for (const LatPointId p1 : latSets[s0]) {
if (p0 != p1 && latGT(p0, p1)) {
isSingleton = false;
break;
}
}
- BitVector simple = latPoints[p0].bits;
+ BitVector simple(latPoints[p0].bits);
bool reset = isSingleton && hasAnySparse(simple);
- unsigned be = simple.size();
- unsigned offset = 0; // relative to the end
+ const TensorLoopId be = simple.size();
+ TensorLoopId offset = 0; // relative to the end
if (!reset)
- // Starts resetting from a dense dimension, so that the first bit (if kept)
- // is not undefined dimension type.
- for (unsigned b = 0; b < be; b++) {
+ // Starts resetting from a dense level, so that the first bit (if kept)
+ // is not undefined level-type.
+ for (TensorLoopId b = 0; b < be; b++) {
if (simple[b] && isDenseDLT(getDimLevelType(b))) {
offset = be - b - 1; // relative to the end
break;
@@ -365,24 +372,26 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
// Now apply the two basic rules. We also iterate the bits reversely to always
// keep the rightmost bit (which could possibly be a synthetic tensor).
- for (unsigned b = be - 1 - offset, i = 0; i < be;
+ for (TensorLoopId b = be - 1 - offset, i = 0; i < be;
b = b == 0 ? be - 1 : b - 1, i++) {
- if (simple[b] && (!isCompressedDLT(getDimLevelType(b)) &&
- !isSingletonDLT(getDimLevelType(b)))) {
- if (reset)
- simple.reset(b);
- reset = true;
+ if (simple[b]) {
+ const auto dlt = getDimLevelType(b);
+ if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) {
+ if (reset)
+ simple.reset(b);
+ reset = true;
+ }
}
}
return simple;
}
-bool Merger::latGT(unsigned i, unsigned j) const {
+bool Merger::latGT(LatPointId i, LatPointId j) const {
const BitVector &bitsi = latPoints[i].bits;
const BitVector &bitsj = latPoints[j].bits;
assert(bitsi.size() == bitsj.size());
if (bitsi.count() > bitsj.count()) {
- for (unsigned b = 0, be = bitsj.size(); b < be; b++)
+ for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
if (bitsj[b] && !bitsi[b])
return false;
return true;
@@ -390,13 +399,13 @@ bool Merger::latGT(unsigned i, unsigned j) const {
return false;
}
-bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
- BitVector tmp = latPoints[j].bits;
+bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
+ BitVector tmp(latPoints[j].bits);
tmp ^= latPoints[i].bits;
return !hasAnySparse(tmp);
}
-bool Merger::expContainsTensor(unsigned e, unsigned t) const {
+bool Merger::expContainsTensor(ExprId e, TensorId t) const {
if (tensorExps[e].kind == kTensor)
return tensorExps[e].tensor == t;
@@ -404,23 +413,23 @@ bool Merger::expContainsTensor(unsigned e, unsigned t) const {
case ExpArity::kNullary:
return false;
case ExpArity::kUnary: {
- unsigned op = tensorExps[e].children.e0;
- if (expIsTensor(op, t))
+ const ExprId e0 = tensorExps[e].children.e0;
+ if (expIsTensor(e0, t))
return true;
- return expContainsTensor(op, t);
+ return expContainsTensor(e0, t);
}
case ExpArity::kBinary: {
- unsigned op1 = tensorExps[e].children.e0;
- unsigned op2 = tensorExps[e].children.e1;
- if (expIsTensor(op1, t) || expIsTensor(op2, t))
+ const ExprId e0 = tensorExps[e].children.e0;
+ const ExprId e1 = tensorExps[e].children.e1;
+ if (expIsTensor(e0, t) || expIsTensor(e1, t))
return true;
- return expContainsTensor(op1, t) || expContainsTensor(op2, t);
+ return expContainsTensor(e0, t) || expContainsTensor(e1, t);
}
}
llvm_unreachable("unexpected arity");
}
-bool Merger::hasNegateOnOut(unsigned e) const {
+bool Merger::hasNegateOnOut(ExprId e) const {
switch (tensorExps[e].kind) {
case kNegF:
case kNegC:
@@ -446,13 +455,14 @@ bool Merger::hasNegateOnOut(unsigned e) const {
llvm_unreachable("unexpected kind");
}
-bool Merger::isSingleCondition(unsigned t, unsigned e) const {
+bool Merger::isSingleCondition(TensorId t, ExprId e) const {
+ assert(t < numTensors && e < tensorExps.size());
switch (tensorExps[e].kind) {
// Leaf.
case kTensor:
return tensorExps[e].tensor == t;
case kInvariant:
- case kIndex:
+ case kLoopVar:
return false;
// Unary operations.
case kAbsF:
@@ -531,10 +541,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
}
bool Merger::hasAnySparse(const BitVector &bits) const {
- for (unsigned b = 0, be = bits.size(); b < be; b++)
- if (bits[b] && (isCompressedDLT(getDimLevelType(b)) ||
- isSingletonDLT(getDimLevelType(b))))
- return true;
+ for (TensorLoopId b = 0, be = bits.size(); b < be; b++)
+ if (bits[b]) {
+ const auto dlt = getDimLevelType(b);
+ if (isCompressedDLT(dlt) || isSingletonDLT(dlt))
+ return true;
+ }
return false;
}
@@ -551,7 +563,7 @@ static const char *kindToOpSymbol(Kind kind) {
return "tensor";
case kInvariant:
return "invariant";
- case kIndex:
+ case kLoopVar:
return "index";
// Unary operations.
case kAbsF:
@@ -641,7 +653,7 @@ static const char *kindToOpSymbol(Kind kind) {
llvm_unreachable("unexpected kind for symbol");
}
-void Merger::dumpExp(unsigned e) const {
+void Merger::dumpExp(ExprId e) const {
switch (tensorExps[e].kind) {
// Leaf.
case kTensor:
@@ -654,8 +666,8 @@ void Merger::dumpExp(unsigned e) const {
case kInvariant:
llvm::dbgs() << "invariant";
break;
- case kIndex:
- llvm::dbgs() << "index_" << tensorExps[e].index;
+ case kLoopVar:
+ llvm::dbgs() << "loopvar_" << tensorExps[e].loop;
break;
// Unary operations.
case kAbsF:
@@ -725,7 +737,7 @@ void Merger::dumpExp(unsigned e) const {
}
}
-void Merger::dumpLat(unsigned p) const {
+void Merger::dumpLat(LatPointId p) const {
llvm::dbgs() << "lat(";
dumpBits(latPoints[p].bits);
llvm::dbgs() << " :";
@@ -735,9 +747,9 @@ void Merger::dumpLat(unsigned p) const {
llvm::dbgs() << " )\n";
}
-void Merger::dumpSet(unsigned s) const {
+void Merger::dumpSet(LatSetId s) const {
llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
- for (unsigned p : latSets[s]) {
+ for (const LatPointId p : latSets[s]) {
llvm::dbgs() << " ";
dumpLat(p);
}
@@ -745,11 +757,11 @@ void Merger::dumpSet(unsigned s) const {
}
void Merger::dumpBits(const BitVector &bits) const {
- for (unsigned b = 0, be = bits.size(); b < be; b++) {
+ for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
if (bits[b]) {
- unsigned t = tensor(b);
- unsigned i = index(b);
- DimLevelType dlt = dimTypes[t][i];
+ const TensorId t = tensor(b);
+ const LoopId i = loop(b);
+ const auto dlt = lvlTypes[t][i];
llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
}
}
@@ -761,20 +773,20 @@ void Merger::dumpBits(const BitVector &bits) const {
// Builder methods.
//===----------------------------------------------------------------------===//
-unsigned Merger::buildLattices(unsigned e, unsigned i) {
- Kind kind = tensorExps[e].kind;
+LatSetId Merger::buildLattices(ExprId e, LoopId i) {
+ const Kind kind = tensorExps[e].kind;
switch (kind) {
// Leaf.
case kTensor:
case kInvariant:
- case kIndex: {
- // Either the index is really used in the tensor expression, or it is
- // set to the undefined index in that dimension. An invariant expression,
+ case kLoopVar: {
+ // Either the loop-var is really used in the tensor expression, or it is
+ // set to the undefined loop-var in that level. An invariant expression,
// a proper index value, and a truly dynamic sparse output tensor are set
// to a synthetic tensor with undefined indices only to ensure the
// iteration space is not skipped as a result of their contents.
- unsigned s = addSet();
- unsigned t = syntheticTensor;
+ const LatSetId s = addSet();
+ TensorId t = syntheticTensor;
if (kind == kTensor) {
t = tensorExps[e].tensor;
if (hasSparseOut && t == outTensor)
@@ -836,7 +848,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// ----+----------+------------+
// | absent() | present(y) |
{
- unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
+ const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i);
UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
Region &absentRegion = unop.getAbsentRegion();
@@ -848,8 +860,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
Value absentVal = absentYield.getResult();
- unsigned rhs = addExp(kInvariant, absentVal);
- return takeDisj(kind, child0, buildLattices(rhs, i), unop);
+ const ExprId rhs = addExp(kInvariant, absentVal);
+ return disjSet(kind, child0, buildLattices(rhs, i), unop);
}
// Binary operations.
case kMulF:
@@ -865,9 +877,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// x | 0 |x*y|
//
// Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
- return takeConj(kind, // take binary conjunction
- buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
+ buildLattices(tensorExps[e].children.e1, i));
case kDivF:
case kDivC:
case kDivS:
@@ -886,9 +897,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// rules applies (viz. x/c = x*(1/c) as far as lattice
// construction is concerned).
assert(!maybeZero(tensorExps[e].children.e1));
- return takeConj(kind, // take binary conjunction
- buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
+ buildLattices(tensorExps[e].children.e1, i));
case kAddF:
case kAddC:
case kAddI:
@@ -904,9 +914,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// ---+---+---+ ---+---+---+
// !x | 0 | y | !x | 0 |-y |
// x | x |x+y| x | x |x-y|
- return takeDisj(kind, // take binary disjunction
- buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ return disjSet(kind, buildLattices(tensorExps[e].children.e0, i),
+ buildLattices(tensorExps[e].children.e1, i));
case kShrS:
case kShrU:
case kShlI:
@@ -914,9 +923,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// can only occur at the left-hand-side of the operator) can be handled
// with the conjuction rule.
assert(isInvariant(tensorExps[e].children.e1));
- return takeConj(kind, // take binary conjunction
- buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
+ buildLattices(tensorExps[e].children.e1, i));
case kBinary:
// A custom binary operation.
//
@@ -925,8 +933,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
// !x | empty | right(y) |
// x | left(x) | overlap(x,y) |
{
- unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
- unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
+ const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i);
+ const LatSetId child1 = buildLattices(tensorExps[e].children.e1, i);
BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
Region &leftRegion = binop.getLeftRegion();
Region &rightRegion = binop.getRightRegion();
@@ -944,20 +952,20 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
}
bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
- return takeCombi(kBinary, child0, child1, binop, includeLeft,
- kBinaryBranch, leftYield, includeRight, kBinaryBranch,
- rightYield);
+ return combiSet(kBinary, child0, child1, binop, includeLeft,
+ kBinaryBranch, leftYield, includeRight, kBinaryBranch,
+ rightYield);
}
case kReduce:
// A custom reduce operation.
- return takeConj(kind, buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i),
- tensorExps[e].op);
+ return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
+ buildLattices(tensorExps[e].children.e1, i),
+ tensorExps[e].op);
}
llvm_unreachable("unexpected expression kind");
}
-std::optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
+std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
// Build the linalg semantics backward from yield.
Operation *yield = op.getRegion().front().getTerminator();
assert(isa<linalg::YieldOp>(yield));
@@ -965,7 +973,7 @@ std::optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
}
/// Only returns false if we are certain this is a nonzero.
-bool Merger::maybeZero(unsigned e) const {
+bool Merger::maybeZero(ExprId e) const {
if (tensorExps[e].kind == kInvariant) {
if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
@@ -980,11 +988,11 @@ bool Merger::maybeZero(unsigned e) const {
return true;
}
-bool Merger::isInvariant(unsigned e) const {
+bool Merger::isInvariant(ExprId e) const {
return tensorExps[e].kind == kInvariant;
}
-Type Merger::inferType(unsigned e, Value src) {
+Type Merger::inferType(ExprId e, Value src) const {
// Obtain the destination type from the cast node.
Type dtp = tensorExps[e].val.getType();
// Inspect source type. For vector types, apply the same
@@ -997,7 +1005,7 @@ Type Merger::inferType(unsigned e, Value src) {
/// Ensures that sparse compiler can generate code for expression.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
// Arguments are always admissible.
- if (auto arg = v.dyn_cast<BlockArgument>())
+ if (v.isa<BlockArgument>())
return true;
// Accept index anywhere.
Operation *def = v.getDefiningOp();
@@ -1024,9 +1032,9 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0));
}
-std::optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
+std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
- unsigned argN = arg.getArgNumber();
+ const TensorId argN = arg.getArgNumber();
// Any argument of the generic op that is not marked as a scalar
// argument is considered a tensor, indexed by the implicit loop
// bounds. This includes rank-0 tensor arguments.
@@ -1047,13 +1055,13 @@ std::optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// Construct index operations.
if (def->getNumOperands() == 0) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return addExp(kIndex, indexOp.getDim());
+ return addExp(kLoopVar, indexOp.getDim());
}
// Construct unary operations if subexpression can be built.
if (def->getNumOperands() == 1) {
- auto x = buildTensorExp(op, def->getOperand(0));
+ const auto x = buildTensorExp(op, def->getOperand(0));
if (x.has_value()) {
- unsigned e = *x;
+ const ExprId e = *x;
if (isa<math::AbsFOp>(def))
return addExp(kAbsF, e);
if (isa<complex::AbsOp>(def))
@@ -1129,11 +1137,11 @@ std::optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
if (def->getNumOperands() == 2) {
- auto x = buildTensorExp(op, def->getOperand(0));
- auto y = buildTensorExp(op, def->getOperand(1));
+ const auto x = buildTensorExp(op, def->getOperand(0));
+ const auto y = buildTensorExp(op, def->getOperand(1));
if (x.has_value() && y.has_value()) {
- unsigned e0 = *x;
- unsigned e1 = *y;
+ const ExprId e0 = *x;
+ const ExprId e1 = *y;
if (isa<arith::MulFOp>(def))
return addExp(kMulF, e0, e1);
if (isa<complex::MulOp>(def))
@@ -1184,12 +1192,12 @@ std::optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
// Construct ternary operations if subexpressions can be built.
if (def->getNumOperands() == 3) {
- auto x = buildTensorExp(op, def->getOperand(0));
- auto y = buildTensorExp(op, def->getOperand(1));
- auto z = buildTensorExp(op, def->getOperand(2));
+ const auto x = buildTensorExp(op, def->getOperand(0));
+ const auto y = buildTensorExp(op, def->getOperand(1));
+ const auto z = buildTensorExp(op, def->getOperand(2));
if (x.has_value() && y.has_value() && z.has_value()) {
- unsigned e0 = *x;
- unsigned e1 = *y;
+ const ExprId e0 = *x;
+ const ExprId e1 = *y;
if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
if (isAdmissibleBranch(redop, redop.getRegion()))
return addExp(kReduce, e0, e1, Value(), def);
@@ -1245,13 +1253,13 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
}
-Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
- Value v0, Value v1) {
+Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
+ Value v1) const {
switch (tensorExps[e].kind) {
// Leaf.
case kTensor:
case kInvariant:
- case kIndex:
+ case kLoopVar:
llvm_unreachable("unexpected non-op");
// Unary operations.
case kAbsF:
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 929900142d27..10d350f7c6b9 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -225,7 +225,7 @@ class MergerTestBase : public ::testing::Test {
case kTensor:
return tensorExp.tensor == pattern->tensorNum;
case kInvariant:
- case kIndex:
+ case kLoopVar:
llvm_unreachable("invariant not handled yet");
// Unary operations.
case kAbsF:
@@ -313,15 +313,15 @@ class MergerTest3T1L : public MergerTestBase {
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
- merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
merger.addExp(Kind::kTensor, t1, -1u);
- merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
// Tensor 2: dense output vector.
merger.addExp(Kind::kTensor, t2, -1u);
- merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
}
};
@@ -338,19 +338,19 @@ class MergerTest4T1L : public MergerTestBase {
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
- merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
merger.addExp(Kind::kTensor, t1, -1u);
- merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
// Tensor 2: sparse input vector
merger.addExp(Kind::kTensor, t2, -1u);
- merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
// Tensor 3: dense output vector
merger.addExp(Kind::kTensor, t3, -1u);
- merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
}
};
@@ -371,15 +371,15 @@ class MergerTest3T1LD : public MergerTestBase {
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
- merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: dense input vector.
merger.addExp(Kind::kTensor, t1, -1u);
- merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
// Tensor 2: dense output vector.
merger.addExp(Kind::kTensor, t2, -1u);
- merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
}
};
@@ -400,19 +400,19 @@ class MergerTest4T1LU : public MergerTestBase {
// Tensor 0: undef input vector.
merger.addExp(Kind::kTensor, t0, -1u);
- merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
+ merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
// Tensor 1: dense input vector.
merger.addExp(Kind::kTensor, t1, -1u);
- merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
// Tensor 2: undef input vector.
merger.addExp(Kind::kTensor, t2, -1u);
- merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Undef);
+ merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef);
// Tensor 3: dense output vector.
merger.addExp(Kind::kTensor, t3, -1u);
- merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
}
};
@@ -436,15 +436,15 @@ class MergerTest3T1LSo : public MergerTestBase {
// Tensor 0: undef input vector.
merger.addExp(Kind::kTensor, t0, -1u);
- merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
+ merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
// Tensor 1: undef input vector.
merger.addExp(Kind::kTensor, t1, -1u);
- merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Undef);
+ merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef);
// Tensor 2: sparse output vector.
merger.addExp(Kind::kTensor, t2, -1u);
- merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
}
};
More information about the Mlir-commits
mailing list