[Mlir-commits] [mlir] [MLIR][memref] Fix normalization issue in memref.load (PR #107771)

Kai Sasaki llvmlistbot at llvm.org
Wed Oct 2 19:26:44 PDT 2024


================
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
   %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
   return %1 : tensor<16x512xf32>
 }
+
+#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4 * i + j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
+  %0 = memref.alloc() : memref<4x8xf32,#map0>
+  %1 = memref.alloc() : memref<8x4xf32,#map1>
+  %2 = memref.alloc() : memref<4x4xf32,#map2>
+  // CHECK-NOT:  memref<4x8xf32>
+  // CHECK-NOT:  memref<8x4xf32>
+  // CHECK-NOT:  memref<4x4xf32>
+  %cst = arith.constant 3.0 : f32
+  %cst0 = arith.constant 0 : index
+  affine.for %i = 0 to 4 {
+    affine.for %j = 0 to 8 {
+      affine.for %k = 0 to 8 {
+        // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
----------------
Lewuathe wrote:

I experimented to make the test pass somehow.

I found we could put the affine map definition on top of the test file and use them in the function after the `CHECK-LABEL`. That's because the affine map is defined at the top of the module, I guess. 

```
// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>

...

// CHECK-LABEL: func @memref_load_with_reduction_map
func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
  %0 = memref.alloc() : memref<4x8xf32,#map0>
  %1 = memref.alloc() : memref<8x4xf32,#map1>
  %2 = memref.alloc() : memref<4x4xf32,#map2>
  // CHECK-NOT:  memref<4x8xf32>
  // CHECK-NOT:  memref<8x4xf32>
  // CHECK-NOT:  memref<4x4xf32>
  %cst = arith.constant 3.0 : f32
  %cst0 = arith.constant 0 : index
  affine.for %i = 0 to 4 {
    affine.for %j = 0 to 8 {
      affine.for %k = 0 to 8 {
        // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
        // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
        // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
        // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
        // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
        // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
        %3 = arith.mulf %a, %b : f32
        %4 = arith.addf %3, %c : f32
        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
      }
    }
  }
  return
}
```

https://github.com/llvm/llvm-project/pull/107771


More information about the Mlir-commits mailing list