[Mlir-commits] [mlir] [MLIR] Fix affine.prefetch replaceAffineOp invoked during canonicalization (PR #88346)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 10 20:25:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Alexandre Eichenberger (AlexandreEichenberger)

<details>
<summary>Changes</summary>

There was an error in the canonicalization of `affine.prefetch`. Currently, when the pass modifies the `prefetch`, the `isWrite` and `localityHint` are swapped, resulting in unusable prefetch. For example, this test example

```mlir
func.func @<!-- -->main_graph(%arg0: memref<8x256x512xf32>) -> memref<8x256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} {
  %alloc = memref.alloc() {alignment = 4096 : i64} : memref<8x256x512xf16, affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>>
  affine.parallel (%arg1) = (0) to (8) {
    affine.for %arg2 = 0 to 256 {
      affine.for %arg3 = 0 to 8 step 2 {
        affine.for %arg5 = 0 to 2 {
          %1 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 * 64)>()[%arg3, %arg5]
          affine.prefetch %arg0[%arg1, %arg2, %1], read, locality<3>, data : memref<8x256x512xf32>
        }
      }      
    }
  }
  return %arg0 : memref<8x256x512xf32>
}
```

resulted in the canonicalized prefetch op as below:
```mlir
affine.prefetch %arg0[%arg1, %arg2, symbol(%arg3) * 64 + symbol(%arg4) * 64], write, locality<0>, data : memref<8x256x512xf32>
```
which is clearly wrong (it used to be a read with locality of 4).

The issue was that the `replaceAffineOp` for `AffinePrefetchOp` swapped the `localityHint` and `isWrite` fields. No error was generated as the fields are compatible (one is a bool, the other an int).

Current patch fixes this issue.

```mlir
module {
  func.func @<!-- -->main_graph(%arg0: memref<8x256x512xf32>) -> memref<8x256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} {
    affine.parallel (%arg1) = (0) to (8) {
      affine.for %arg2 = 0 to 256 {
        affine.for %arg3 = 0 to 8 step 2 {
          affine.for %arg4 = 0 to 2 {
            affine.prefetch %arg0[%arg1, %arg2, symbol(%arg3) * 64 + symbol(%arg4) * 64], read, locality<3>, data : memref<8x256x512xf32>
          }
        }
      }
    }
    return %arg0 : memref<8x256x512xf32>
  }
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/88346.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+2-3) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c591e5056480ca..c9c0a7b4cc6860 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1487,9 +1487,8 @@ void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
     ArrayRef<Value> mapOperands) const {
   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
-      prefetch, prefetch.getMemref(), map, mapOperands,
-      prefetch.getLocalityHint(), prefetch.getIsWrite(),
-      prefetch.getIsDataCache());
+      prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
+      prefetch.getLocalityHint(), prefetch.getIsDataCache());
 }
 template <>
 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(

``````````

</details>


https://github.com/llvm/llvm-project/pull/88346


More information about the Mlir-commits mailing list