[Mlir-commits] [mlir] Lower allreduce (PR #144716)
Anton Lydike
llvmlistbot at llvm.org
Wed Jun 18 08:03:21 PDT 2025
================
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
```
}];
let arguments = !con(commonArgs, (ins
- AnyRankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
----------------
AntonLydike wrote:
I wonder if this is the way to go, isn't the sharding dialect supposed to be tensor only? I would expect MPI to insert some bufferization dialect ops to get to the memref. Although admittedly, my knowledge here is very flaky.
https://github.com/llvm/llvm-project/pull/144716
More information about the Mlir-commits
mailing list