[Mlir-commits] [mlir] [mlir][linalg] do not break outs from block argument (PR #73572)

Okwan Kwon llvmlistbot at llvm.org
Tue Nov 28 16:55:00 PST 2023


================
@@ -1818,6 +1818,11 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
           continue;
 
+        // If outs is wired from a block argument, keep the dependency to
+        // prevent the argument from being optimized away.
----------------
okkwon wrote:

Let me explain the context, which may give you more info and what is happening here.

At 10,000 feet, I'd like to use a tensor operand as output. Here is an example,

```mlir
func.func @my_add(%a: tensor<2xf32>, %b: tensor<2xf32>, %out: tensor<2xf32> { iree.abi.output : 0} -> tensor<2xf32> {
  %r = linalg.generic ins(%a : tensor<2xf32>, %b : tensor<2xf32>) outs(%out) {
    // add op
    ...
  } -> tensor<2xf32>
  return %r : tensor<2xf32>
}
```

Here with `iree.abi.output`, we say that `%out` is used for output. So the function does not allocate buffer for `%r` and uses the caller-allocated storage for `%out` for the output storage.

But the rewriter breaks the existing dependency and blindly introduces `tensor.empty()` which makes `%out` used, so it gets removed later.

Hope this helps.


https://github.com/llvm/llvm-project/pull/73572


More information about the Mlir-commits mailing list