[Mlir-commits] [mlir] [mlir][linalg] Add unit dim folding pattern for tensor.pad (PR #84684)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 11 10:57:46 PDT 2024


================
@@ -561,6 +561,126 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
 };
 } // namespace
 
+//===---------------------------------------------------------------------===//
+// Drop dimensions that are unit-extents within tensor operations.
+//===---------------------------------------------------------------------===//
+
+namespace {
+struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
+  DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), options(std::move(options)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    // 1a. Get the allowed list of dimensions to drop from the `options`.
+    SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
+    if (allowedUnitDims.empty()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "control function returns no allowed unit dims to prune");
+    }
+
+    if (padOp.getSourceType().getEncoding()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "cannot collapse dims of tensor with encoding");
+    }
+
+    // Fail for non-constant padding values. The body of the pad could
+    // depend on the padding indices and/or properties of the padded
+    // tensor so for now we fail.
+    // TODO: Support non-constant padding values.
+    Value paddingVal = padOp.getConstantPaddingValue();
+    if (!paddingVal) {
+      return rewriter.notifyMatchFailure(
+          padOp, "unimplemented: non-constant padding value");
+    }
+
+    ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+    int64_t padRank = sourceShape.size();
+
+    auto isStaticZero = [](OpFoldResult f) {
+      std::optional<int64_t> maybeInt = getConstantIntValue(f);
+      return maybeInt && *maybeInt == 0;
+    };
+
+    llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
+                                                 allowedUnitDims.end());
+    llvm::SmallDenseSet<unsigned> unitDims;
+    SmallVector<int64_t> newShape;
+    SmallVector<OpFoldResult> newLowPad;
+    SmallVector<OpFoldResult> newHighPad;
+    for (const auto [dim, size, low, high] :
+         zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+                   padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+      if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
+          isStaticZero(high)) {
+        unitDims.insert(dim);
+      } else {
+        newShape.push_back(size);
+        newLowPad.push_back(low);
+        newHighPad.push_back(high);
+      }
+    }
+
+    if (unitDims.empty()) {
+      return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
+    }
+
+    ReassociationIndices reassociationGroup;
+    SmallVector<ReassociationIndices> reassociationMap;
+    int64_t dim = 0;
+    while (dim < padRank && unitDims.contains(dim))
+      reassociationGroup.push_back(dim++);
+    while (dim < padRank) {
+      assert(!unitDims.contains(dim) && "expected non unit-extent");
+      reassociationGroup.push_back(dim);
+      dim++;
+      // Fold all following dimensions that are unit-extent.
+      while (dim < padRank && unitDims.contains(dim))
+        reassociationGroup.push_back(dim++);
+      reassociationMap.push_back(reassociationGroup);
+      reassociationGroup.clear();
+    }
+
+    Value collapsedSource =
+        collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
+                      reassociationMap, options.rankReductionStrategy);
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+        newHighPad, paddingVal, padOp.getNofold());
+
+    Value dest = padOp.getResult();
+    if (options.rankReductionStrategy ==
+        ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+      SmallVector<OpFoldResult> expandedSizes;
+      int64_t numUnitDims = 0;
+      for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
+        if (unitDims.contains(dim)) {
+          expandedSizes.push_back(rewriter.getIndexAttr(1));
+          numUnitDims++;
+          continue;
+        }
+        expandedSizes.push_back(tensor::getMixedSize(
+            rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
+      }
+      dest = rewriter.create<tensor::EmptyOp>(
----------------
MaheshRavishankar wrote:

You already added the most reasonable support for this, but I'd be OK with just dropping the support for the extract/insert slice path and landing only the reshape mode.

https://github.com/llvm/llvm-project/pull/84684


More information about the Mlir-commits mailing list