[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 17 09:05:10 PST 2024


================
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+                              ArrayRef<OpFoldResult> tileSizes,
+                              SmallVector<Value> &initTensors,
+                              const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    FailureOr<SmallVector<Value>> maybeInitTensors =
+        redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+                                                       reductionDims);
+    if (failed(maybeInitTensors)) {
+      return failure();
+    }
+    initTensors = maybeInitTensors.value();
+    return success();
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+                       ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+                       ArrayRef<OpFoldResult> sizes,
+                       const scf::SCFTilingOptions &options) {
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getTiledImplementation(rewriter, offsets, sizes);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
+      PartialReductionOuterReduction: {
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only "
+              "supported for operations "
+              "implementing PartialReductionOpInterface");
+    }
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+                                        offsets, sizes, reductionDims);
+  }
+  default:
+    return rewriter.notifyMatchFailure(op,
+                                       "unhandled reduction tiling strategy");
+  }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+                      TilingInterface op, ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      SmallVector<OpFoldResult> &resultOffset,
+                      SmallVector<OpFoldResult> &resultSize,
+                      const scf::SCFTilingOptions &options) {
+
+  switch (options.reductionStrategy) {
+  case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+    return op.getResultTilePosition(rewriter, index, offsets, sizes,
+                                    resultOffset, resultSize);
+  case scf::SCFTilingOptions::ReductionTilingStrategy::
----------------
MaheshRavishankar wrote:

I think this wont work for an operation like this

```
%result = linalg.generic {
    iterator_types = ["parallel", "parallel", "reduction"],
    indexing_maps = [...., affine_map<(d0, d1, d2) -> (d1, d0)>]}
```


but like you said that is the problem with current implementation anyway.

https://github.com/llvm/llvm-project/pull/120115


More information about the Mlir-commits mailing list