[Mlir-commits] [mlir] [mlir] fix infinite while in 1:N dialect conversion (PR #123122)
Maksim Levental
llvmlistbot at llvm.org
Wed Jan 15 13:16:50 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/123122
https://github.com/llvm/llvm-project/pull/116524 introduced a "subtle" bug; I can't give a repro unless you're willing to build https://github.com/triton-lang/triton/pull/5329 but here's a sketch:
```
tt.func @conversion1(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
...
}
```
```c++
class ConvertFuncOp : public OpConversionPattern<tt::FuncOp> {
public:
using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
LogicalResult
matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
...
// REPLACE tt.ptr WITH (tt.ptr, arith.constant 0) SAME tt.ptr!!!!
...
return success();
}
};
```
Then in
```c++
// mlir/lib/Transforms/Utils/DialectConversion.cpp
ConversionPatternRewriterImpl::remapValues(...) {
if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped values.
remapped.push_back(mapping.lookupOrDefault(operand));
continue;
}
}
```
Then
```c++
ConversionValueMapping::lookupOrDefault(...) {
...
do {
ValueVector next;
for (Value v : current) {
// ALWAYS FINDS tt.ptr AT INDEX 0
auto it = mapping.find({v});
if (it != mapping.end()) {
// ALWAYS REPLACES tt.ptr AT INDEX 0
// WITH (tt.ptr, arith.constant)
llvm::append_range(next, it->second);
} else {
next.push_back(v);
}
...
}
} while (true);
```
result: `current` grows without bounds, looks like `(tt.ptr, arith.constant ,arith.constant, arith.constant....)` and `while` loops forever.
The fix is to check whether we're "deepening" on the same `Value`.
Note, I can't add a test for this because the fail is an infinite loop.
>From 0a3ef6a4d69602ea0c4913e7bd751ba7e2fa95e9 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 15 Jan 2025 16:07:57 -0500
Subject: [PATCH] [mlir] fix infinite while in 1:N dialect conversion
---
mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 403321d40d53c9..83d66dbe342d3f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -201,8 +201,9 @@ ConversionValueMapping::lookupOrDefault(Value from,
// If possible, Replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : current) {
- auto it = mapping.find({v});
- if (it != mapping.end()) {
+ ValueVector vv{v};
+ auto it = mapping.find(vv);
+ if (it != mapping.end() && it->first != vv) {
llvm::append_range(next, it->second);
} else {
next.push_back(v);
More information about the Mlir-commits
mailing list