[Mlir-commits] [mlir] e07149c - [mlir][linalg] Add option to generate rank-reducing slices in DropUnitDims

Matthias Springer llvmlistbot at llvm.org
Wed Dec 14 05:10:15 PST 2022


Author: Matthias Springer
Date: 2022-12-14T14:10:04+01:00
New Revision: e07149c91f4bb43c1413bc1fbe19dc6eff2fcde6

URL: https://github.com/llvm/llvm-project/commit/e07149c91f4bb43c1413bc1fbe19dc6eff2fcde6
DIFF: https://github.com/llvm/llvm-project/commit/e07149c91f4bb43c1413bc1fbe19dc6eff2fcde6.diff

LOG: [mlir][linalg] Add option to generate rank-reducing slices in DropUnitDims

This change extends the `ReplaceUnitExtents` pattern so that users can choose between of two strategies for generating rank reductions:
* CollapseShapeOp / ExpandShapeOp (was already implemented but code was cleaned up; default strategy)
* rank-reducing ExtractSliceOp / InsertSliceOp

Also add helper functions to the memref dialect that we already have on the tensor dialect: `getMixedSizes`, `createCanonicalRankReducingSubViewOp`, `rankReduceIfNeeded`.

We are using ReassociationIndices instead of ReassoicationExprs in many other places and this makes the code easier to read. Also adding a new test case (that also passed before).

Differential Revision: https://reviews.llvm.org/D139947

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 550feef376db5..05397c6775d35 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -31,7 +31,10 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
     Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
             /*default=*/"false",
            "Only folds the one-trip loops from Linalg ops on tensors "
-           "(for testing purposes only)">
+           "(for testing purposes only)">,
+    Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
+           /*default=*/"false",
+           "Generate rank-reducing slices instead of reassociative reshapes">
   ];
   let dependentDialects = [
     "linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect"

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7b3e0726effc0..13a7e6f3f2aec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -129,8 +129,12 @@ void populateFuseTensorPadWithProducerLinalgOpPatterns(
 void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
 
 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors.
-void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
+/// tensors via reassociative reshape ops.
+void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
+
+/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
+/// tensors via rank-reducing slices.
+void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
 
 /// Patterns that are used to inline constant operands into linalg generic ops.
 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 9932f36e5f1bf..bcf86b07db3b4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -54,6 +54,16 @@ Type getTensorTypeFromMemRefType(Type type);
 /// single deallocate if it exists or nullptr.
 Optional<Operation *> findDealloc(Value allocValue);
 
+/// Return the dimensions of the given memref value.
+SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
+                                        Value value);
+
+/// Create a rank-reducing SubViewOp @[0 .. 0] with strides [1 .. 1] and
+/// appropriate sizes (i.e. `memref.getSizes()`) to reduce the rank of `memref`
+/// to that of `targetShape`.
+Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc,
+                                           Value memref,
+                                           ArrayRef<int64_t> targetShape);
 } // namespace memref
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4a567b40e2e5e..5233badf0bcc1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1954,6 +1954,15 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     /// Return the dimensions of the source type that are dropped when
     /// the result is rank-reduced.
     llvm::SmallBitVector getDroppedDims();
+
+    /// Given a `value`, asserted to be of MemRefType, build a SubViewOp that
+    /// results in a rank reduction to the desired memref shape and return the
+    /// new value created.
+    /// If the shape of `value` is already the `desiredShape`, just return
+    /// `value`.
+    /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
+    static FailureOr<Value> rankReduceIfNeeded(
+      OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c0bf57cb705a6..ab288c0bb5f1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -42,6 +42,10 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::linalg;
 
+namespace {
+enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
+} // namespace
+
 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
 /// broadcasting. For example,
 ///
@@ -349,9 +353,9 @@ struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
 };
 
 struct UnitExtentReplacementInfo {
-  Type type;
   AffineMap indexMap;
-  ArrayAttr reassociation;
+  SmallVector<ReassociationIndices> reassociation;
+  SmallVector<int64_t> targetShape;
 };
 } // namespace
 
@@ -371,8 +375,6 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
-  SmallVector<AffineExpr> reassociations;
-  SmallVector<Attribute> reassociationMaps;
   SmallVector<AffineExpr> newIndexExprs;
   SmallVector<int64_t> newShape;
 
@@ -391,99 +393,110 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
   }
 
   int64_t dim = 0;
+  SmallVector<ReassociationIndices> reassociation;
+  ReassociationIndices reassociationGroup;
   // Fold dimensions that are unit-extent at the beginning of the tensor.
   while (dim < origRank && isUnitExtent(dim))
-    reassociations.push_back(getAffineDimExpr(dim++, context));
+    reassociationGroup.push_back(dim++);
   while (dim < origRank) {
-    reassociations.push_back(getAffineDimExpr(dim, context));
+    assert(!isUnitExtent(dim) && "expected non unit-extent");
+    reassociationGroup.push_back(dim);
     newIndexExprs.push_back(exprs[dim]);
     newShape.push_back(shape[dim]);
-    // Fold all following dimensions that are unit-extent.
-    while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
-      ++dim;
-      reassociations.push_back(getAffineDimExpr(dim, context));
-    }
-    reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
-        origRank, /*symbolCount = */ 0, reassociations, context)));
-    reassociations.clear();
     ++dim;
+    // Fold all following dimensions that are unit-extent.
+    while (dim < origRank && isUnitExtent(dim))
+      reassociationGroup.push_back(dim++);
+    reassociation.push_back(reassociationGroup);
+    reassociationGroup.clear();
   }
 
-  // Compute the tensor or scalar replacement type.
-  Type elementType = getElementTypeOrSelf(opOperand->get());
-  Type replacementType;
-  if (elementType == opOperand->get().getType()) {
-    replacementType = elementType;
-  } else if (actualType.isa<RankedTensorType>()) {
-    replacementType = RankedTensorType::get(newShape, elementType);
-  } else {
-    auto memrefType = actualType.cast<MemRefType>();
-    replacementType = MemRefType::get(newShape, elementType, {},
-                                      memrefType.getMemorySpaceAsInt());
-  }
-  UnitExtentReplacementInfo info = {replacementType,
-                                    AffineMap::get(indexingMap.getNumDims(),
-                                                   indexingMap.getNumSymbols(),
-                                                   newIndexExprs, context),
-                                    ArrayAttr::get(context, reassociationMaps)};
+  // Return if the rank was not reduced.
+  if (origRank == static_cast<int64_t>(newShape.size()))
+    return std::nullopt;
+
+  UnitExtentReplacementInfo info = {
+      /*indexMap=*/AffineMap::get(indexingMap.getNumDims(),
+                                  indexingMap.getNumSymbols(), newIndexExprs,
+                                  context),
+      /*reassociation=*/reassociation, /*targetShape=*/newShape};
   return info;
 }
 
 namespace {
 
-SmallVector<ReassociationExprs, 2>
-convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
-  SmallVector<ReassociationExprs, 2> reassociationExprs;
-  for (auto attr : affineMapArrayAttr)
-    reassociationExprs.push_back(
-        llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
-  return reassociationExprs;
-}
-
 /// Pattern to replace tensor/buffer operands/results that are unit extents.
 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  // Return the original value if the type is unchanged, or reshape it. Return a
-  // nullptr if this is an unsupported type.
-  Value maybeExpand(Value result, Type origResultType,
-                    ArrayAttr reassociationMap, Location loc,
+  ReplaceUnitExtents(MLIRContext *ctx,
+                     RankReductionStrategy rankReductionStrategy)
+      : OpRewritePattern<GenericOp>(ctx),
+        rankReductionStrategy(rankReductionStrategy) {}
+
+  // Expand the given value.
+  Value expandValue(Value result, Value origOutput,
+                    ArrayRef<ReassociationIndices> reassociation, Location loc,
                     PatternRewriter &rewriter) const {
-    if (origResultType == result.getType())
-      return result;
-    if (origResultType.isa<RankedTensorType>()) {
-      return rewriter.create<tensor::ExpandShapeOp>(
-          loc, origResultType, result,
-          convertAffineMapArrayToExprs(reassociationMap));
-    }
-    if (origResultType.isa<MemRefType>()) {
-      return rewriter.create<memref::ExpandShapeOp>(
-          loc, origResultType, result,
-          convertAffineMapArrayToExprs(reassociationMap));
+    // There are no results for memref outputs.
+    auto origResultType = origOutput.getType().cast<RankedTensorType>();
+    if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
+      unsigned rank = origResultType.getRank();
+      SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+      SmallVector<OpFoldResult> sizes =
+          tensor::getMixedSizes(rewriter, loc, origOutput);
+      SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+      return rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, result, origOutput, offsets, sizes, strides);
     }
-    return nullptr;
-  };
 
-  // Return the original value if the type is unchanged, or reshape it. Return a
-  // nullptr if this is an unsupported type.
-  Value maybeCollapse(Value operand, Type newInputOutputType,
-                      ArrayAttr reassociationMap, Location loc,
-                      PatternRewriter &rewriter) const {
-    auto operandType = operand.getType();
-    if (operandType == newInputOutputType)
-      return operand;
-    if (operandType.isa<MemRefType>()) {
-      return rewriter.create<memref::CollapseShapeOp>(
-          loc, newInputOutputType, operand,
-          convertAffineMapArrayToExprs(reassociationMap));
+    assert(rankReductionStrategy ==
+               RankReductionStrategy::ReassociativeReshape &&
+           "unknown rank reduction strategy");
+    return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
+                                                  reassociation);
+  }
+
+  // Collapse the given value.
+  Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
+                      ArrayRef<ReassociationIndices> reassociation,
+                      Location loc, PatternRewriter &rewriter) const {
+    if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+      if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
+        FailureOr<Value> rankReducingExtract =
+            memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
+                                                  targetShape);
+        assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+        return *rankReducingExtract;
+      }
+
+      assert(rankReductionStrategy ==
+                 RankReductionStrategy::ReassociativeReshape &&
+             "unknown rank reduction strategy");
+      MemRefLayoutAttrInterface layout;
+      auto targetType =
+          MemRefType::get(targetShape, memrefType.getElementType(), layout,
+                          memrefType.getMemorySpace());
+      return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
+                                                      reassociation);
     }
-    if (operandType.isa<RankedTensorType>()) {
-      return rewriter.create<tensor::CollapseShapeOp>(
-          loc, newInputOutputType, operand,
-          convertAffineMapArrayToExprs(reassociationMap));
+    if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
+      if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
+        FailureOr<Value> rankReducingExtract =
+            tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
+                                                       targetShape);
+        assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+        return *rankReducingExtract;
+      }
+
+      assert(rankReductionStrategy ==
+                 RankReductionStrategy::ReassociativeReshape &&
+             "unknown rank reduction strategy");
+      auto targetType =
+          RankedTensorType::get(targetShape, tensorType.getElementType());
+      return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
+                                                      reassociation);
     }
-    return nullptr;
-  };
+    llvm_unreachable("unsupported operand type");
+  }
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
@@ -495,71 +508,59 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
       return failure();
     MLIRContext *context = rewriter.getContext();
     Location loc = genericOp.getLoc();
+    SmallVector<Value> oldOutputs(genericOp.getOutputs().begin(),
+                                  genericOp.getOutputs().end());
 
     SmallVector<AffineMap> newIndexingMaps;
-    SmallVector<ArrayAttr> reassociationMaps;
-    SmallVector<Type> newInputOutputTypes;
-    bool doCanonicalization = false;
+    SmallVector<SmallVector<ReassociationIndices>> reassociations;
+    SmallVector<SmallVector<int64_t>> targetShapes;
+    SmallVector<bool> collapsed;
     for (OpOperand &opOperand : genericOp->getOpOperands()) {
       auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
       if (replacementInfo) {
-        reassociationMaps.push_back(replacementInfo->reassociation);
+        reassociations.push_back(replacementInfo->reassociation);
         newIndexingMaps.push_back(replacementInfo->indexMap);
-        newInputOutputTypes.push_back(replacementInfo->type);
-        doCanonicalization |=
-            replacementInfo->type != opOperand.get().getType();
+        targetShapes.push_back(replacementInfo->targetShape);
+        collapsed.push_back(true);
       } else {
-        // If replaceUnitExtents cannot handle this case, maintain the same
-        // type, indexing map, and create a set of mappings representing an
-        // identity matrix.
-        newInputOutputTypes.push_back(opOperand.get().getType());
+        // If replaceUnitExtents cannot handle this case (or no unit dim was
+        // removed), maintain the same type, indexing map, and create a set of
+        // mappings representing an identity matrix.
         newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
-        int64_t origRank = genericOp.getRank(&opOperand);
-        auto maps = llvm::to_vector<8>(llvm::map_range(
-            llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
-              return AffineMapAttr::get(
-                  AffineMap::get(origRank, /*symbolCount = */ 0,
-                                 getAffineDimExpr(dim, context), context));
-            }));
-        reassociationMaps.push_back(ArrayAttr::get(context, maps));
+        reassociations.emplace_back();
+        targetShapes.emplace_back();
+        collapsed.push_back(false);
       }
     }
 
-    // If the indexing maps of the result operation are not invertible (i.e. not
-    // legal), abort.
-    if (!doCanonicalization ||
+    // Abort if the indexing maps of the result operation are not invertible
+    // (i.e. not legal) or if no dimension was reduced.
+    if (!llvm::any_of(collapsed, [](bool c) { return c; }) ||
         !inversePermutation(concatAffineMaps(newIndexingMaps)))
       return failure();
 
-    // If any operand type change, insert a reshape to convert from the original
-    // type to the new type.
-    // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
-    unsigned flattenedIdx = 0;
-    auto insertReshapes = [&](ValueRange values) {
-      SmallVector<Value, 4> res;
-      res.reserve(values.size());
-      for (auto operand : values) {
-        auto reshapedValue =
-            maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
-                          reassociationMaps[flattenedIdx], loc, rewriter);
-        assert(reshapedValue &&
-               "expected ranked MemRef or Tensor operand type");
-        res.push_back(reshapedValue);
-        ++flattenedIdx;
+    // Insert rank reductions.
+    SmallVector<Value> newOperands;
+    for (OpOperand &opOperand : genericOp->getOpOperands()) {
+      int64_t idx = opOperand.getOperandNumber();
+      if (!collapsed[idx]) {
+        newOperands.push_back(opOperand.get());
+        continue;
       }
-      return res;
-    };
-
-    SmallVector<Value, 4> newInputs = insertReshapes(genericOp.getInputs());
-    SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.getOutputs());
+      newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx],
+                                          reassociations[idx], loc, rewriter));
+    }
 
     // If any result type changes, insert a reshape to convert from the original
     // type to the new type.
-    SmallVector<Type, 4> resultTypes;
+    ArrayRef<Value> newInputs =
+        ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+    ArrayRef<Value> newOutputs =
+        ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+    SmallVector<Type> resultTypes;
     resultTypes.reserve(genericOp.getNumResults());
     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
-      resultTypes.push_back(
-          newInputOutputTypes[i + genericOp.getNumDpsInputs()]);
+      resultTypes.push_back(newOutputs[i].getType());
     GenericOp replacementOp = rewriter.create<GenericOp>(
         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
         genericOp.getIteratorTypesArray());
@@ -569,20 +570,24 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
 
     // If any result tensor has a modified shape, then add reshape to recover
     // the original shape.
-    SmallVector<Value, 4> resultReplacements;
+    SmallVector<Value> resultReplacements;
     for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
       unsigned index = result.index() + replacementOp.getNumDpsInputs();
-      auto origResultType = genericOp.getResult(result.index()).getType();
-
-      auto newResult = maybeExpand(result.value(), origResultType,
-                                   reassociationMaps[index], loc, rewriter);
-      assert(newResult &&
-             "unexpected output type other than ranked MemRef or Tensor");
-      resultReplacements.push_back(newResult);
+      Value origOutput = oldOutputs[result.index()];
+      if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) {
+        resultReplacements.push_back(result.value());
+        continue;
+      }
+      resultReplacements.push_back(expandValue(
+          result.value(), origOutput, reassociations[index], loc, rewriter));
     }
+
     rewriter.replaceOp(genericOp, resultReplacements);
     return success();
   }
+
+private:
+  RankReductionStrategy rankReductionStrategy;
 };
 } // namespace
 
@@ -656,14 +661,16 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
 
 /// Patterns that are used to canonicalize the use of unit-extent dims for
 /// broadcasting.
-void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
-  patterns.add<FoldUnitDimLoops, AddInitOperandsToInput, ReplaceUnitExtents,
-               RankReducedExtractSliceOp,
-               RankReducedInsertSliceOp<tensor::InsertSliceOp>,
-               RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
-      context);
+  patterns.add<ReplaceUnitExtents>(context,
+                                   RankReductionStrategy::ReassociativeReshape);
+  // TODO: Patterns unrelated to unit dim folding should be factored out.
+  patterns
+      .add<FoldUnitDimLoops, AddInitOperandsToInput, RankReducedExtractSliceOp,
+           RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+           RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(context);
   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
@@ -673,6 +680,14 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 
+void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
+    RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.add<ReplaceUnitExtents>(context,
+                                   RankReductionStrategy::ExtractInsertSlice);
+  patterns.add<FoldUnitDimLoops>(context);
+}
+
 namespace {
 /// Pass that removes unit-extent dims within generic ops.
 struct LinalgFoldUnitExtentDimsPass
@@ -681,10 +696,13 @@ struct LinalgFoldUnitExtentDimsPass
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
     RewritePatternSet patterns(context);
-    if (foldOneTripLoopsOnly)
+    if (foldOneTripLoopsOnly) {
       patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
-    else
-      populateFoldUnitExtentDimsPatterns(patterns);
+    } else if (useRankReducingSlices) {
+      populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
+    } else {
+      populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
+    }
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index de9690fd2eb78..1d9011091c3a8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -109,6 +109,21 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
+SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
+                                                Location loc, Value value) {
+  auto memrefType = value.getType().cast<MemRefType>();
+  SmallVector<OpFoldResult> result;
+  for (int64_t i = 0; i < memrefType.getRank(); ++i) {
+    if (memrefType.isDynamicDim(i)) {
+      Value size = builder.create<memref::DimOp>(loc, value, i);
+      result.push_back(size);
+    } else {
+      result.push_back(builder.getIndexAttr(memrefType.getDimSize(i)));
+    }
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions for propagating static information
 //===----------------------------------------------------------------------===//
@@ -2912,6 +2927,35 @@ static MemRefType getCanonicalSubViewResultType(
                                        mixedStrides);
 }
 
+Value mlir::memref::createCanonicalRankReducingSubViewOp(
+    OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
+  auto memrefType = memref.getType().cast<MemRefType>();
+  unsigned rank = memrefType.getRank();
+  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
+  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
+  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+  auto targetType = SubViewOp::inferRankReducedResultType(
+                        targetShape, memrefType, offsets, sizes, strides)
+                        .cast<MemRefType>();
+  return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
+                                           sizes, strides);
+}
+
+FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
+                                               Value value,
+                                               ArrayRef<int64_t> desiredShape) {
+  auto sourceMemrefType = value.getType().dyn_cast<MemRefType>();
+  assert(sourceMemrefType && "not a ranked memref type");
+  auto sourceShape = sourceMemrefType.getShape();
+  if (sourceShape.equals(desiredShape))
+    return value;
+  auto maybeRankReductionMask =
+      mlir::computeRankReductionMask(sourceShape, desiredShape);
+  if (!maybeRankReductionMask)
+    return failure();
+  return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
+}
+
 /// Helper method to check if a `subview` operation is trivially a no-op. This
 /// is the case if the all offsets are zero, all strides are 1, and the source
 /// shape is same as the size of the subview. In such cases, the subview can

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index f9f06b0b55464..cfce3b3eea70e 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(linalg-fold-unit-extent-dims))" | FileCheck %s
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="use-rank-reducing-slices" -cse -split-input-file | FileCheck %s --check-prefix=CHECK-SLICES
 
 #accesses = [
   affine_map<(i, j, k, l, m) -> (i, k, m)>,
@@ -26,11 +27,57 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
 //   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @drop_one_trip_loops
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
+//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
 
+//   CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//   CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+//   CHECK-SLICES-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-SLICES-LABEL: func @drop_one_trip_loops
+//       CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [%{{.*}}, 1, %{{.*}}] [1, 1, 1] : tensor<?x1x?xf32> to tensor<?x?xf32>
+//       CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor<?x1x?x1x?xf32> to tensor<?x?x?xf32>
+//       CHECK-SLICES: linalg.generic
+//  CHECK-SLICES-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+//  CHECK-SLICES-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
+//       CHECK-SLICES: tensor.insert_slice %{{.*}} into %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
+
+
+// -----
+
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> ()>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
+  %0 = linalg.generic #trait
+     ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32)
+    outs(%shape : tensor<?x1x?x1x?xf32>) {
+       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+         linalg.yield %arg3 : f32
+       } -> tensor<?x1x?x1x?xf32>
+  return %0 : tensor<?x1x?x1x?xf32>
+}
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
+//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, d0, 0)>
+// CHECK-LABEL: func @drop_one_trip_loops_all_ones
+//       CHECK: tensor.collapse_shape %{{.*}} []
+//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
+//  CHECK-SAME:   iterator_types = ["parallel"]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+
 // -----
 
 #accesses = [
@@ -871,3 +918,8 @@ func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3>
 //       CHECK:   memref.collapse_shape
 //  CHECK-SAME:     [] : memref<1x1xf32, 3> into memref<f32, 3>
 //       CHECK:   linalg.generic{{.*}}memref<f32, 3>
+
+// CHECK-SLICES-LABEL: func @drop_all_loops
+//       CHECK-SLICES:   memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref<f32, strided<[]>, 3>
+//       CHECK-SLICES:   linalg.generic{{.*}}memref<f32, strided<[]>, 3>
+


        


More information about the Mlir-commits mailing list