[Mlir-commits] [mlir] 83b7f01 - [mlir][sparse] fix crashes when the tensor that defines the loop bound can not be found
Peiming Liu
llvmlistbot at llvm.org
Wed Jun 14 13:27:57 PDT 2023
Author: Peiming Liu
Date: 2023-06-14T20:27:50Z
New Revision: 83b7f018fd704714fbb4c15081bce3dde48eb707
URL: https://github.com/llvm/llvm-project/commit/83b7f018fd704714fbb4c15081bce3dde48eb707
DIFF: https://github.com/llvm/llvm-project/commit/83b7f018fd704714fbb4c15081bce3dde48eb707.diff
LOG: [mlir][sparse] fix crashes when the tensor that defines the loop bound can not be found
Reviewed By: aartbik, K-Wu
Differential Revision: https://reviews.llvm.org/D152877
Added:
mlir/test/Dialect/SparseTensor/unused-tensor.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index d0884ca482de2..0a70cb0a970fc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -376,8 +376,15 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
loopIdToOrd[topSort[n]] = n;
}
-void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
- LoopEmitter::OutputUpdater updater) {
+void LoopEmitter::initializeLoopEmit(
+ OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
+ LoopEmitter::SynTensorBoundSetter synSetter) {
+
+ // For every synthetic tensor, set the high bound by calling the callback.
+ if (synSetter)
+ for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++)
+ highs[getSynTensorId()][i] = synSetter(builder, loc, i);
+
// For every manifest tensor:
// * get the values buffer.
// * For every level:
@@ -534,27 +541,15 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
// Prepares for all the tensors used in the current loop sequence.
std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
- bool hasSynTensor = false;
- std::optional<std::pair<TensorId, Level>> loopBoundDefLevel = std::nullopt;
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
if (!dependentLvlMap[tid][lvl].empty()) {
bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
slicedTids.emplace_back(tid, lvl, fullyRed);
- } else {
- if (isSynTensor(tid)) {
- hasSynTensor = true;
- } else {
- loopBoundDefLevel = std::make_pair(tid, lvl);
- prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
- }
+ } else if (!isSynTensor(tid)) {
+ prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
}
}
- if (hasSynTensor && loopBoundDefLevel.has_value()) {
- // TODO: compute the loopBound for index reduction by d - sum(unres_lvls).
- highs[getSynTensorId()][getCurrentDepth()] =
- lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second];
- }
// Universal Index starts from 0.
loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index f178366a738a4..4f5100cc36104 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -78,6 +78,12 @@ class LoopEmitter {
/// initializing the loop emitter (e.g., to fill a dense output with zeros).
using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
Value memref, Value tensor)>;
+
+ /// Optional callback function to set the bound for the synthetic tensor,
+ /// which essentially is the dense loop bound.
+ using SynTensorBoundSetter =
+ function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
+
// Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
// index on sparse tensors.
// E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
@@ -114,7 +120,8 @@ class LoopEmitter {
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
void initializeLoopEmit(OpBuilder &builder, Location loc,
- OutputUpdater updater = nullptr);
+ OutputUpdater updater = nullptr,
+ SynTensorBoundSetter synSetter = nullptr);
/// Generates code to compute an affine expression whose variables are
/// `LoopId`s (i.e., `a.cast<AffineDimExpr>().getPosition()` is a valid
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 637c16f92a293..7e69a737b0661 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -832,6 +832,21 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
Location loc = op.getLoc();
assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
+ SmallVector<Range, 4> loopRange =
+ llvm::cast<linalg::LinalgOp>(op.getOperation())
+ .createLoopRanges(builder, loc);
+
+ assert(loopRange.size() == env.merger().getStartingFilterLoopId());
+ SmallVector<Range, 4> sortedRange;
+ for (unsigned i = 0, e = env.topSortSize(); i < e; i++) {
+ LoopId ldx = env.topSortAt(i);
+ // FIXME: Gets rid of filter loops since we have a better algorithm to deal
+ // with affine index expression.
+ if (ldx < env.merger().getStartingFilterLoopId()) {
+ sortedRange.push_back(loopRange[ldx]);
+ }
+ }
+
env.emitter().initializeLoopEmit(
builder, loc,
/// Generates buffer for the output tensor.
@@ -865,6 +880,16 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
ValueRange{init});
}
return init;
+ },
+ [&sortedRange, &env](OpBuilder &b, Location loc, Level l) {
+ assert(l < env.topSortSize());
+ // FIXME: Remove filter loop since we have a better algorithm to
+ // deal with affine index expression.
+ if (l >= env.merger().getStartingFilterLoopId())
+ return Value();
+
+ return mlir::getValueOrCreateConstantIndexOp(b, loc,
+ sortedRange[l].size);
});
}
@@ -1594,7 +1619,9 @@ static bool translateBitsToTidLvlPairs(
// iterate based on the level of output tensor. E.g., this
// could be a synthetic tensor (for invariants and sparse
// output tensor).
- if (env.isReduc() && env.merger().getSynTensorID() == tid) {
+ auto itType = env.op().getIteratorTypesArray()[ldx];
+ if (linalg::isReductionIterator(itType) &&
+ env.merger().getSynTensorID() == tid) {
// Coiterating with an invariant, and this is a reduction loop
// e.g., out = prod(in[i][j] op invariant);
// In this case, we can not infer the loop bound from output
@@ -1669,7 +1696,14 @@ static bool translateBitsToTidLvlPairs(
tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
}
- assert(numloopCond > 0);
+ if (numloopCond == 0) {
+ // Corner cases where the loop bound is defined by a *unused* operand, in
+ // this case, we just generate a dense "fake" loop by iterating over the
+ // synthetic tensor.
+ tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
+ env.emitter().getCurrentDepth()));
+ numloopCond++;
+ }
// If we just need to one loop conditions and the conditions is not imposed on
// non-unique level, the loop can be generated by a for loop.
return numloopCond == 1 && !hasNonUnique;
diff --git a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
new file mode 100644
index 0000000000000..05da6c455135c
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+//
+// A contrived example where the sparse tensor B is only
+// used in the linalg op to determine the number of iterations
+// for the k-loop. This is included to make sure the sparse
+// compiler still generates the correct loop nest for this case.
+//
+
+#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+#trait = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,j)>, // A
+ affine_map<(i,j,k) -> (k,j)>, // B
+ affine_map<(i,j,k) -> (i,j)> // S_out
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "C(i,j) = SUM_k A(i,j)"
+}
+
+// CHECK-LABEL: func.func @b_ununsed(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<2x4xf64>
+// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2x4xf64>
+// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
+// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f64
+// CHECK: memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<2x4xf64>
+// CHECK: return %[[VAL_16]] : tensor<2x4xf64>
+// CHECK: }
+func.func @b_ununsed(%argA: tensor<2x4xf64>,
+ %argB: tensor<8x4xf64, #SM>,
+ %argC: tensor<2x4xf64>) -> tensor<2x4xf64> {
+ %result = linalg.generic #trait
+ ins(%argA, %argB: tensor<2x4xf64>, tensor<8x4xf64, #SM>)
+ outs(%argC: tensor<2x4xf64>) {
+ ^bb(%a: f64, %b: f64, %c: f64):
+ %0 = arith.addf %c, %a : f64
+ linalg.yield %0 : f64
+ } -> tensor<2x4xf64>
+ return %result : tensor<2x4xf64>
+}
More information about the Mlir-commits
mailing list