[Mlir-commits] [mlir] [mlir][linalg] added some conditions for values being undefined in the documentation for `linalg.generic` (PR #96251)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Jun 24 02:46:15 PDT 2024


================
@@ -82,6 +82,16 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
         types:
           parallel, reduction, window
 
+    Any element within the output operands that are not indexed by the 
----------------
adam-smnk wrote:

Looking at that example, I think it's bufferization issue rather than op semantics.
In this case, `arith.constant` is not writable and gets replaced with raw `memref.alloc` without any copy. So, only the indexed elements get updated.

With a bit of hand holding, we can get bufferization in place:
```mlir
memref.global "private" @__wrapper_0 : memref<2x2x1xi64> = dense<
          [[[0], 
            [-2]], 
          [[-1], 
            [0]]]>
func.func @entry() -> tensor<2x2x1xi64> {
  %1 = arith.constant dense<
          [[[10, 20]]]> : tensor<1x1x2xi64>
  %2 = arith.constant dense<-1> : tensor<1xi64>
  // Replace with writable global
  // %3 = arith.constant dense<
  //         [[[0], 
  //           [-2]], 
  //         [[-1], 
  //           [0]]]> : tensor<2x2x1xi64>

  %g0 = memref.get_global @__wrapper_0 : memref<2x2x1xi64>
  %3 = bufferization.to_tensor %g0 restrict writable : memref<2x2x1xi64>

  %out = linalg.generic {indexing_maps = [
              affine_map<(d0, d1) -> (d1, d1, d0)>, // accesses %1[0,0,0], %1[0,0,1]
              affine_map<(d0, d1) -> (d1)>,         // accesses %2[0],     %2[0]
              affine_map<(d0, d1) -> (d0, d0, d1)>  // accesses %3[0,0,0], %3[1,1,0] 
              ], iterator_types = ["reduction", "reduction"]} 
          ins(%1, %2 : tensor<1x1x2xi64>, tensor<1xi64>) 
          outs(%3 : tensor<2x2x1xi64>) {
      ^bb0(%in: i64, %in_5: i64, %out: i64):
        linalg.yield %in : i64
      } -> tensor<2x2x1xi64>
      // expected:
      // [[[10], 
      //   [-2]],
      //  [[-1], 
      //   [20]]]
  return %out : tensor<2x2x1xi64>
}
```

https://github.com/llvm/llvm-project/pull/96251


More information about the Mlir-commits mailing list