[Mlir-commits] [mlir] [MLIR] Fix affine.prefetch replaceAffineOp invoked during canonicalization (PR #88346)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 10 20:25:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Alexandre Eichenberger (AlexandreEichenberger)
<details>
<summary>Changes</summary>
There was an error in the canonicalization of `affine.prefetch`. Currently, when the pass modifies the `prefetch`, the `isWrite` and `localityHint` are swapped, resulting in unusable prefetch. For example, this test example
```mlir
func.func @<!-- -->main_graph(%arg0: memref<8x256x512xf32>) -> memref<8x256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} {
%alloc = memref.alloc() {alignment = 4096 : i64} : memref<8x256x512xf16, affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>>
affine.parallel (%arg1) = (0) to (8) {
affine.for %arg2 = 0 to 256 {
affine.for %arg3 = 0 to 8 step 2 {
affine.for %arg5 = 0 to 2 {
%1 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 * 64)>()[%arg3, %arg5]
affine.prefetch %arg0[%arg1, %arg2, %1], read, locality<3>, data : memref<8x256x512xf32>
}
}
}
}
return %arg0 : memref<8x256x512xf32>
}
```
resulted in the canonicalized prefetch op as below:
```mlir
affine.prefetch %arg0[%arg1, %arg2, symbol(%arg3) * 64 + symbol(%arg4) * 64], write, locality<0>, data : memref<8x256x512xf32>
```
which is clearly wrong (it used to be a read with locality of 4).
The issue was that the `replaceAffineOp` for `AffinePrefetchOp` swapped the `localityHint` and `isWrite` fields. No error was generated as the fields are compatible (one is a bool, the other an int).
Current patch fixes this issue.
```mlir
module {
func.func @<!-- -->main_graph(%arg0: memref<8x256x512xf32>) -> memref<8x256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} {
affine.parallel (%arg1) = (0) to (8) {
affine.for %arg2 = 0 to 256 {
affine.for %arg3 = 0 to 8 step 2 {
affine.for %arg4 = 0 to 2 {
affine.prefetch %arg0[%arg1, %arg2, symbol(%arg3) * 64 + symbol(%arg4) * 64], read, locality<3>, data : memref<8x256x512xf32>
}
}
}
}
return %arg0 : memref<8x256x512xf32>
}
}
```
---
Full diff: https://github.com/llvm/llvm-project/pull/88346.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+2-3)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c591e5056480ca..c9c0a7b4cc6860 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1487,9 +1487,8 @@ void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
- prefetch, prefetch.getMemref(), map, mapOperands,
- prefetch.getLocalityHint(), prefetch.getIsWrite(),
- prefetch.getIsDataCache());
+ prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
+ prefetch.getLocalityHint(), prefetch.getIsDataCache());
}
template <>
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
``````````
</details>
https://github.com/llvm/llvm-project/pull/88346
More information about the Mlir-commits
mailing list