[Mlir-commits] [mlir] [mlir] fix infinite while in 1:N dialect conversion (PR #123122)

Maksim Levental llvmlistbot at llvm.org
Wed Jan 15 13:16:50 PST 2025


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/123122

https://github.com/llvm/llvm-project/pull/116524 introduced a "subtle" bug; I can't give a repro unless you're willing to build https://github.com/triton-lang/triton/pull/5329 but here's a sketch:

```
tt.func @conversion1(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
    ... 
}
```
```c++
class ConvertFuncOp : public OpConversionPattern<tt::FuncOp> {
public:
  using PointerCanonicalizationPattern::PointerCanonicalizationPattern;

  LogicalResult
  matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
    ...
    // REPLACE tt.ptr WITH (tt.ptr, arith.constant 0) SAME tt.ptr!!!!
    ...
    return success();
  }
};
```

Then in 

```c++
// mlir/lib/Transforms/Utils/DialectConversion.cpp
ConversionPatternRewriterImpl::remapValues(...) {
    if (!currentTypeConverter) {
      // The current pattern does not have a type converter. I.e., it does not
      // distinguish between legal and illegal types. For each operand, simply
      // pass through the most recently mapped values.
      remapped.push_back(mapping.lookupOrDefault(operand));
      continue;
    }
}
```

Then 

```c++
ConversionValueMapping::lookupOrDefault(...) {
  ...
  do {
    ValueVector next;
    for (Value v : current) {
      // ALWAYS FINDS tt.ptr AT INDEX 0
      auto it = mapping.find({v});
      if (it != mapping.end()) {
        // ALWAYS REPLACES tt.ptr AT INDEX 0 
        // WITH (tt.ptr, arith.constant)
        llvm::append_range(next, it->second);
      } else {
        next.push_back(v);
      }
      ...
    }
  } while (true);
```

result: `current` grows without bounds, looks like `(tt.ptr, arith.constant ,arith.constant, arith.constant....)` and `while` loops forever.

The fix is to check whether we're "deepening" on the same `Value`.

Note, I can't add a test for this because the fail is an infinite loop.

>From 0a3ef6a4d69602ea0c4913e7bd751ba7e2fa95e9 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 15 Jan 2025 16:07:57 -0500
Subject: [PATCH] [mlir] fix infinite while in 1:N dialect conversion

---
 mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 403321d40d53c9..83d66dbe342d3f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -201,8 +201,9 @@ ConversionValueMapping::lookupOrDefault(Value from,
     // If possible, Replace each value with (one or multiple) mapped values.
     ValueVector next;
     for (Value v : current) {
-      auto it = mapping.find({v});
-      if (it != mapping.end()) {
+      ValueVector vv{v};
+      auto it = mapping.find(vv);
+      if (it != mapping.end() && it->first != vv) {
         llvm::append_range(next, it->second);
       } else {
         next.push_back(v);



More information about the Mlir-commits mailing list