[Mlir-commits] [mlir] 46a384d - [mlir][sparse] Preliminary code changes for ExprId, LatPointId, LatSetId newtypes
wren romano
llvmlistbot at llvm.org
Wed Mar 29 18:02:07 PDT 2023
Author: wren romano
Date: 2023-03-29T18:01:56-07:00
New Revision: 46a384dfbe11545b200611e471cb4242d1295589
URL: https://github.com/llvm/llvm-project/commit/46a384dfbe11545b200611e471cb4242d1295589
DIFF: https://github.com/llvm/llvm-project/commit/46a384dfbe11545b200611e471cb4242d1295589.diff
LOG: [mlir][sparse] Preliminary code changes for ExprId, LatPointId, LatSetId newtypes
This commit contains several code changes which are ultimately required for converting the varions `Merger` identifiers from typedefs to newtypes. The actual implementation of the newtypes themselves has been split off into separate commits, in hopes of simplifying the review process.
Depends On D146561
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146684
Added:
mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 1a11010971f23..5a6ffeadfa999 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
+#include "mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h"
+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -23,65 +25,6 @@
namespace mlir {
namespace sparse_tensor {
-// TODO: These type aliases currently only serve to make the code more
-// self-documenting, however because they are not type-checked they can
-// do nothing to prevent mixups. We should really change them from mere
-// aliases to actual struct definitions, so that we can type-check them.
-
-/// Tensor identifiers. The valid set of identifiers is defined by the
-/// first argument passed to the `Merger` ctor.
-using TensorId = unsigned;
-
-/// Loop identifiers. The valid set of identifiers is defined by the
-/// second two arguments to the `Merger` ctor.
-///
-/// These identifiers serve as proxies for the `$dim` argument to
-/// `linalg::IndexOp`, however the numerical value of a `LoopId` should
-/// not necessarily be equated with the numerical value of the corresponding
-/// `$dim` argument. The `$dim` arguments are De Bruijn indices: that
-/// is, they identify the loop which binds the loop-variable by counting
-/// the enclosing loops from innermost to outermost, starting from zero.
-/// Whereas `LoopId` are considered to be arbitrary names for identifying
-/// loops; since the `Merger` does not care about the actual ordering of
-/// loops, and leaves it up to the `LoopEmitter` to specify the actual
-/// loop ordering (`LoopOrd`).
-///
-/// TODO: Despite the above claim that `$dim` and `LoopId` need not be
-/// numerically equal, some code in the `Merger` class does equate them
-/// (e.g., `buildTensorExp`). So we need to explicate the exact relationship
-/// between `$dim`, `LoopId`, and `LoopOrd`; especially with regards to their
-/// providence. If `LoopId` really is supposed to be equated with `$dim`,
-/// then we should change the name to `LoopIdx` or similar, to capture the
-/// fact that its numerical value is not invariant when entering/exiting
-/// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which
-/// are invariant identifiers).
-using LoopId = unsigned;
-
-/// A compressed representation of `std::pair<TensorId, LoopId>`.
-/// The compression scheme is such that this also serves as an index
-/// into the bitvector stored in `LatPoint` (since that bitvector is
-/// just the implementation for a set of `TensorLoopId` values).
-using TensorLoopId = unsigned;
-
-/// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
-/// and serve as unique identifiers for the corresponding `TensorExp` object.
-using ExprId = unsigned;
-
-/// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
-/// and serve as unique identifiers for the corresponding `LatPoint` object.
-using LatPointId = unsigned;
-
-/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and
-/// by other methods calling that one), and serve as unique identifiers
-/// for the corresponding `SmallVector<LatPointId>` object.
-using LatSetId = unsigned;
-
-namespace detail {
-/// A constant serving as the canonically invalid identifier, regardless
-/// of the identifier type.
-static constexpr unsigned kInvalidId = -1u;
-} // namespace detail
-
/// Tensor expression. Represents an MLIR expression in tensor index notation.
struct TensorExp final {
enum class Kind;
@@ -207,18 +150,17 @@ enum class TensorExp::Kind {
kReduce, // semiring reduction op
};
+//===----------------------------------------------------------------------===//
/// Lattice point. Each lattice point consists of a formal conjunction
/// of `TensorLoopId`s, together with the identifier of the corresponding
/// tensor expression. The formal conjunction is represented as a set of
/// `TensorLoopId`, where that set is implemented as a `BitVector`.
struct LatPoint final {
- /// Construct the lattice point from a given set of `TensorLoopId`s.
- LatPoint(const BitVector &bits, ExprId e);
+ /// Construct a lattice point with the empty set of `TensorLoopId`s.
+ LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {}
- /// Construct a lattice point with `(t,i)` as the only `TensorLoopId`,
- /// where `(t,i) < (numTensors,numLoops)`.
- LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i,
- ExprId e);
+ /// Construct a lattice point from the given set of `TensorLoopId`s.
+ LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
/// Conjunction of all `TensorLoopId`s involved in the tensor expression.
BitVector bits;
@@ -232,6 +174,7 @@ struct LatPoint final {
ExprId exp;
};
+//===----------------------------------------------------------------------===//
/// A class to handle all iteration lattice operations. This class abstracts
/// away from some implementation details of storing iteration lattices and
/// tensor expressions. This allows for fine-tuning performance characteristics
@@ -271,18 +214,46 @@ class Merger {
Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
unsigned numFilterLoops, unsigned maxLvlRank);
- /// Constructs a new tensor expression, and returns its identifier.
- /// The type of the `e0` argument varies according to the value of the
- /// `k` argument, as described by the `TensorExp` ctor.
- ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = detail::kInvalidId,
- Value v = Value(), Operation *op = nullptr);
- ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) {
- return addExp(k, e, detail::kInvalidId, v, op);
+ //
+ // Constructing valid tensor and loop identifiers.
+ //
+
+ /// Safely converts the argument to a tensor identifier.
+ constexpr TensorId makeTensorId(unsigned t) const {
+ assert(isValidTensorId(t));
+ return t;
}
- ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) {
- return addExp(k, detail::kInvalidId, detail::kInvalidId, v, op);
+
+ /// Safely converts the argument to a loop identifier.
+ constexpr LoopId makeLoopId(unsigned i) const {
+ assert(isValidLoopId(i));
+ return i;
}
+ /// Safely converts the arguments to a pair of (tensor,loop) identifiers.
+ constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
+ assert(isValidTensorId(t) && isValidLoopId(i));
+ return numTensors * i + t;
+ }
+
+ //
+ // Allocating new expressions, points, and sets.
+ //
+
+ /// Constructs a new tensor expression, and returns its identifier.
+ ExprId addTensorExp(TensorId t);
+ /// Constructs a new loop-variable expression, and returns its identifier.
+ ExprId addLoopVarExp(LoopId i);
+ /// Constructs a new invariant expression, and returns its identifier.
+ ExprId addInvariantExp(Value v);
+ /// Constructs a new unary or binary expression, and returns its identifier.
+ ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = detail::kInvalidId,
+ Operation *op = nullptr);
+ /// Constructs a new sesquinary expression, and returns its identifier.
+ /// Currently no sesquinary `Kind` allows specifying the `op`, but we
+ /// allow it anyways because `mapSet` is designed to allow it.
+ ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr);
+
/// Constructs a new iteration lattice point, and returns its identifier.
LatPointId addLat(TensorId t, LoopId i, ExprId e);
LatPointId addLat(const BitVector &bits, ExprId e);
@@ -339,51 +310,47 @@ class Merger {
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const;
/// Gets the tensor-identifier of the `TensorLoopId`.
- TensorId tensor(TensorLoopId b) const { return b % numTensors; }
+ constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; }
/// Gets the loop-identifier of the `TensorLoopId`.
- LoopId loop(TensorLoopId b) const { return b / numTensors; }
+ constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; }
/// Get the total number of tensors (including the output-tensor and
- /// synthetic-tensor). The result is given the type `TensorId` since
- /// the result is primarily used as an upper bound for `TensorId`s.
- TensorId getNumTensors() const { return numTensors; }
+ /// synthetic-tensor).
+ constexpr unsigned getNumTensors() const { return numTensors; }
/// Get the total number of loops (native loops + filter loops).
- /// The result is given the type `LoopId` since the result will
- /// generally be used as a for-loop upper bound.
- LoopId getNumLoops() const { return numLoops; }
- /// Get the number of native loops. The result is given the type
- /// `LoopId` since the result will generally be used as a for-loop
- /// upper bound.
- LoopId getNumNativeLoops() const { return numNativeLoops; }
- /// Get the number of filter loops. The result is given the type
- /// `LoopId` since the result will generally be used as a for-loop
- /// upper bound.
- LoopId getNumFilterLoops() const { return numLoops - numNativeLoops; }
+ constexpr unsigned getNumLoops() const { return numLoops; }
+ /// Get the number of native loops.
+ constexpr unsigned getNumNativeLoops() const { return numNativeLoops; }
+ /// Get the number of filter loops.
+ constexpr unsigned getNumFilterLoops() const {
+ return numLoops - numNativeLoops;
+ }
/// Get the identifier of the first filter-loop.
- LoopId getStartingFilterLoopId() const { return getNumNativeLoops(); }
+ constexpr LoopId getStartingFilterLoopId() const {
+ return getNumNativeLoops();
+ }
/// Returns true if `b` is the `i`th loop of the output tensor.
- bool isOutTensor(TensorLoopId b, LoopId i) const {
- assert(i < numLoops);
- return b == numTensors * i + outTensor;
+ constexpr bool isOutTensor(TensorLoopId b, LoopId i) const {
+ return b == makeTensorLoopId(outTensor, i);
}
/// Get the output tensor's identifier.
- TensorId getOutTensorID() const { return outTensor; }
+ constexpr TensorId getOutTensorID() const { return outTensor; }
/// Get the synthetic tensor's identifier (used for all invariant
/// tensor expressions).
- TensorId getSynTensorID() const { return syntheticTensor; }
+ constexpr TensorId getSynTensorID() const { return syntheticTensor; }
- bool isFilterLoop(LoopId i) const {
- assert(i < numLoops);
+ constexpr bool isFilterLoop(LoopId i) const {
+ assert(isValidLoopId(i));
return i >= numNativeLoops;
}
/// Returns true if the expression is `(kTensor t)`.
bool expIsTensor(ExprId e, TensorId t) const {
- return tensorExps[e].kind == TensorExp::Kind::kTensor &&
- tensorExps[e].tensor == t;
+ const auto &expr = exp(e);
+ return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t;
}
/// Returns true if the expression contains the tensor as an operand.
@@ -411,7 +378,7 @@ class Merger {
/// Gets the level-type of the `t`th tensor on `i`th loop.
DimLevelType getDimLevelType(TensorId t, LoopId i) const {
- assert(t < numTensors && i < numLoops);
+ assert(isValidTensorId(t) && isValidLoopId(i));
return lvlTypes[t][i];
}
@@ -422,13 +389,13 @@ class Merger {
/// Gets the loop identifier for the `lvl`th level of the `t`th tensor.
std::optional<LoopId> getLoopId(TensorId t, Level lvl) const {
- assert(t < numTensors && lvl < lvlToLoop[t].size());
+ assert(isValidLevel(t, lvl));
return lvlToLoop[t][lvl];
}
/// Gets the level number of the the `t`th tensor on `i`th loop.
std::optional<Level> getLvl(TensorId t, LoopId i) const {
- assert(t < numTensors && i < numLoops);
+ assert(isValidTensorId(t) && isValidLoopId(i));
return loopToLvl[t][i];
}
std::optional<Level> getLvl(TensorLoopId b) const {
@@ -438,31 +405,41 @@ class Merger {
/// Sets the level number and level-type of the `t`th tensor on
/// `i`th loop.
void setLevelAndType(TensorId t, LoopId i, Level lvl, DimLevelType dlt) {
- assert(t < numTensors && i < numLoops && lvl < lvlToLoop[t].size() &&
- isValidDLT(dlt));
+ assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidDLT(dlt));
lvlTypes[t][i] = dlt;
loopToLvl[t][i] = lvl;
lvlToLoop[t][lvl] = i;
}
+ using ForeachTensorLoopIdCallback = function_ref<void(
+ TensorLoopId, TensorId, std::optional<Level>, DimLevelType, bool)>;
+
/// Iterates over a set of `TensorLoopId`s, invoking the callback
/// for each `TensorLoopId` and passing it the corresponding tensor
/// identifier, level, and level-type, following with a boolean value
/// indicating whether it is a dependent index reduction loop condition.
- void foreachTensorLoopId(
- LatPointId p, function_ref<void(TensorLoopId, TensorId,
- std::optional<Level>, DimLevelType, bool)>
- callback) {
- for (const TensorLoopId b : latPoints[p].bits.set_bits()) {
- TensorId t = tensor(b);
+ void foreachTensorLoopId(LatPointId p,
+ ForeachTensorLoopIdCallback callback) const {
+ // TODO: the default ought to be simple=true; but we'll need to make
+ // sure to update all the tests to make sure they do the right thing.
+ foreachTensorLoopId(p, /*simple=*/false, callback);
+ }
+ void foreachTensorLoopId(LatPointId p, bool simple,
+ ForeachTensorLoopIdCallback callback) const {
+ const auto &point = lat(p);
+ const auto &bits = simple ? point.simple : point.bits;
+ for (const TensorLoopId b : bits.set_bits()) {
+ const TensorId t = tensor(b);
+ const auto optLvl = getLvl(b);
+ const auto lvlTp = getDimLevelType(b);
if (isLvlWithNonTrivialIdxExp(b)) {
// This must be an undefined level.
- assert(!getLvl(b).has_value());
+ assert(!optLvl.has_value());
// Slice the tid along the dependent level to iterate current loop.
- callback(b, t, loopToDependencies[loop(b)][t], getDimLevelType(b),
+ callback(b, t, loopToDependencies[loop(b)][t], lvlTp,
/*isIdxReduc=*/true);
} else {
- callback(b, t, getLvl(b), getDimLevelType(b), /*isIdxReduc=*/false);
+ callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false);
}
}
}
@@ -472,31 +449,37 @@ class Merger {
/// Establishes the two-way map that i <-> <t, lvl>.
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) {
- assert(lvl < numLoops);
+ assert(isValidLoopId(i) && isValidLevel(t, lvl));
loopToDependencies[i][t] = lvl;
levelToDependentIdx[t][lvl].push_back(i);
}
/// Whether the loop has dependent slice.
- bool hasDependentLvl(LoopId i, TensorId tid) {
- return loopToDependencies[i][tid].has_value();
+ bool hasDependentLvl(LoopId i, TensorId t) {
+ assert(isValidTensorId(t) && isValidLoopId(i));
+ return loopToDependencies[i][t].has_value();
}
/// Returns the list of loop indices which appear in the non-trivial index
/// expression on t_l, e.g., A[i+j] => {i, j}
std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
+ assert(isValidLevel(t, lvl));
return levelToDependentIdx[t][lvl];
}
/// Returns the defining [tid, lvl] for the loop.
std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
+ assert(isValidLoopId(i));
return loopBounds[i];
}
/// Checks whether the TensorLoopId represents a tensor level with
/// non-trivial index expression on it.
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const {
- return loopToDependencies[loop(b)][tensor(b)].has_value();
+ const TensorId t = tensor(b);
+ const LoopId i = loop(b);
+ assert(isValidTensorId(t) && isValidLoopId(i));
+ return loopToDependencies[i][t].has_value();
}
/// Convenience getters to immediately access the stored nodes.
@@ -512,20 +495,28 @@ class Merger {
/// references, but also applies to the `ArrayRef`. In particular,
/// using `for (LatPointId p : merger.set(s))` will run into the same
/// dangling-reference problems if the loop body inserts new sets.
- const TensorExp &exp(ExprId e) const { return tensorExps[e]; }
- const LatPoint &lat(LatPointId p) const { return latPoints[p]; }
- ArrayRef<LatPointId> set(LatSetId s) const { return latSets[s]; }
+ const TensorExp &exp(ExprId e) const {
+ assert(isValidExprId(e));
+ return tensorExps[e];
+ }
+ const LatPoint &lat(LatPointId p) const {
+ assert(isValidLatPointId(p));
+ return latPoints[p];
+ }
+ ArrayRef<LatPointId> set(LatSetId s) const {
+ assert(isValidLatSetId(s));
+ return latSets[s];
+ }
/// Checks whether the given expression has an associated value.
- bool hasExprValue(ExprId e) const {
- return static_cast<bool>(tensorExps[e].val);
- }
+ bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
/// Sets the expression to have the associated value. Asserts that
/// the new value is defined, and that the expression does not already
/// have a value. If you want to overwrite a previous associated value,
/// use `updateExprValue` instead.
void setExprValue(ExprId e, Value v) {
+ assert(isValidExprId(e));
assert(v && "Got an undefined value");
auto &val = tensorExps[e].val;
assert(!val && "Expression already has an associated value");
@@ -537,6 +528,7 @@ class Merger {
/// If you don't want to check for a previous associated value first,
/// then use `updateExprValue` instead.
void clearExprValue(ExprId e) {
+ assert(isValidExprId(e));
auto &val = tensorExps[e].val;
assert(val && "Expression does not have an associated value to clear");
val = Value();
@@ -553,7 +545,10 @@ class Merger {
// the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
// `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
// provide better invariants.
- void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; }
+ void updateExprValue(ExprId e, Value v) {
+ assert(isValidExprId(e));
+ tensorExps[e].val = v;
+ }
#ifndef NDEBUG
/// Print methods (for debugging).
@@ -578,8 +573,26 @@ class Merger {
private:
/// Private helpers.
+ constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }
+ constexpr bool isValidLoopId(LoopId i) const {
+ return i != detail::kInvalidId && i < numLoops;
+ }
+ bool isValidLevel(TensorId t, Level lvl) const {
+ return isValidTensorId(t) && lvl < lvlToLoop[t].size();
+ }
+ bool isValidExprId(ExprId e) const {
+ return e != detail::kInvalidId && e < tensorExps.size();
+ }
+ bool isValidLatPointId(LatPointId p) const {
+ return p != detail::kInvalidId && p < latPoints.size();
+ }
+ bool isValidLatSetId(LatSetId s) const {
+ return s != detail::kInvalidId && s < latSets.size();
+ }
bool maybeZero(ExprId e) const;
- bool isInvariant(ExprId e) const;
+ bool isInvariant(ExprId e) const {
+ return exp(e).kind == TensorExp::Kind::kInvariant;
+ }
Type inferType(ExprId e, Value src) const;
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h
new file mode 100644
index 0000000000000..c481db9511085
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h
@@ -0,0 +1,97 @@
+//===- MergerNewtypes.h - Newtypes for the `Merger` class -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TODO: This header currently defines some typedefs to avoid confusion
+// between several
diff erent things which are all represented as `unsigned`.
+// Over the next few commits, these typedefs will be replaced with "newtypes"
+// (i.e., data types which are zero-cost abstractions for wrapping some
+// underlying type while ensuring that the compiler keeps the new type
+// distinct from the old type), along with related classes for iterating
+// over them, etc.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_
+#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_
+
+#include <cassert>
+#include <type_traits>
+
+namespace mlir {
+namespace sparse_tensor {
+
+namespace detail {
+/// A constant serving as the canonically invalid identifier,
+/// regardless of the identifier type.
+static constexpr unsigned kInvalidId = -1u;
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+/// Tensor identifiers.
+///
+/// Semantically, tensor identifiers could be chosen to be anything;
+/// but operationally, they must be chosen such that the `Merger`
+/// and `GenericOpSparsifier` agree. Therefore, the numeric values of
+/// tensor identifiers are chosen to be the `BlockArgument::getArgNumber`
+/// of the value passed to `Merger::buildTensorExp`, which ranges from
+/// zero to `linalg::GenericOp::getNumOperands` for the op passed to
+/// `GenericOpSparsifier::matchAndRewrite`.
+using TensorId = unsigned;
+
+//===----------------------------------------------------------------------===//
+/// Loop identifiers.
+///
+/// These identifiers serve as proxies for the `$dim` argument to
+/// `linalg::IndexOp`, however the numerical value of a `LoopId` should
+/// not necessarily be equated with the numerical value of the corresponding
+/// `$dim` argument. The `$dim` arguments are De Bruijn indices: that
+/// is, they identify the loop which binds the loop-variable by counting
+/// the enclosing loops from innermost to outermost, starting from zero.
+/// Whereas `LoopId` are considered to be arbitrary names for identifying
+/// loops; since the `Merger` does not care about the actual ordering of
+/// loops, and leaves it up to the `LoopEmitter` to specify the actual
+/// loop ordering (`LoopOrd`).
+///
+/// TODO: Despite the above claim that `$dim` and `LoopId` need not be
+/// numerically equal, some code in the `Merger` class does equate them
+/// (e.g., `buildTensorExp`). So we need to explicate the exact relationship
+/// between `$dim`, `LoopId`, and `LoopOrd`; especially with regards to their
+/// providence. If `LoopId` really is supposed to be equated with `$dim`,
+/// then we should change the name to `LoopIdx` or similar, to capture the
+/// fact that its numerical value is not invariant when entering/exiting
+/// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which
+/// are invariant identifiers).
+using LoopId = unsigned;
+
+//===----------------------------------------------------------------------===//
+/// A compressed representation of `std::pair<TensorId, LoopId>`.
+/// The compression scheme is such that this also serves as an index
+/// into the bitvector stored in `LatPoint` (since that bitvector is
+/// just the implementation for a set of `TensorLoopId` values).
+using TensorLoopId = unsigned;
+
+//===----------------------------------------------------------------------===//
+/// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
+/// and serve as unique identifiers for the corresponding `TensorExp` object.
+using ExprId = unsigned;
+
+//===----------------------------------------------------------------------===//
+/// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
+/// and serve as unique identifiers for the corresponding `LatPoint` object.
+using LatPointId = unsigned;
+
+//===----------------------------------------------------------------------===//
+/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and
+/// by other methods calling that one), and serve as unique identifiers
+/// for the corresponding `SmallVector<LatPointId>` object.
+using LatSetId = unsigned;
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 5631688096033..23b278317f589 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -87,11 +87,13 @@ void CodegenEnv::startEmit() {
SmallVector<Value> tensors; // input tensors passed to loop emitter
for (OpOperand &t : linalgOp->getOpOperands()) {
tensors.push_back(t.get());
- Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
- for (Level lvl = 0; lvl < rank; lvl++) {
- sortArrayBasedOnOrder(
- latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort);
- }
+ const TensorId tid = makeTensorId(t.getOperandNumber());
+ const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
+ const auto enc = getSparseTensorEncoding(t.get().getType());
+ (void)enc;
+ assert(!enc || lvlRank == enc.getLvlRank());
+ for (Level lvl = 0; lvl < lvlRank; lvl++)
+ sortArrayBasedOnOrder(latticeMerger.getDependentLoops(tid, lvl), topSort);
}
loopEmitter.initialize(
@@ -163,10 +165,7 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
}
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
- // That the operand number is a valid `TensorId` will be verified
- // by the call to `isSingleCondition` below; though we may want to add
- // assertions to check it here, in order to give better error messages.
- const TensorId tensor = lhs->getOperandNumber();
+ const TensorId tensor = makeTensorId(lhs->getOperandNumber());
// An non-annotated output tensor is assumed dense, and becomes a random
// access n-dim memref. Admissible since insertions cannot occur.
if (getSparseTensorType(lhs->get()).isAllDense())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index b544478801d66..1ee5c19b284fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -66,6 +66,15 @@ class CodegenEnv {
// Merger delegates.
//
+ constexpr TensorId makeTensorId(unsigned t) const {
+ return latticeMerger.makeTensorId(t);
+ }
+ constexpr LoopId makeLoopId(unsigned i) const {
+ return latticeMerger.makeLoopId(i);
+ }
+ constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
+ return latticeMerger.makeTensorLoopId(t, i);
+ }
const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 7a5605346f508..51a9c65714f87 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -221,7 +221,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
- const TensorId numTensors = ts.size();
+ const unsigned numTensors = ts.size();
this->tensors.assign(ts.begin(), ts.end());
this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
this->lvlSizes.assign(numTensors, std::vector<Value>());
@@ -420,8 +420,9 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
// level-expression, the `getPosition` must in fact be a `Dimension`.
// However, elsewhere we have been lead to expect that `loopIdToOrd`
// should be indexed by `LoopId`...
- const LoopId i = a.cast<AffineDimExpr>().getPosition();
- return loopStack[loopIdToOrd[i]].iv;
+ const auto loopId = a.cast<AffineDimExpr>().getPosition();
+ assert(loopId < loopIdToOrd.size());
+ return loopStack[loopIdToOrd[loopId]].iv;
}
case AffineExprKind::Add: {
auto binOp = a.cast<AffineBinaryOpExpr>();
@@ -692,7 +693,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
Value cond;
unsigned o = 0;
for (auto [t, lvl] : llvm::zip(tids, lvls)) {
- unsigned tid = t; // Why `t` can not be captured by lambda?
+ const TensorId tid = t; // Why `t` can not be captured by lambda?
const auto lvlTp = lvlTypes[tid][lvl];
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
const auto reassoc = getCollapseReassociation(tid, lvl);
@@ -896,10 +897,11 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
// guarantee that segHi is defined: because we only generate segHi
// whenever coiterating, in order to improve code quality for the
// non-coiterating cases.
- const auto theSegHi = segHi[tid][srcLvl - 1];
- highs[tid][srcLvl] = (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && theSegHi)
- ? theSegHi
- : builder.create<arith::AddIOp>(loc, pLo, c1);
+ const auto parentSegHi = segHi[tid][srcLvl - 1];
+ highs[tid][srcLvl] =
+ (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && parentSegHi)
+ ? parentSegHi
+ : builder.create<arith::AddIOp>(loc, pLo, c1);
return;
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index b5772d6f7a100..f3b5a619b06e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -255,10 +255,10 @@ class LoopEmitter {
Location loc, Value crd,
TensorId tid, Level lvl);
- TensorId getNumTensors() const { return tensors.size(); }
+ unsigned getNumTensors() const { return tensors.size(); }
bool isOutputTensor(TensorId tid) const {
- return hasOutput && tid == static_cast<TensorId>(getNumTensors() - 1);
+ return hasOutput && tid == getNumTensors() - 1;
}
bool isSparseOutput(TensorId tid) const {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 4a3e62ffaca04..7827573b215a1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -943,11 +943,12 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
// one for loop?
// FIXME(wrengr): what is this "ld" supposed to be really?
const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l;
- loopEmitter.enterNewLoopSeq(rewriter, loc, 0, ld);
+ const SmallVector<TensorId, 1> tids{0};
+ loopEmitter.enterNewLoopSeq(rewriter, loc, tids, ld);
// Note that reduc will be taken care of by loop emitter and get updated
// in place.
- loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, 0, l, reduc);
+ loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tids, l, reduc);
}
SmallVector<Value> lcvs;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 23e30351f220a..64a86aa1a8570 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -225,7 +225,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
bool setLvlFormat = true) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- const LoopId idx = a.cast<AffineDimExpr>().getPosition();
+ const LoopId idx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
if (!isUndefDLT(merger.getDimLevelType(tid, idx)))
return false; // used more than once
@@ -239,7 +239,8 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
if (!isDenseDLT(dlt) && setLvlFormat) {
assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx)));
// Use a filter loop for sparse affine expression.
- merger.setLevelAndType(tid, filterLdx++, lvl, dlt);
+ merger.setLevelAndType(tid, filterLdx, lvl, dlt);
+ ++filterLdx;
}
if (auto binOp = a.dyn_cast<AffineBinaryOpExpr>()) {
@@ -279,7 +280,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
bool isSubExp = false) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- LoopId ldx = a.cast<AffineDimExpr>().getPosition();
+ const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
if (!isUndefDLT(merger.getDimLevelType(tensor, ldx)))
return false; // used more than once, e.g., A[i][i]
@@ -408,6 +409,7 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
// `filterLdx` may be mutated by `findAffine`.
LoopId filterLdx = env.merger().getStartingFilterLoopId();
for (OpOperand &t : env.op()->getOpOperands()) {
+ const TensorId tid = env.makeTensorId(t.getOperandNumber());
const auto map = env.op().getMatchingIndexingMap(&t);
const auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
@@ -426,9 +428,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
// If then current tensor being inspected requires affine index, it need
// to be sliced.
for (Level l = 0; l < lvlRank; l++) {
- const TensorId tid = t.getOperandNumber();
- AffineExpr a = map.getResult(toOrigDim(enc, l));
- DimLevelType dlt = enc.getLvlType(l);
+ // FIXME: `toOrigDim` is deprecated.
+ const AffineExpr a = map.getResult(toOrigDim(enc, l));
+ const DimLevelType dlt = enc.getLvlType(l);
if (idxReducBased && needIdxReduc) {
if (!findDepIdxSet(env.merger(), tid, l, a, dlt))
return false; // inadmissible affine expression
@@ -445,19 +447,19 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
/// A helper to compute a topological sort. O(n^2) time complexity
/// as we use adj matrix for the graph.
/// The sorted result will put the first Reduction iterator to the
-/// latest possible index.
-/// FIXME(wrengr): correct the above "index"
+/// latest possible `LoopOrd`.
///
/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
/// `(LoopId,LoopId)`.
-static bool topSortOptimal(CodegenEnv &env, LoopId n,
+static bool topSortOptimal(CodegenEnv &env,
ArrayRef<utils::IteratorType> iteratorTypes,
std::vector<unsigned> &inDegree,
std::vector<std::vector<bool>> &adjM) {
std::vector<LoopId> redIt; // reduce iterator with 0 degree
std::vector<LoopId> parIt; // parallel iterator with 0 degree
std::vector<LoopId> filterIt; // filter loop with 0 degree
- for (LoopId i = 0; i < n; i++) {
+ const LoopId numLoops = env.merger().getNumLoops();
+ for (LoopId i = 0; i < numLoops; i++) {
if (inDegree[i] == 0) {
if (env.merger().isFilterLoop(i))
filterIt.push_back(i);
@@ -493,7 +495,7 @@ static bool topSortOptimal(CodegenEnv &env, LoopId n,
env.topSortPushBack(src);
it.pop_back();
// Update in-degree, and push 0-degree node into worklist.
- for (LoopId dst = 0; dst < n; dst++) {
+ for (LoopId dst = 0; dst < numLoops; dst++) {
if (adjM[src][dst] && --inDegree[dst] == 0) {
if (env.merger().isFilterLoop(dst))
filterIt.push_back(dst);
@@ -504,7 +506,7 @@ static bool topSortOptimal(CodegenEnv &env, LoopId n,
}
}
}
- return env.topSortSize() == n;
+ return env.topSortSize() == numLoops;
}
/// Helper method to add all constraints from the indices in one affine
@@ -535,7 +537,8 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
const auto toExpand = a ? a : b;
switch (toExpand.getKind()) {
case AffineExprKind::DimId: {
- std::optional<LoopId> idx = toExpand.cast<AffineDimExpr>().getPosition();
+ const std::optional<LoopId> idx{
+ toExpand.cast<AffineDimExpr>().getPosition()};
if (toExpand == a)
addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx);
else // toExpand == b
@@ -597,22 +600,26 @@ static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
OpOperand *skip, SortMask mask,
std::vector<std::vector<bool>> &adjM,
std::vector<unsigned> &inDegree) {
- // Get map and encoding.
- auto map = env.op().getMatchingIndexingMap(&t);
- auto enc = getSparseTensorEncoding(t.get().getType());
+ // Get map, encoding, and tensor-identifier.
+ const auto map = env.op().getMatchingIndexingMap(&t);
+ const auto enc = getSparseTensorEncoding(t.get().getType());
+ const TensorId tid = env.makeTensorId(t.getOperandNumber());
// Each tensor expression and optional dimension ordering (row-major
// by default) puts an ordering constraint on the loop indices. For
// example, the tensor expresion A_ijk forces the ordering i < j < k
// on the loop indices if no explicit dimension ordering is given.
- for (Level l = 0, rank = map.getNumResults(); l < rank; l++) {
- AffineExpr ta = map.getResult(toOrigDim(enc, l));
- std::optional<LoopId> tldx =
- env.merger().getLoopId(t.getOperandNumber(), l);
+ const Level lvlRank = map.getNumResults();
+ assert(!enc || lvlRank == enc.getLvlRank());
+ for (Level lvl = 0; lvl < lvlRank; lvl++) {
+ // FIXME: `toOrigDim` is deprecated.
+ AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
+ std::optional<LoopId> tldx = env.merger().getLoopId(tid, lvl);
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
if (tldx && env.merger().isFilterLoop(*tldx)) {
- assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getDimLevelType()[l]));
+ assert(!ta.isa<AffineDimExpr>() &&
+ !isDenseDLT(enc.getDimLevelType()[lvl]));
addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
// Now that the ordering of affine expression is captured by filter
// loop idx, we only need to ensure the affine ordering against filter
@@ -626,10 +633,10 @@ static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
if (&t == skip)
continue;
- if (l > 0) {
- AffineExpr fa = map.getResult(toOrigDim(enc, l - 1));
- std::optional<LoopId> fldx =
- env.merger().getLoopId(t.getOperandNumber(), l - 1);
+ if (lvl > 0) {
+ // FIXME: `toOrigDim` is deprecated.
+ AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
+ std::optional<LoopId> fldx = env.merger().getLoopId(tid, lvl - 1);
// Applying order constraints on every pair of dimExpr between two
// compound affine expressions can sometime too strict:
@@ -657,8 +664,8 @@ static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
std::vector<std::vector<bool>> &adjM,
std::vector<unsigned> &inDegree) {
// Get map and encoding.
- auto map = env.op().getMatchingIndexingMap(&t);
- auto enc = getSparseTensorEncoding(t.get().getType());
+ const auto map = env.op().getMatchingIndexingMap(&t);
+ const auto enc = getSparseTensorEncoding(t.get().getType());
// No special treatment for simple indices.
if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0)
@@ -674,19 +681,22 @@ static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
// To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
// we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
// and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
- for (Level lvl = 1, rank = map.getNumResults(); lvl < rank; lvl++) {
- AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
- AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
+ const Level lvlRank = map.getNumResults();
+ assert(!enc || lvlRank == enc.getLvlRank());
+ for (Level lvl = 1; lvl < lvlRank; lvl++) {
+ // FIXME: `toOrigDim` is deprecated.
+ const AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
+ const AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
// This is a heuristic, we pick an abitrary reduction loop from lhs and
// rhs and use them as d_x and d_y.
finder.walkPostOrder(fa);
- AffineDimExpr fexp = finder.getDimExpr();
- LoopId fldx = fexp.getPosition();
+ const AffineDimExpr fexp = finder.getDimExpr();
+ const LoopId fldx = env.makeLoopId(fexp.getPosition());
finder.walkPostOrder(ta);
- AffineDimExpr texp = finder.getDimExpr();
- LoopId tldx = texp.getPosition();
+ const AffineDimExpr texp = finder.getDimExpr();
+ const LoopId tldx = env.makeLoopId(texp.getPosition());
// d_x > d_y
if (!adjM[fldx][tldx]) {
@@ -701,7 +711,7 @@ static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
// make sure dx and dy is the last;
for (auto fd : fCollector.dims) {
- LoopId f = fd.getPosition();
+ const LoopId f = env.makeLoopId(fd.getPosition());
if (f == fldx)
continue;
if (!adjM[f][fldx]) {
@@ -710,7 +720,7 @@ static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
}
}
for (auto td : tCollector.dims) {
- LoopId t = td.getPosition();
+ const LoopId t = env.makeLoopId(td.getPosition());
if (t == tldx)
continue;
if (!adjM[t][tldx]) {
@@ -728,12 +738,12 @@ static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
// TODO: the evaluation order need to be ensure to
// support affine multiplication.
for (auto fd : fCollector.dims) {
- LoopId f = fd.getPosition();
+ const LoopId f = env.makeLoopId(fd.getPosition());
if (f == fldx) // skip d_x
continue;
for (auto td : tCollector.dims) {
- LoopId t = td.getPosition();
+ const LoopId t = env.makeLoopId(td.getPosition());
if (t == tldx) // skip d_y
continue;
if (!adjM[f][t]) {
@@ -755,9 +765,10 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
OpOperand *skip, bool idxReducBased = false) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
- const LoopId n = env.merger().getNumLoops();
- std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
- std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
+ const unsigned numLoops = env.merger().getNumLoops();
+ std::vector<std::vector<bool>> adjM(numLoops,
+ std::vector<bool>(numLoops, false));
+ std::vector<unsigned> inDegree(numLoops, 0); // in-degree of each node.
const auto iteratorTypes = env.op().getIteratorTypesArray();
// Iterate over the indexing maps of every tensor in the tensor expression.
for (OpOperand &t : env.op()->getOpOperands()) {
@@ -765,7 +776,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
const auto enc = getSparseTensorEncoding(t.get().getType());
assert(env.op().getMatchingIndexingMap(&t).getNumDims() +
getNumNonTrivialIdxExpOnSparseLvls(env.op()) ==
- n);
+ numLoops);
// Skips dense inputs/outputs when not requested.
const bool isDenseInput = !enc && env.op().isDpsInput(&t);
@@ -778,12 +789,12 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
// will be skipped more often.
// TODO: Do we really need this?
if (includesUndef(mask)) {
- const TensorId tensor = t.getOperandNumber();
- for (LoopId i = 0; i < n; i++) {
- const auto dltI = env.dlt(tensor, i);
+ const TensorId tid = env.makeTensorId(t.getOperandNumber());
+ for (LoopId i = 0; i < numLoops; i++) {
+ const auto dltI = env.dlt(tid, i);
if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) {
- for (LoopId j = 0; j < n; j++)
- if (isUndefDLT(env.dlt(tensor, j))) {
+ for (LoopId j = 0; j < numLoops; j++)
+ if (isUndefDLT(env.dlt(tid, j))) {
adjM[i][j] = true;
inDegree[j]++;
}
@@ -801,8 +812,8 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
}
// Topologically sort the iteration graph to determine loop order.
// Report failure for a cyclic iteration graph.
- env.topSortClear(n);
- return topSortOptimal(env, n, iteratorTypes, inDegree, adjM);
+ env.topSortClear(numLoops);
+ return topSortOptimal(env, iteratorTypes, inDegree, adjM);
}
//===----------------------------------------------------------------------===//
@@ -856,16 +867,16 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
// a "coordinate", or "Ldx", or what). So the function should be renamed
// and/or the documentation expanded in order to clarify.
static Value genIndex(CodegenEnv &env, OpOperand *t) {
- auto map = env.op().getMatchingIndexingMap(t);
+ const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
// FIXME: `toOrigDim` is deprecated.
// FIXME: above we asserted that there are `lvlRank` many results,
// but this is assuming there are in fact `dimRank` many results instead.
- AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1));
+ const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1));
assert(a.getKind() == AffineExprKind::DimId);
- const LoopId idx = a.cast<AffineDimExpr>().getPosition();
+ const LoopId idx = env.makeLoopId(a.cast<AffineDimExpr>().getPosition());
return env.getLoopVar(idx);
}
@@ -873,7 +884,7 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVectorImpl<Value> &args) {
const Location loc = env.op().getLoc();
- const TensorId tid = t->getOperandNumber();
+ const TensorId tid = env.makeTensorId(t->getOperandNumber());
const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
if (stt.hasEncoding()) {
@@ -1092,7 +1103,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
Value e, LoopId ldx) {
if (Operation *def = e.getDefiningOp()) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return env.getLoopVar(indexOp.getDim());
+ return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
if (def->getBlock() == block) {
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
@@ -1153,7 +1164,7 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
linalg::GenericOp op = env.op();
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
- auto map = op.getMatchingIndexingMap(&t);
+ const auto map = op.getMatchingIndexingMap(&t);
const auto stt = getSparseTensorType(t.get());
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
@@ -1161,8 +1172,9 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
// FIXME: `toOrigDim` is deprecated.
// FIXME: above we asserted that there are `lvlRank` many results,
// but this is assuming there are in fact `dimRank` many results instead.
- AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l));
- const auto sldx = env.merger().getLoopId(t.getOperandNumber(), l);
+ const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l));
+ const auto sldx =
+ env.merger().getLoopId(env.makeTensorId(t.getOperandNumber()), l);
if (sldx && env.merger().isFilterLoop(*sldx)) {
if (!env.getLoopVar(*sldx))
// The filter loops has not been constructed.
@@ -1386,29 +1398,28 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
- const BitVector &conditions) {
+ LatPointId p) {
Location loc = env.op().getLoc();
SmallVector<Type> types;
Value cond;
- for (TensorLoopId b = 0, be = conditions.size(); b < be; b++) {
- if (!conditions[b])
- continue;
- const TensorId tid = env.merger().tensor(b);
- assert(ldx == env.merger().loop(b));
- Value clause;
- const auto dlt = env.dlt(b);
- if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
- const Level lvl = *env.merger().getLvl(tid, ldx);
- const Value crd = env.emitter().getCoords()[tid][lvl];
- const Value lvar = env.getLoopVar(ldx);
- clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, crd,
- lvar);
- } else {
- assert(isDenseDLT(dlt) || isUndefDLT(dlt));
- clause = constantI1(builder, loc, true);
- }
- cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
- }
+ env.merger().foreachTensorLoopId(
+ p, /*simple=*/true,
+ [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
+ DimLevelType dlt, bool /*unused*/) {
+ assert(ldx == env.merger().loop(b));
+ Value clause;
+ if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
+ assert(lvl.has_value());
+ const Value crd = env.emitter().getCoords()[tid][*lvl];
+ const Value lvar = env.getLoopVar(ldx);
+ clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ crd, lvar);
+ } else {
+ assert(isDenseDLT(dlt) || isUndefDLT(dlt));
+ clause = constantI1(builder, loc, true);
+ }
+ cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
+ });
if (env.isReduc()) {
types.push_back(env.getReduc().getType());
if (env.getValidLexInsert())
@@ -1505,7 +1516,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
const auto enc = getSparseTensorEncoding(input->get().getType());
if (enc) {
const Location loc = op.getLoc();
- const TensorId tid = input->getOperandNumber();
+ const TensorId tid = env.makeTensorId(input->getOperandNumber());
const Level lvlRank = enc.getLvlRank();
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
// FIXME: there is dim/lvl confusion here
@@ -1545,7 +1556,7 @@ static bool translateBitsToTidLvlPairs(
env.merger().foreachTensorLoopId(
li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
DimLevelType dlt, bool isIdxReduc) {
- if (simple.test(b)) {
+ if (simple[b]) {
if (isIdxReduc) {
tids.push_back(tid);
lvls.push_back(*lvl);
@@ -1634,8 +1645,8 @@ static bool translateBitsToTidLvlPairs(
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
- OpBuilder &builder, unsigned at,
- unsigned li, bool needsUniv) {
+ OpBuilder &builder, LoopOrd at,
+ LatPointId li, bool needsUniv) {
// The set of tensors + lvls to generate loops on
SmallVector<TensorId> tids, affineTids;
SmallVector<Level> lvls, affineLvls;
@@ -1748,7 +1759,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
if (!isSingleCond) {
- scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
+ scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
genStmt(env, rewriter, ej, at + 1);
endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
} else {
@@ -1899,7 +1910,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// sparse input tensor in succession until an acylic
// iteration graph results.
for (OpOperand *t : env.op().getDpsInputOperands()) {
- const TensorId tid = t->getOperandNumber();
+ const TensorId tid = env.makeTensorId(t->getOperandNumber());
Value tval = t->get();
auto srcEnc = getSparseTensorEncoding(tval.getType());
if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t))
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9b929a5dda25e..a79bfbb9c1080 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -98,7 +98,8 @@ static ExpArity getExpArity(TensorExp::Kind k) {
// Constructors.
//===----------------------------------------------------------------------===//
-TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
+TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
+ Operation *o)
: kind(k), val(v), op(o) {
switch (kind) {
// Leaf.
@@ -200,16 +201,6 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
llvm_unreachable("unexpected kind");
}
-LatPoint::LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
-
-LatPoint::LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i,
- ExprId e)
- : bits(numLoops * numTensors, false), exp(e) {
- assert(t < numTensors && i < numLoops);
- const TensorLoopId b = numTensors * i + t;
- bits.set(b);
-}
-
Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
unsigned numFilterLoops, unsigned maxLvlRank)
: outTensor(numInputOutputTensors - 1),
@@ -232,61 +223,92 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
// Lattice methods.
//===----------------------------------------------------------------------===//
-ExprId Merger::addExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
- Operation *op) {
- const ExprId e = tensorExps.size();
- assert((k != TensorExp::Kind::kTensor || x < numTensors) &&
- (k != TensorExp::Kind::kLoopVar || x < numLoops));
- tensorExps.emplace_back(k, x, y, v, op);
- return e;
+ExprId Merger::addTensorExp(TensorId t) {
+ assert(isValidTensorId(t));
+ const ExprId eNew(tensorExps.size());
+ tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
+ Value(), nullptr);
+ return eNew;
+}
+
+ExprId Merger::addLoopVarExp(LoopId i) {
+ assert(isValidLoopId(i));
+ const ExprId eNew(tensorExps.size());
+ tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
+ Value(), nullptr);
+ return eNew;
+}
+
+ExprId Merger::addInvariantExp(Value v) {
+ const ExprId eNew(tensorExps.size());
+ tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
+ detail::kInvalidId, v, nullptr);
+ return eNew;
+}
+
+ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) {
+ assert(k > TensorExp::Kind::kLoopVar);
+ const ExprId eNew(tensorExps.size());
+ tensorExps.emplace_back(k, e0, e1, Value(), op);
+ return eNew;
+}
+
+ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) {
+ assert(k > TensorExp::Kind::kLoopVar);
+ const ExprId eNew(tensorExps.size());
+ tensorExps.emplace_back(k, e, detail::kInvalidId, v, op);
+ return eNew;
}
LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
- assert(t < numTensors && i < numLoops);
- const LatPointId p = latPoints.size();
- latPoints.emplace_back(numTensors, numLoops, t, i, e);
- return p;
+ const LatPointId pNew(latPoints.size());
+ const unsigned size = numLoops * numTensors;
+ const TensorLoopId b = makeTensorLoopId(t, i);
+ latPoints.emplace_back(size, e);
+ latPoints[pNew].bits.set(b);
+ return pNew;
}
LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
assert(bits.size() == numLoops * numTensors);
- const LatPointId p = latPoints.size();
+ const LatPointId pNew(latPoints.size());
latPoints.emplace_back(bits, e);
- return p;
+ return pNew;
}
LatSetId Merger::addSet() {
- const LatSetId s = latSets.size();
+ const LatSetId sNew(latSets.size());
latSets.emplace_back();
- return s;
+ return sNew;
}
LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1,
Operation *op) {
- const LatPointId p = latPoints.size();
- BitVector bits(latPoints[p0].bits);
- bits |= latPoints[p1].bits;
- const ExprId e =
- addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
+ const LatPointId pNew(latPoints.size());
+ const auto &point0 = lat(p0);
+ const auto &point1 = lat(p1);
+ BitVector bits(point0.bits);
+ bits |= point1.bits;
+ const ExprId e = addExp(kind, point0.exp, point1.exp, op);
latPoints.emplace_back(bits, e);
- return p;
+ return pNew;
}
LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
Operation *op) {
- const LatSetId s = addSet();
- for (const LatPointId p0 : latSets[s0])
- for (const LatPointId p1 : latSets[s1])
- latSets[s].push_back(conjLat(kind, p0, p1, op));
- return s;
+ const LatSetId sNew = addSet();
+ auto &setNew = latSets[sNew];
+ for (const LatPointId p0 : set(s0))
+ for (const LatPointId p1 : set(s1))
+ setNew.push_back(conjLat(kind, p0, p1, op));
+ return sNew;
}
LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
Operation *op) {
- const LatSetId s = conjSet(kind, s0, s1, op);
+ const LatSetId sNew = conjSet(kind, s0, s1, op);
// Followed by all in s0.
- for (const LatPointId p : latSets[s0])
- latSets[s].push_back(p);
+ latSets[sNew].append(latSets[s0]);
// Map binary 0-y to unary -y.
// TODO: move this if-else logic into buildLattices
if (kind == TensorExp::Kind::kSubF)
@@ -296,9 +318,8 @@ LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
else if (kind == TensorExp::Kind::kSubI)
s1 = mapSet(TensorExp::Kind::kNegI, s1);
// Followed by all in s1.
- for (const LatPointId p : latSets[s1])
- latSets[s].push_back(p);
- return s;
+ latSets[sNew].append(latSets[s1]);
+ return sNew;
}
LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
@@ -306,48 +327,48 @@ LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
TensorExp::Kind ltrans, Operation *opleft,
bool includeRight, TensorExp::Kind rtrans,
Operation *opright) {
- const LatSetId s = conjSet(kind, s0, s1, orig);
+ const LatSetId sNew = conjSet(kind, s0, s1, orig);
// Left Region.
if (includeLeft) {
if (opleft)
s0 = mapSet(ltrans, s0, Value(), opleft);
- for (const LatPointId p : latSets[s0])
- latSets[s].push_back(p);
+ latSets[sNew].append(latSets[s0]);
}
// Right Region.
if (includeRight) {
if (opright)
s1 = mapSet(rtrans, s1, Value(), opright);
- for (const LatPointId p : latSets[s1])
- latSets[s].push_back(p);
+ latSets[sNew].append(latSets[s1]);
}
- return s;
+ return sNew;
}
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
Operation *op) {
assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
- const LatSetId s = addSet();
- for (const LatPointId p : latSets[s0]) {
- const ExprId e = addExp(kind, latPoints[p].exp, v, op);
- latSets[s].push_back(addLat(latPoints[p].bits, e));
+ const LatSetId sNew = addSet();
+ auto &setNew = latSets[sNew];
+ for (const LatPointId p : set(s0)) {
+ const auto &point = latPoints[p];
+ setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
}
- return s;
+ return sNew;
}
LatSetId Merger::optimizeSet(LatSetId s0) {
- const LatSetId s = addSet();
- assert(!latSets[s0].empty());
- const LatPointId p0 = latSets[s0][0];
- for (const LatPointId p1 : latSets[s0]) {
+ const LatSetId sNew = addSet();
+ auto &setNew = latSets[sNew];
+ const auto &set0 = set(s0);
+ assert(!set0.empty());
+ const LatPointId p0 = set0[0];
+ for (const LatPointId p1 : set0) {
bool add = true;
if (p0 != p1) {
// Check whether this is a straightforward copy.
- const ExprId e = latPoints[p1].exp;
- if (expIsTensor(e, outTensor))
+ if (expIsTensor(latPoints[p1].exp, outTensor))
continue;
// Check whether this conjunction is already covered.
- for (const LatPointId p2 : latSets[s]) {
+ for (const LatPointId p2 : setNew) {
assert(!latGT(p1, p2)); // Lj => Li would be bad
if (onlyDenseDiff(p2, p1)) {
add = false;
@@ -357,34 +378,38 @@ LatSetId Merger::optimizeSet(LatSetId s0) {
assert(!add || latGT(p0, p1));
}
if (add)
- latSets[s].push_back(p1);
+ setNew.push_back(p1);
}
- for (const LatPointId p : latSets[s])
- latPoints[p].simple = simplifyCond(s, p);
- return s;
+ for (const LatPointId p : setNew)
+ latPoints[p].simple = simplifyCond(sNew, p);
+ return sNew;
}
BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// First determine if this lattice point is a *singleton*, i.e.,
// the last point in a lattice, no other is less than this one.
bool isSingleton = true;
- for (const LatPointId p1 : latSets[s0]) {
+ for (const LatPointId p1 : set(s0)) {
if (p0 != p1 && latGT(p0, p1)) {
isSingleton = false;
break;
}
}
- BitVector simple(latPoints[p0].bits);
+ BitVector simple(lat(p0).bits);
bool reset =
isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple));
- const TensorLoopId be = simple.size();
- TensorLoopId offset = 0; // relative to the end
+ // `be`, `b`, and `offset` are `TensorLoopId` in spirit; but we avoid
+ // using that class in this function because we need to do a bunch of
+ // arithmetic on them, so using the newtype would introduce too much
+ // boilerplate.
+ const unsigned be = simple.size();
+ unsigned offset = 0; // relative to the end
if (!reset)
// Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type.
- for (TensorLoopId b = 0; b < be; b++) {
- if (simple[b] && isDenseDLT(getDimLevelType(b))) {
+ for (unsigned b = 0; b < be; b++) {
+ if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) {
offset = be - b - 1; // relative to the end
break;
}
@@ -392,12 +417,12 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Now apply the two basic rules. We also iterate the bits reversely to always
// keep the rightmost bit (which could possibly be a synthetic tensor).
- for (TensorLoopId b = be - 1 - offset, i = 0; i < be;
+ for (unsigned b = be - 1 - offset, i = 0; i < be;
b = b == 0 ? be - 1 : b - 1, i++) {
// FIXME: better name? also slice on dense level has locate property as
// well. Handle it correctly!
- if (simple[b] && !isLvlWithNonTrivialIdxExp(b)) {
- const auto dlt = getDimLevelType(b);
+ if (simple[b] && !isLvlWithNonTrivialIdxExp(TensorLoopId{b})) {
+ const auto dlt = getDimLevelType(TensorLoopId{b});
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) {
if (reset)
simple.reset(b);
@@ -409,8 +434,8 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
}
bool Merger::latGT(LatPointId i, LatPointId j) const {
- const BitVector &bitsi = latPoints[i].bits;
- const BitVector &bitsj = latPoints[j].bits;
+ const BitVector &bitsi = lat(i).bits;
+ const BitVector &bitsj = lat(j).bits;
assert(bitsi.size() == bitsj.size());
if (bitsi.count() > bitsj.count()) {
for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
@@ -422,27 +447,28 @@ bool Merger::latGT(LatPointId i, LatPointId j) const {
}
bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
- BitVector tmp(latPoints[j].bits);
- tmp ^= latPoints[i].bits;
+ BitVector tmp(lat(j).bits);
+ tmp ^= lat(i).bits;
return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp);
}
bool Merger::expContainsTensor(ExprId e, TensorId t) const {
- if (tensorExps[e].kind == TensorExp::Kind::kTensor)
- return tensorExps[e].tensor == t;
+ const auto &expr = exp(e);
+ if (expr.kind == TensorExp::Kind::kTensor)
+ return expr.tensor == t;
- switch (getExpArity(tensorExps[e].kind)) {
+ switch (getExpArity(expr.kind)) {
case ExpArity::kNullary:
return false;
case ExpArity::kUnary: {
- const ExprId e0 = tensorExps[e].children.e0;
+ const ExprId e0 = expr.children.e0;
if (expIsTensor(e0, t))
return true;
return expContainsTensor(e0, t);
}
case ExpArity::kBinary: {
- const ExprId e0 = tensorExps[e].children.e0;
- const ExprId e1 = tensorExps[e].children.e1;
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
if (expIsTensor(e0, t) || expIsTensor(e1, t))
return true;
return expContainsTensor(e0, t) || expContainsTensor(e1, t);
@@ -452,25 +478,26 @@ bool Merger::expContainsTensor(ExprId e, TensorId t) const {
}
bool Merger::hasNegateOnOut(ExprId e) const {
- switch (tensorExps[e].kind) {
+ const auto &expr = exp(e);
+ switch (expr.kind) {
case TensorExp::Kind::kNegF:
case TensorExp::Kind::kNegC:
case TensorExp::Kind::kNegI:
- return expContainsTensor(tensorExps[e].children.e0, outTensor);
+ return expContainsTensor(expr.children.e0, outTensor);
case TensorExp::Kind::kSubF:
case TensorExp::Kind::kSubC:
case TensorExp::Kind::kSubI:
- return expContainsTensor(tensorExps[e].children.e1, outTensor) ||
- hasNegateOnOut(tensorExps[e].children.e0);
+ return expContainsTensor(expr.children.e1, outTensor) ||
+ hasNegateOnOut(expr.children.e0);
default: {
- switch (getExpArity(tensorExps[e].kind)) {
+ switch (getExpArity(expr.kind)) {
case ExpArity::kNullary:
return false;
case ExpArity::kUnary:
- return hasNegateOnOut(tensorExps[e].children.e0);
+ return hasNegateOnOut(expr.children.e0);
case ExpArity::kBinary:
- return hasNegateOnOut(tensorExps[e].children.e0) ||
- hasNegateOnOut(tensorExps[e].children.e1);
+ return hasNegateOnOut(expr.children.e0) ||
+ hasNegateOnOut(expr.children.e1);
}
}
}
@@ -478,11 +505,12 @@ bool Merger::hasNegateOnOut(ExprId e) const {
}
bool Merger::isSingleCondition(TensorId t, ExprId e) const {
- assert(t < numTensors && e < tensorExps.size());
- switch (tensorExps[e].kind) {
+ assert(isValidTensorId(t));
+ const auto &expr = exp(e);
+ switch (expr.kind) {
// Leaf.
case TensorExp::Kind::kTensor:
- return tensorExps[e].tensor == t;
+ return expr.tensor == t;
case TensorExp::Kind::kInvariant:
case TensorExp::Kind::kLoopVar:
return false;
@@ -518,7 +546,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kCIm:
case TensorExp::Kind::kCRe:
case TensorExp::Kind::kBitCast:
- return isSingleCondition(t, tensorExps[e].children.e0);
+ return isSingleCondition(t, expr.children.e0);
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kUnary:
case TensorExp::Kind::kSelect:
@@ -528,28 +556,28 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kDivC:
case TensorExp::Kind::kDivS:
case TensorExp::Kind::kDivU:
- assert(!maybeZero(tensorExps[e].children.e1));
- return isSingleCondition(t, tensorExps[e].children.e0);
+ assert(!maybeZero(expr.children.e1));
+ return isSingleCondition(t, expr.children.e0);
case TensorExp::Kind::kShrS: // note: x >> inv only
case TensorExp::Kind::kShrU:
case TensorExp::Kind::kShlI:
- assert(isInvariant(tensorExps[e].children.e1));
- return isSingleCondition(t, tensorExps[e].children.e0);
+ assert(isInvariant(expr.children.e1));
+ return isSingleCondition(t, expr.children.e0);
case TensorExp::Kind::kMulF:
case TensorExp::Kind::kMulC:
case TensorExp::Kind::kMulI:
case TensorExp::Kind::kAndI:
- if (isSingleCondition(t, tensorExps[e].children.e0))
- return isSingleCondition(t, tensorExps[e].children.e1) ||
- isInvariant(tensorExps[e].children.e1);
- if (isSingleCondition(t, tensorExps[e].children.e1))
- return isInvariant(tensorExps[e].children.e0);
+ if (isSingleCondition(t, expr.children.e0))
+ return isSingleCondition(t, expr.children.e1) ||
+ isInvariant(expr.children.e1);
+ if (isSingleCondition(t, expr.children.e1))
+ return isInvariant(expr.children.e0);
return false;
case TensorExp::Kind::kAddF:
case TensorExp::Kind::kAddC:
case TensorExp::Kind::kAddI:
- return isSingleCondition(t, tensorExps[e].children.e0) &&
- isSingleCondition(t, tensorExps[e].children.e1);
+ return isSingleCondition(t, expr.children.e0) &&
+ isSingleCondition(t, expr.children.e1);
case TensorExp::Kind::kSubF:
case TensorExp::Kind::kSubC:
case TensorExp::Kind::kSubI:
@@ -684,20 +712,21 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
}
void Merger::dumpExp(ExprId e) const {
- switch (tensorExps[e].kind) {
+ const auto &expr = exp(e);
+ switch (expr.kind) {
// Leaf.
case TensorExp::Kind::kTensor:
- if (tensorExps[e].tensor == syntheticTensor)
+ if (expr.tensor == syntheticTensor)
llvm::dbgs() << "synthetic_";
- else if (tensorExps[e].tensor == outTensor)
+ else if (expr.tensor == outTensor)
llvm::dbgs() << "output_";
- llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
+ llvm::dbgs() << "tensor_" << expr.tensor;
break;
case TensorExp::Kind::kInvariant:
llvm::dbgs() << "invariant";
break;
case TensorExp::Kind::kLoopVar:
- llvm::dbgs() << "loopvar_" << tensorExps[e].loop;
+ llvm::dbgs() << "loopvar_" << expr.loop;
break;
// Unary operations.
case TensorExp::Kind::kAbsF:
@@ -734,8 +763,8 @@ void Merger::dumpExp(ExprId e) const {
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kUnary:
case TensorExp::Kind::kSelect:
- llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
- dumpExp(tensorExps[e].children.e0);
+ llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
+ dumpExp(expr.children.e0);
break;
// Binary operations.
case TensorExp::Kind::kMulF:
@@ -760,26 +789,28 @@ void Merger::dumpExp(ExprId e) const {
case TensorExp::Kind::kBinary:
case TensorExp::Kind::kReduce:
llvm::dbgs() << "(";
- dumpExp(tensorExps[e].children.e0);
- llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
- dumpExp(tensorExps[e].children.e1);
+ dumpExp(expr.children.e0);
+ llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " ";
+ dumpExp(expr.children.e1);
llvm::dbgs() << ")";
}
}
void Merger::dumpLat(LatPointId p) const {
+ const auto &point = lat(p);
llvm::dbgs() << "lat(";
- dumpBits(latPoints[p].bits);
+ dumpBits(point.bits);
llvm::dbgs() << " :";
- dumpBits(latPoints[p].simple);
+ dumpBits(point.simple);
llvm::dbgs() << " : ";
- dumpExp(latPoints[p].exp);
+ dumpExp(point.exp);
llvm::dbgs() << " )\n";
}
void Merger::dumpSet(LatSetId s) const {
- llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
- for (const LatPointId p : latSets[s]) {
+ const auto &ss = set(s);
+ llvm::dbgs() << "{ #" << ss.size() << "\n";
+ for (const LatPointId p : ss) {
llvm::dbgs() << " ";
dumpLat(p);
}
@@ -807,7 +838,12 @@ void Merger::dumpBits(const BitVector &bits) const {
//===----------------------------------------------------------------------===//
LatSetId Merger::buildLattices(ExprId e, LoopId i) {
- const TensorExp::Kind kind = tensorExps[e].kind;
+ // NOTE: The `expr` reference will be invalidated by recursive calls
+ // (and any other method that may add new expressions); therefore, the
+ // code below must make sure to copy fields of `expr` into local variables
+ // before making any recursive calls.
+ const auto &expr = exp(e);
+ const TensorExp::Kind kind = expr.kind;
switch (kind) {
// Leaf.
case TensorExp::Kind::kTensor:
@@ -821,7 +857,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
const LatSetId s = addSet();
TensorId t = syntheticTensor;
if (kind == TensorExp::Kind::kTensor) {
- t = tensorExps[e].tensor;
+ t = expr.tensor;
if (hasSparseOut && t == outTensor)
t = syntheticTensor;
}
@@ -866,14 +902,20 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// -y|!y | y |
// --+---+---+
// | 0 |-y |
- return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
- tensorExps[e].val);
+ {
+ const ExprId e0 = expr.children.e0;
+ const Value v = expr.val;
+ return mapSet(kind, buildLattices(e0, i), v);
+ }
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kSelect:
// The left or right half of a binary operation which has already
// been split into separate operations for each region.
- return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
- tensorExps[e].op);
+ {
+ const ExprId e0 = expr.children.e0;
+ Operation *const op = expr.op;
+ return mapSet(kind, buildLattices(e0, i), Value(), op);
+ }
case TensorExp::Kind::kUnary:
// A custom unary operation.
//
@@ -881,8 +923,9 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// ----+----------+------------+
// | absent() | present(y) |
{
- const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i);
- UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
+ const ExprId e0 = expr.children.e0;
+ UnaryOp unop = cast<UnaryOp>(expr.op);
+ const LatSetId child0 = buildLattices(e0, i);
Region &absentRegion = unop.getAbsentRegion();
if (absentRegion.empty()) {
@@ -892,8 +935,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// invariant on the right.
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
- Value absentVal = absentYield.getResult();
- const ExprId rhs = addExp(TensorExp::Kind::kInvariant, absentVal);
+ const Value absentVal = absentYield.getResult();
+ const ExprId rhs = addInvariantExp(absentVal);
return disjSet(kind, child0, buildLattices(rhs, i), unop);
}
// Binary operations.
@@ -910,8 +953,11 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// x | 0 |x*y|
//
// Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
- return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ {
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
+ }
case TensorExp::Kind::kDivF:
case TensorExp::Kind::kDivC:
case TensorExp::Kind::kDivS:
@@ -929,9 +975,12 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// during expression building, so that the conjunction
// rules applies (viz. x/c = x*(1/c) as far as lattice
// construction is concerned).
- assert(!maybeZero(tensorExps[e].children.e1));
- return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ {
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ assert(!maybeZero(e1));
+ return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
+ }
case TensorExp::Kind::kAddF:
case TensorExp::Kind::kAddC:
case TensorExp::Kind::kAddI:
@@ -947,17 +996,23 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// ---+---+---+ ---+---+---+
// !x | 0 | y | !x | 0 |-y |
// x | x |x+y| x | x |x-y|
- return disjSet(kind, buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ {
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ return disjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
+ }
case TensorExp::Kind::kShrS:
case TensorExp::Kind::kShrU:
case TensorExp::Kind::kShlI:
// A shift operation by an invariant amount (viz. tensor expressions
// can only occur at the left-hand-side of the operator) can be handled
// with the conjuction rule.
- assert(isInvariant(tensorExps[e].children.e1));
- return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i));
+ {
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ assert(isInvariant(e1));
+ return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
+ }
case TensorExp::Kind::kBinary:
// A custom binary operation.
//
@@ -966,9 +1021,11 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// !x | empty | right(y) |
// x | left(x) | overlap(x,y) |
{
- const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i);
- const LatSetId child1 = buildLattices(tensorExps[e].children.e1, i);
- BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ BinaryOp binop = cast<BinaryOp>(expr.op);
+ const LatSetId child0 = buildLattices(e0, i);
+ const LatSetId child1 = buildLattices(e1, i);
Region &leftRegion = binop.getLeftRegion();
Region &rightRegion = binop.getRightRegion();
// Left Region.
@@ -991,9 +1048,12 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
}
case TensorExp::Kind::kReduce:
// A custom reduce operation.
- return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
- buildLattices(tensorExps[e].children.e1, i),
- tensorExps[e].op);
+ {
+ const ExprId e0 = expr.children.e0;
+ const ExprId e1 = expr.children.e1;
+ Operation *const op = expr.op;
+ return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i), op);
+ }
}
llvm_unreachable("unexpected expression kind");
}
@@ -1007,27 +1067,24 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(ExprId e) const {
- if (tensorExps[e].kind == TensorExp::Kind::kInvariant) {
- if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
+ const auto &expr = exp(e);
+ if (expr.kind == TensorExp::Kind::kInvariant) {
+ if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
arrayAttr[1].cast<FloatAttr>().getValue().isZero();
}
- if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
+ if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
return c.value() == 0;
- if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
+ if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
return c.value().isZero();
}
return true;
}
-bool Merger::isInvariant(ExprId e) const {
- return tensorExps[e].kind == TensorExp::Kind::kInvariant;
-}
-
Type Merger::inferType(ExprId e, Value src) const {
// Obtain the destination type from the cast node.
- Type dtp = tensorExps[e].val.getType();
+ Type dtp = exp(e).val.getType();
// Inspect source type. For vector types, apply the same
// vectorization to the destination type.
if (auto vtp = src.getType().dyn_cast<VectorType>())
@@ -1067,28 +1124,28 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
- const TensorId argN = arg.getArgNumber();
+ const TensorId tid = makeTensorId(arg.getArgNumber());
// Any argument of the generic op that is not marked as a scalar
// argument is considered a tensor, indexed by the implicit loop
// bounds. This includes rank-0 tensor arguments.
if (arg.getOwner()->getParentOp() == op) {
- OpOperand &t = op->getOpOperand(argN);
+ OpOperand &t = op->getOpOperand(tid);
if (!op.isScalar(&t))
- return addExp(TensorExp::Kind::kTensor, argN);
+ return addTensorExp(tid);
v = t.get(); // get scalar value
}
// Any other argument (marked as scalar argument for the generic op
// or belonging to an enveloping op) is considered invariant.
- return addExp(TensorExp::Kind::kInvariant, v);
+ return addInvariantExp(v);
}
// Something defined outside is invariant.
Operation *def = v.getDefiningOp();
if (def->getBlock() != &op.getRegion().front())
- return addExp(TensorExp::Kind::kInvariant, v);
+ return addInvariantExp(v);
// Construct index operations.
if (def->getNumOperands() == 0) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return addExp(TensorExp::Kind::kLoopVar, indexOp.getDim());
+ return addLoopVarExp(makeLoopId(indexOp.getDim()));
}
// Construct unary operations if subexpression can be built.
if (def->getNumOperands() == 1) {
@@ -1219,7 +1276,7 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
(binop.getRightIdentity() ||
isAdmissibleBranch(binop, binop.getRightRegion())))
- return addExp(TensorExp::Kind::kBinary, e0, e1, Value(), def);
+ return addExp(TensorExp::Kind::kBinary, e0, e1, def);
}
}
}
@@ -1233,7 +1290,7 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
const ExprId e1 = *y;
if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
if (isAdmissibleBranch(redop, redop.getRegion()))
- return addExp(TensorExp::Kind::kReduce, e0, e1, Value(), def);
+ return addExp(TensorExp::Kind::kReduce, e0, e1, def);
}
}
}
@@ -1288,7 +1345,8 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
Value v1) const {
- switch (tensorExps[e].kind) {
+ const auto &expr = exp(e);
+ switch (expr.kind) {
// Leaf.
case TensorExp::Kind::kTensor:
case TensorExp::Kind::kInvariant:
@@ -1410,17 +1468,17 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kShlI:
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
- return insertYieldOp(rewriter, loc,
- *tensorExps[e].op->getBlock()->getParent(), {v0});
+ return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
+ {v0});
case TensorExp::Kind::kUnary:
- return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
+ return buildUnaryPresent(rewriter, loc, expr.op, v0);
case TensorExp::Kind::kSelect:
- return insertYieldOp(rewriter, loc,
- cast<SelectOp>(tensorExps[e].op).getRegion(), {v0});
+ return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
+ {v0});
case TensorExp::Kind::kBinary:
- return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
+ return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
case TensorExp::Kind::kReduce: {
- ReduceOp redOp = cast<ReduceOp>(tensorExps[e].op);
+ ReduceOp redOp = cast<ReduceOp>(expr.op);
return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
}
}
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index f6cd3e05fe0fb..38c3b363b585f 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -140,21 +140,15 @@ class MergerTestBase : public ::testing::Test {
/*maxRank=*/numLoops) {
tensors.reserve(numTensors);
for (unsigned t = 0; t < numTensors; t++)
- tensors.push_back(merger.addExp(TensorExp::Kind::kTensor, tid(t)));
+ tensors.push_back(merger.addTensorExp(tid(t)));
}
///
/// Expression construction helpers.
///
- TensorId tid(unsigned t) const {
- assert(t < merger.getNumTensors());
- return t;
- }
- LoopId lid(unsigned i) const {
- assert(i < merger.getNumLoops());
- return i;
- }
+ TensorId tid(unsigned t) const { return merger.makeTensorId(t); }
+ LoopId lid(unsigned i) const { return merger.makeLoopId(i); }
ExprId tensor(unsigned t) const {
assert(t < tensors.size());
return tensors[t];
@@ -208,11 +202,9 @@ class MergerTestBase : public ::testing::Test {
/// Converts a vector of (loop, tensor) pairs to a bitvector with the
/// corresponding bits set.
BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
- // NOTE: this `numTensors` includes both the output- and synthetic-tensors.
- const auto numTensors = merger.getNumTensors();
- BitVector testBits = BitVector(numTensors, false);
+ BitVector testBits = BitVector(merger.getNumTensors(), false);
for (auto [loop, tensor] : loops)
- testBits.set(numTensors * loop + tensor);
+ testBits.set(merger.makeTensorLoopId(tensor, loop));
return testBits;
}
More information about the Mlir-commits
mailing list