[Mlir-commits] [mlir] [mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface (PR #92624)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 22 09:48:48 PDT 2024
================
@@ -250,58 +250,66 @@ template <typename LinalgOpTy>
struct LinalgOpPartialReductionInterface
: public PartialReductionOpInterface::ExternalModel<
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
- FailureOr<Operation *> generateInitialTensorForPartialReduction(
+ FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
ArrayRef<int> reductionDims) const {
auto linalgOp = cast<LinalgOp>(op);
OpBuilder::InsertionGuard guard(b);
if (linalgOp.hasPureBufferSemantics())
return op->emitOpError("expected operation to have tensor semantics");
- // Insert the new parallel dimension based on the index of the reduction
- // loops. This could be controlled by user for more flexibility.
- SmallVector<Operation *, 4> combinerOps;
- if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
- combinerOps.size() != 1)
- return op->emitOpError("Failed to anaysis the reduction operation.");
-
- Operation *reductionOp = combinerOps[0];
- std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
- if (!identity.has_value())
- return op->emitOpError(
- "Failed to get an identity value for the reduction operation.");
-
- ArrayRef<int64_t> oldShape =
- linalgOp.getShape(linalgOp.getDpsInitOperand(0));
-
- // Calculate the new shape, we insert the new dimensions based on the index
- // of the reduction dimensions.
- SmallVector<int64_t> newOutputShape;
- SmallVector<Value> dynamicDims;
- int64_t currReductionDims = 0;
- DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
- for (int64_t idx :
- llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
- if (reductionDimsSet.contains(idx)) {
- dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
- currReductionDims++;
- continue;
+ SmallVector<Value> inits;
+ for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
----------------
MaheshRavishankar wrote:
Nit: Use `for (auto [index, init] : llvm::enumerate(linalgOp.getNumDpsInits())`
https://github.com/llvm/llvm-project/pull/92624
More information about the Mlir-commits
mailing list