<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/62906>62906</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            [mlir][MemRef] Teach `fold-memref-alias-ops` about `nvgpu.ldmatrix`
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
            mlir:memref
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          qcolombet
      </td>
    </tr>
</table>

<pre>
    The `fold-memref-alias-ops` pass folds view-like operations in the address computation of the related memory operation.
This pass doesn't know about `nvgpu.ldmatrix`.

To reproduce:
```
mlir-opt -fold-memref-alias-ops test.mlir 
```

With `test.mlir`:
```mlir
#map = affine_map<()[s0] -> (-s0 + 4)>
#map1 = affine_map<()[s0] -> (-s0 + 32)>

  func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: index, %arg3: index) -> vector<4x2xf16> {
    %c0 = arith.constant 0 : index
    %0 = affine.apply #map()[%arg1]
    %1 = affine.apply #map1()[%arg2]
    %2 = affine.apply #map1()[%arg3]
    %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [%0, %1, %2] [1, 1, 1] : memref<4x32x32xf16, 3> to memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3>
    %3 = nvgpu.ldmatrix %subview[%c0, %c0, %c0] {numTiles = 4 : i32, transpose = false} : memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3> -> vector<4x2xf16>
    return %3 : vector<4x2xf16>
 }
```

Expected result:
```mlir
func.func @test_ldmatrix(%arg0 : memref<4x32x32xf16, 3>,
    %arg1 : index, %arg2: index, %arg3: index)
    -> vector<4x2xf16> {
  %loaded_val = nvgpu.ldmatrix%arg0[%arg1, %arg2, %arg3] {numTiles = 4 : i32, transpose = false}
      : memref<4x32x32xf16, 3> -> vector<4x2xf16>
  return %loaded_val : vector<4x2xf16>
}
```

Actual result: The IR is left unchanged.

Note: One can use `memref.load` to see how it works:
```mlir
#map = affine_map<()[s0] -> (-s0 + 4)>
#map1 = affine_map<()[s0] -> (-s0 + 32)>

  func.func @test_ld(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: index, %arg3: index) -> f16 {
    %c0 = arith.constant 0 : index
    %0 = affine.apply #map()[%arg1]
    %1 = affine.apply #map1()[%arg2]
    %2 = affine.apply #map1()[%arg3]
    %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [%0, %1, %2] [1, 1, 1] : memref<4x32x32xf16, 3> to memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3>
    %3 = memref.load %subview[%c0, %c0, %c0] : memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3>
    return %3 : f16
  }
```

After `fold-memref-alias-ops`:
```mlir
  func.func @test_ld(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: index, %arg3: index) -> f16 {
    %0 = memref.load %arg0[%arg1, %arg2, %arg3] : memref<4x32x32xf16, 3>
    return %0 : f16
  }
```
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzsV1tv-jYU_zTm5YjIseMADzyUUqQ9bJP-qrTHyiQnxKsTZ7bTy7ef7IQCHYV22qZN-0vBIefm8zsXH1k6p3Yt4pKIFRHriex9bezyt8Jo02zRT7amfF3e1wgkp5XR5bTBxmI1lVpJNzWdIzmFTjoHgevgSeHzVKtHBNOhlV6Z1oFqwdcIsiwtOgeFabreRx6YKrIsaumxhAYbY18Pugmha0Jv7mvlhl1Kg64lbObhsTXPILem98G39mnX9YkuG-mteiE5HTVHfQMWO2vKvkDCRyLJ6fjEz0YrOzWdh-lZnODR-SQIwVn1Yf1F-Tq48yYc-H_YMDIGEuON7IDwNciqUi0-NLIj_JawOWELIlaOErGGKeF3QNh86igQtoIsMPndsY30y0Y4O7ESV4Cqb4skLECyCOPhLajBnJB2FxDBEB_Cb7MXzsJTpTlht8CDRXYLg2gaRFVb4suBxs7Q-BFtMXj6hIU3Nm4wGA_ez1Z7PyFoFnRAbZWvk8K0zsvWQyDurR1L06MQJbLr9CsMsXsL1N5psT5RTD9STN9psvea7LOa_L2m67ehl6L-EOvkjTRm4eDvUWyPQirWMMjQkboXZCMrfo9LoFxJK3hz4BO-edn_RhnnrSqxDEyxSinLoiLb2w9vU1UOfdiI8M1YKfytBkfoPII-7eijkAyYij2o4z8BxGzV9s290uiimWwohsENb2XrOuMwsiqpHZLZO-B_LbAPa_mA2KLvbbsHfnNBmMzWF86eu5cOi3CIWnS99peOnc80-dVyIOz2JG2hEuHP9PvBymcanzChjSyxfHiS-kylfKU7vl4rB1_her9cy_0h8yeILpTA5Qq4KXwv9SH_EAb3D99AOdBYeejbopbtDsuT4fiT8WEows8tQiFb6F2c9uOhEzwLM94bcIhQm2dQHp6NfXT__cH2D4-0Ks2_j7D_0Qg7aqHPz6-_bxZ9PHKC7f3xevmEqTzaS3eBS2fCv7YH6blsfbo-r3l9Jur0etQn5ZKXC76QE1ym-TxbpCLldFIvmdzms5zncr7gLEszOd_OmRCiSCUKzHCilowyTgXL6EzMBU2ylBa4yFLOxbyimJOMYiOVTrR-ahJjdxPlXI_LnC1oPtFyi9rFCyFjMXn8ZsTHQtdN7DLoTbf9zpGMauW8O1jyyut4m4yaYk3E6kdsvmEVQnWPsqgv3iQ_vs9NequXtfddHDtsQ9hmp3zdb5PCNIRtggfja9pZ8ysWnrBNBOYI20RsvwcAAP__JdkS5A">