[Mlir-commits] [mlir] 74c5420 - [mlir][sparse] moving kInvalidId into "detail" namespace
wren romano
llvmlistbot at llvm.org
Fri Mar 24 15:15:42 PDT 2023
Author: wren romano
Date: 2023-03-24T15:15:34-07:00
New Revision: 74c54206d7abdf2680c0b67265ab0a61bc053f5d
URL: https://github.com/llvm/llvm-project/commit/74c54206d7abdf2680c0b67265ab0a61bc053f5d
DIFF: https://github.com/llvm/llvm-project/commit/74c54206d7abdf2680c0b67265ab0a61bc053f5d.diff
LOG: [mlir][sparse] moving kInvalidId into "detail" namespace
In the next few commits I will be converting the various Merger identifier typedefs into newtypes; and once that's done, the `kInvalidId` constant will only be used internally and therefore does not need to be part of the public `mlir::sparse_tensor` namespace.
Depends On D146673
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146674
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 2c13ad2d9238e..1a11010971f23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -76,9 +76,11 @@ using LatPointId = unsigned;
/// for the corresponding `SmallVector<LatPointId>` object.
using LatSetId = unsigned;
+namespace detail {
/// A constant serving as the canonically invalid identifier, regardless
/// of the identifier type.
static constexpr unsigned kInvalidId = -1u;
+} // namespace detail
/// Tensor expression. Represents an MLIR expression in tensor index notation.
struct TensorExp final {
@@ -272,13 +274,13 @@ class Merger {
/// Constructs a new tensor expression, and returns its identifier.
/// The type of the `e0` argument varies according to the value of the
/// `k` argument, as described by the `TensorExp` ctor.
- ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = kInvalidId,
+ ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = detail::kInvalidId,
Value v = Value(), Operation *op = nullptr);
ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) {
- return addExp(k, e, kInvalidId, v, op);
+ return addExp(k, e, detail::kInvalidId, v, op);
}
ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) {
- return addExp(k, kInvalidId, kInvalidId, v, op);
+ return addExp(k, detail::kInvalidId, detail::kInvalidId, v, op);
}
/// Constructs a new iteration lattice point, and returns its identifier.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 5d9c347b62327..5631688096033 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -56,7 +56,8 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
latticeMerger(numTensors, numLoops, numFilterLoops, maxRank),
loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
- redExp(kInvalidId), redCustom(kInvalidId), redValidLexInsert() {}
+ redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
+ redValidLexInsert() {}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.
@@ -277,7 +278,7 @@ void CodegenEnv::endExpand() {
//===----------------------------------------------------------------------===//
void CodegenEnv::startReduc(ExprId exp, Value val) {
- assert(!isReduc() && exp != kInvalidId);
+ assert(!isReduc() && exp != detail::kInvalidId);
redExp = exp;
updateReduc(val);
}
@@ -296,7 +297,7 @@ Value CodegenEnv::endReduc() {
Value val = redVal;
redVal = val;
latticeMerger.clearExprValue(redExp);
- redExp = kInvalidId;
+ redExp = detail::kInvalidId;
return val;
}
@@ -311,7 +312,7 @@ void CodegenEnv::clearValidLexInsert() {
}
void CodegenEnv::startCustomReduc(ExprId exp) {
- assert(!isCustomReduc() && exp != kInvalidId);
+ assert(!isCustomReduc() && exp != detail::kInvalidId);
redCustom = exp;
}
@@ -322,5 +323,5 @@ Value CodegenEnv::getCustomRedId() {
void CodegenEnv::endCustomReduc() {
assert(isCustomReduc());
- redCustom = kInvalidId;
+ redCustom = detail::kInvalidId;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index e11e2428d86c9..b544478801d66 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -134,7 +134,7 @@ class CodegenEnv {
//
void startReduc(ExprId exp, Value val);
- bool isReduc() const { return redExp != kInvalidId; }
+ bool isReduc() const { return redExp != detail::kInvalidId; }
void updateReduc(Value val);
Value getReduc() const { return redVal; }
Value endReduc();
@@ -143,7 +143,7 @@ class CodegenEnv {
Value getValidLexInsert() const { return redValidLexInsert; }
void startCustomReduc(ExprId exp);
- bool isCustomReduc() const { return redCustom != kInvalidId; }
+ bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
Value getCustomRedId();
void endCustomReduc();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3343a5103671e..23e30351f220a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1111,7 +1111,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
- if (e == kInvalidId)
+ if (e == ::mlir::sparse_tensor::detail::kInvalidId)
return Value();
const TensorExp &exp = env.exp(e);
const auto kind = exp.kind;
@@ -1146,11 +1146,11 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
/// Hoists loop invariant tensor loads for which indices have been exhausted.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
LoopId ldx, bool atStart) {
- if (exp == kInvalidId)
+ if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
return;
if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
// Inspect tensor indices.
- bool isAtLoop = ldx == kInvalidId;
+ bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
linalg::GenericOp op = env.op();
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
auto map = op.getMatchingIndexingMap(&t);
@@ -1715,7 +1715,8 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
// Construct iteration lattices for current loop index, with L0 at top.
const LoopId idx = env.topSortAt(at);
- const LoopId ldx = at == 0 ? kInvalidId : env.topSortAt(at - 1);
+ const LoopId ldx = at == 0 ? ::mlir::sparse_tensor::detail::kInvalidId
+ : env.topSortAt(at - 1);
const LatSetId lts =
env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index b3ff60882fbfe..9b929a5dda25e 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -103,14 +103,14 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
switch (kind) {
// Leaf.
case TensorExp::Kind::kTensor:
- assert(x != kInvalidId && y == kInvalidId && !v && !o);
+ assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
tensor = x;
return;
case TensorExp::Kind::kInvariant:
- assert(x == kInvalidId && y == kInvalidId && v && !o);
+ assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
return;
case TensorExp::Kind::kLoopVar:
- assert(x != kInvalidId && y == kInvalidId && !v && !o);
+ assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
loop = x;
return;
// Unary operations.
@@ -134,7 +134,7 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
case TensorExp::Kind::kNegI:
case TensorExp::Kind::kCIm:
case TensorExp::Kind::kCRe:
- assert(x != kInvalidId && y == kInvalidId && !v && !o);
+ assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
children.e0 = x;
children.e1 = y;
return;
@@ -149,20 +149,20 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
case TensorExp::Kind::kCastIdx:
case TensorExp::Kind::kTruncI:
case TensorExp::Kind::kBitCast:
- assert(x != kInvalidId && y == kInvalidId && v && !o);
+ assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
children.e0 = x;
children.e1 = y;
return;
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kSelect:
- assert(x != kInvalidId && y == kInvalidId && !v && o);
+ assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
return;
case TensorExp::Kind::kUnary:
// No assertion on y can be made, as the branching paths involve both
// a unary (`mapSet`) and binary (`disjSet`) pathway.
- assert(x != kInvalidId && !v && o);
+ assert(x != detail::kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
return;
@@ -186,13 +186,13 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
case TensorExp::Kind::kShrS:
case TensorExp::Kind::kShrU:
case TensorExp::Kind::kShlI:
- assert(x != kInvalidId && y != kInvalidId && !v && !o);
+ assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
children.e0 = x;
children.e1 = y;
return;
case TensorExp::Kind::kBinary:
case TensorExp::Kind::kReduce:
- assert(x != kInvalidId && y != kInvalidId && !v && o);
+ assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
return;
More information about the Mlir-commits
mailing list