[Mlir-commits] [mlir] 7d9677a - [mlir][sparse] Make getNumTensors() consistent between LoopEmitter and Merger.
Peiming Liu
llvmlistbot at llvm.org
Mon Jun 5 10:56:13 PDT 2023
Author: Peiming Liu
Date: 2023-06-05T17:56:08Z
New Revision: 7d9677a9bd4fe147c50a753651342e287dcf2ab5
URL: https://github.com/llvm/llvm-project/commit/7d9677a9bd4fe147c50a753651342e287dcf2ab5
DIFF: https://github.com/llvm/llvm-project/commit/7d9677a9bd4fe147c50a753651342e287dcf2ab5.diff
LOG: [mlir][sparse] Make getNumTensors() consistent between LoopEmitter and Merger.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D152178
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 3186889b77293..dea9e740b8db6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -89,9 +89,11 @@ class CodegenEnv {
TensorLevel makeTensorLevel(TensorId t, Level l) const {
// Make sure LoopEmitter, GenericOp, and Merger agree on the number of
- // tensors. Merger has one more synthetic tensor for loop invariants.
- assert(loopEmitter.getNumTensors() == linalgOp->getNumOperands() &&
- loopEmitter.getNumTensors() == latticeMerger.getNumTensors() - 1);
+ // tensors.
+ assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
+ loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
+ loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
+ loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
return loopEmitter.makeTensorLevel(t, l);
}
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 67f3c30eb4db1..6a639efb2b337 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -235,8 +235,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
const unsigned numManifestTensors = ts.size();
const unsigned synTensorId = numManifestTensors;
const unsigned numTensors = numManifestTensors + 1;
-
+ // tensors array (len == numManifestTensor).
this->tensors.assign(ts.begin(), ts.end());
+ // Arrays with len == numTensor.
this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
this->lvlSizes.assign(numTensors, std::vector<Value>());
this->highs.assign(numTensors, std::vector<Value>());
@@ -355,13 +356,14 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
LoopEmitter::OutputUpdater updater) {
- // For every tensor:
+ // For every manifest tensor:
// * get the values buffer.
// * For every level:
// * get the positions and coordinates buffers
// * get/compute the level-size, which is also used as the upper-bound
// on positions.
- for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) {
+ for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
+ t++) {
const Value tensor = tensors[t];
const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index ded58f2d4b01b..8fa79128889e2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -191,21 +191,33 @@ class LoopEmitter {
return n < getCurrentDepth() ? loopStack[n].iv : Value();
}
+ /// Gets the total number of manifest tensors (excluding the synthetic
+ /// tensor).
+ unsigned getNumManifestTensors() const { return tensors.size(); }
+
/// Gets the total number of tensors that loopEmitter is operating on.
- unsigned getNumTensors() const { return tensors.size(); }
+ unsigned getNumTensors() const {
+ // Manifest tensors with one synthetic tensor at the end.
+ return getNumManifestTensors() + 1;
+ }
/// Gets the TensorId for synthetic tensor.
TensorId getSynTensorId() const { return tensors.size(); }
+ /// Gets the TensorId for output tensor.
+ TensorId getOutTensorId() const {
+ assert(hasOutput);
+ return getNumManifestTensors() - 1;
+ }
+
/// Compresses a TensorId and Level into a TensorLevel.
TensorLevel makeTensorLevel(TensorId t, Level l) const {
- // TODO: getNumTensor() should include synthetic tensor.
- return l * (getNumTensors() + 1) + t;
+ return l * getNumTensors() + t;
}
/// De-compresses a TensorLevel back to a pair of TensorId and Level.
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
- unsigned nt = getNumTensors() + 1;
+ unsigned nt = getNumTensors();
return std::make_pair(tidLvl % nt, tidLvl / nt);
}
@@ -323,10 +335,10 @@ class LoopEmitter {
Location loc, Value crd,
TensorId tid, Level lvl);
- bool isSynTensor(TensorId tid) const { return tid == getNumTensors(); }
+ bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); }
bool isOutputTensor(TensorId tid) const {
- return hasOutput && tid == getNumTensors() - 1;
+ return hasOutput && tid == getOutTensorId();
}
bool isSparseOutput(TensorId tid) const {
@@ -414,8 +426,8 @@ class LoopEmitter {
/// TODO: why not do this computation when we first store the reassoc,
/// instead of doing it every time we look it up?
SmallVector<Level, 2> getCollapseReassociation(TensorId tid, Level dstLvl) {
- assert(tid < getNumTensors() + 1 && "Invalid TensorId");
- assert(collapseReassoc.size() == getNumTensors() + 1);
+ assert(tid < getNumTensors() && "Invalid TensorId");
+ assert(collapseReassoc.size() == getNumTensors());
if (const auto reassoc = collapseReassoc[tid]) {
assert(!isSynTensor(tid) && !isOutputTensor(tid) &&
"Output/Synthetic tensor should not have reassociation");
More information about the Mlir-commits
mailing list