[Mlir-commits] [mlir] [mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface (PR #92624)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 22 09:48:48 PDT 2024


================
@@ -250,58 +250,66 @@ template <typename LinalgOpTy>
 struct LinalgOpPartialReductionInterface
     : public PartialReductionOpInterface::ExternalModel<
           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
-  FailureOr<Operation *> generateInitialTensorForPartialReduction(
+  FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
       Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
       ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
     OpBuilder::InsertionGuard guard(b);
 
     if (linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
-    // Insert the new parallel dimension based on the index of the reduction
-    // loops. This could be controlled by user for more flexibility.
 
-    SmallVector<Operation *, 4> combinerOps;
-    if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
-        combinerOps.size() != 1)
-      return op->emitOpError("Failed to anaysis the reduction operation.");
-
-    Operation *reductionOp = combinerOps[0];
-    std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
-    if (!identity.has_value())
-      return op->emitOpError(
-          "Failed to get an identity value for the reduction operation.");
-
-    ArrayRef<int64_t> oldShape =
-        linalgOp.getShape(linalgOp.getDpsInitOperand(0));
-
-    // Calculate the new shape, we insert the new dimensions based on the index
-    // of the reduction dimensions.
-    SmallVector<int64_t> newOutputShape;
-    SmallVector<Value> dynamicDims;
-    int64_t currReductionDims = 0;
-    DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
-    for (int64_t idx :
-         llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
-      if (reductionDimsSet.contains(idx)) {
-        dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
-        currReductionDims++;
-        continue;
+    SmallVector<Value> inits;
+    for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
----------------
MaheshRavishankar wrote:

Nit: Use `for (auto [index, init] : llvm::enumerate(linalgOp.getNumDpsInits())`

https://github.com/llvm/llvm-project/pull/92624


More information about the Mlir-commits mailing list