[Mlir-commits] [mlir] [mlir][linalg] Add unit dim folding pattern for tensor.pad (PR #84684)
Quinn Dawkins
llvmlistbot at llvm.org
Mon Mar 11 11:21:25 PDT 2024
================
@@ -561,6 +561,126 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
};
} // namespace
+//===---------------------------------------------------------------------===//
+// Drop dimensions that are unit-extents within tensor operations.
+//===---------------------------------------------------------------------===//
+
+namespace {
+struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
+ DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(std::move(options)) {}
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ // 1a. Get the allowed list of dimensions to drop from the `options`.
+ SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
+ if (allowedUnitDims.empty()) {
+ return rewriter.notifyMatchFailure(
+ padOp, "control function returns no allowed unit dims to prune");
+ }
+
+ if (padOp.getSourceType().getEncoding()) {
+ return rewriter.notifyMatchFailure(
+ padOp, "cannot collapse dims of tensor with encoding");
+ }
+
+ // Fail for non-constant padding values. The body of the pad could
+ // depend on the padding indices and/or properties of the padded
+ // tensor so for now we fail.
+ // TODO: Support non-constant padding values.
+ Value paddingVal = padOp.getConstantPaddingValue();
+ if (!paddingVal) {
+ return rewriter.notifyMatchFailure(
+ padOp, "unimplemented: non-constant padding value");
+ }
+
+ ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+ int64_t padRank = sourceShape.size();
+
+ auto isStaticZero = [](OpFoldResult f) {
+ std::optional<int64_t> maybeInt = getConstantIntValue(f);
+ return maybeInt && *maybeInt == 0;
+ };
+
+ llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
+ allowedUnitDims.end());
+ llvm::SmallDenseSet<unsigned> unitDims;
+ SmallVector<int64_t> newShape;
+ SmallVector<OpFoldResult> newLowPad;
+ SmallVector<OpFoldResult> newHighPad;
+ for (const auto [dim, size, low, high] :
+ zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+ padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+ if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
+ isStaticZero(high)) {
+ unitDims.insert(dim);
+ } else {
+ newShape.push_back(size);
+ newLowPad.push_back(low);
+ newHighPad.push_back(high);
+ }
+ }
+
+ if (unitDims.empty()) {
+ return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
+ }
+
+ ReassociationIndices reassociationGroup;
+ SmallVector<ReassociationIndices> reassociationMap;
+ int64_t dim = 0;
+ while (dim < padRank && unitDims.contains(dim))
+ reassociationGroup.push_back(dim++);
+ while (dim < padRank) {
+ assert(!unitDims.contains(dim) && "expected non unit-extent");
+ reassociationGroup.push_back(dim);
+ dim++;
+ // Fold all following dimensions that are unit-extent.
+ while (dim < padRank && unitDims.contains(dim))
+ reassociationGroup.push_back(dim++);
+ reassociationMap.push_back(reassociationGroup);
+ reassociationGroup.clear();
+ }
+
+ Value collapsedSource =
+ collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
+ reassociationMap, options.rankReductionStrategy);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+ newHighPad, paddingVal, padOp.getNofold());
+
+ Value dest = padOp.getResult();
+ if (options.rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ SmallVector<OpFoldResult> expandedSizes;
+ int64_t numUnitDims = 0;
+ for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
+ if (unitDims.contains(dim)) {
+ expandedSizes.push_back(rewriter.getIndexAttr(1));
+ numUnitDims++;
+ continue;
+ }
+ expandedSizes.push_back(tensor::getMixedSize(
+ rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
+ }
+ dest = rewriter.create<tensor::EmptyOp>(
----------------
qedawkins wrote:
ok I'll drop it because I didn't like doing this either.
https://github.com/llvm/llvm-project/pull/84684
More information about the Mlir-commits
mailing list