[Mlir-commits] [mlir] 34c58c8 - [mlir][sparse] Include sparse emit strategy in wrapping iterator (#165611)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 29 13:40:09 PDT 2025
Author: Jordan Rupprecht
Date: 2025-10-29T15:40:05-05:00
New Revision: 34c58c8b7c3dbf831bb1d26f1624af3e6a56edc7
URL: https://github.com/llvm/llvm-project/commit/34c58c8b7c3dbf831bb1d26f1624af3e6a56edc7
DIFF: https://github.com/llvm/llvm-project/commit/34c58c8b7c3dbf831bb1d26f1624af3e6a56edc7.diff
LOG: [mlir][sparse] Include sparse emit strategy in wrapping iterator (#165611)
When we create a `SparseIterator`, we sometimes wrap it in a
`FilterIterator`, which delegates _some_ calls to the underlying
`SparseIterator`.
After construction, e.g. in `makeNonEmptySubSectIterator()`, we call
`setSparseEmitStrategy()`. This sets the strategy only in one of the
filters -- if we call `setSparseEmitStrategy()` immediately after
creating the `SparseIterator`, then the wrapped `SparseIterator` will
have the right strategy, and the `FilterIterator` strategy will be
unintialized; if we call `setSparseEmitStrategy()` after wrapping the
iterator in `FilterIterator`, then the opposite happens.
If we make `setSparseEmitStrategy()` a virtual method so that it's
included in the `FilterIterator` pattern, and then do all reads of
`emitStrategy` via a virtual method as well, it's pretty simple to
ensure that the value of `strategy` is being set consistently and
correctly.
Without this, the UB of strategy being uninitialized manifests as a
sporadic test failure in
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir,
when run downstream with the right flags (e.g. asan + assertions off).
The test sometimes fails with `ne_sub<trivial<dense[0,1]>>.begin' op
created with unregistered dialect`. It can also be directly observed w/
msan that this uninitialized read is the cause of that issue, but msan
causes other problems w/ this test.
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 46d0baac58f06..61b5ad600a16e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -504,6 +504,14 @@ class SimpleWrapIterator : public SparseIterator {
unsigned extraCursorVal = 0)
: SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
+ void setSparseEmitStrategy(SparseEmitStrategy strategy) override {
+ wrap->setSparseEmitStrategy(strategy);
+ }
+
+ SparseEmitStrategy getSparseEmitStrategy() const override {
+ return wrap->getSparseEmitStrategy();
+ }
+
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
return wrap->getCursorValTypes(b);
}
@@ -979,7 +987,7 @@ class SubSectIterator : public SparseIterator {
void SparseIterator::genInit(OpBuilder &b, Location l,
const SparseIterator *p) {
- if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
+ if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
getCursorValTypes(b));
@@ -994,7 +1002,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
}
Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
- if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
+ if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
getCursor(), b.getI1Type());
@@ -1005,7 +1013,7 @@ Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
}
void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
- if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
+ if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
SmallVector<Value> args = getCursor();
args.push_back(crd);
@@ -1019,7 +1027,7 @@ void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
}
Value SparseIterator::deref(OpBuilder &b, Location l) {
- if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
+ if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
SmallVector<Value> args = getCursor();
Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
@@ -1032,7 +1040,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) {
ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
assert(!randomAccessible());
- if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
+ if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
getCursor(), getCursorValTypes(b));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 642cb1afa156b..3636f3f01adb5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -177,10 +177,14 @@ class SparseIterator {
public:
virtual ~SparseIterator() = default;
- void setSparseEmitStrategy(SparseEmitStrategy strategy) {
+ virtual void setSparseEmitStrategy(SparseEmitStrategy strategy) {
emitStrategy = strategy;
}
+ virtual SparseEmitStrategy getSparseEmitStrategy() const {
+ return emitStrategy;
+ }
+
virtual std::string getDebugInterfacePrefix() const = 0;
virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
More information about the Mlir-commits
mailing list