[Mlir-commits] [mlir] 6739993 - [mlir][Linalg] Cleanup the drop unit dims pass in Linalg.

Mahesh Ravishankar llvmlistbot at llvm.org
Wed Jul 19 10:47:38 PDT 2023


Author: Mahesh Ravishankar
Date: 2023-07-19T17:47:18Z
New Revision: 67399932c767f0a64c83a500dc6f7806c09d9401

URL: https://github.com/llvm/llvm-project/commit/67399932c767f0a64c83a500dc6f7806c09d9401
DIFF: https://github.com/llvm/llvm-project/commit/67399932c767f0a64c83a500dc6f7806c09d9401.diff

LOG: [mlir][Linalg] Cleanup the drop unit dims pass in Linalg.

TL;DR the following API functions have been merged

```
void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
```

into

```
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
                                        ControlDropUnitDims &options);
```

To use the previous functionality use

```
ControlDropUnitDims options;
// By default options.rankReductionStrategy is
// ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape.
populateFoldUnitExtentDimsPatterns(patterns, options);
```

and

```
ControlDropUnitDims options;
options.rankReductionStrategy = ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice
populateFoldUnitExtentDimsPatterns(patterns, options);

```

This pass is quite old and needed to be updated based on the current
approach to transformations in Linalg

- Instead of two patterns, one to just remove loop dimensions that are
  unit extent (and using 0 in the indexing maps), and another to drop
  the unit-extents in the operand shapes, combine into a single
  transformation. This avoid creating an intermediate step with
  indexing maps having 0's in the domains exp ressions.

- Expose the core transformation as a utility function and add a
  pattern that calls this transformation.

This is a mostly NFC change, apart from the API change and dropping
the patterns/test that only dropped the loops that are unit extents.

Differential Revision: https://reviews.llvm.org/D155518

Added: 
    mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 1ed867b3a0df75..3093604af63e33 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -28,10 +28,6 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
   let options = [
-    Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
-            /*default=*/"false",
-           "Only folds the one-trip loops from Linalg ops on tensors "
-           "(for testing purposes only)">,
     Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
            /*default=*/"false",
            "Generate rank-reducing slices instead of reassociative reshapes">

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a78dc1e1e571bc..68fce05c0d2221 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -419,6 +419,25 @@ LogicalResult vectorizeOpPrecondition(Operation *op,
 
 using LinalgLoops = SmallVector<Operation *, 4>;
 
+/// Transformation to drop unit-extent dimensions from `linalg.generic`
+/// operations.
+struct ControlDropUnitDims {
+  enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
+
+  RankReductionStrategy rankReductionStrategy =
+      RankReductionStrategy::ReassociativeReshape;
+
+  using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
+  ControlFnTy controlFn = [](Operation *op) {
+    if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
+      return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops()));
+    }
+    return SmallVector<unsigned>{};
+  };
+};
+LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+                           const ControlDropUnitDims &options);
+
 /// Fuse two `linalg.generic` operations that have a producer-consumer
 /// relationship captured through `fusedOperand`. The method expects
 /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
@@ -1496,11 +1515,8 @@ void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
 
 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
 /// tensors via reassociative reshape ops.
-void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
-
-/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors via rank-reducing slices.
-void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
+void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
+                                        ControlDropUnitDims &options);
 
 /// A pattern that converts init operands to input operands.
 void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f6e0f27548a20c..40b602ffd4f80b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -156,12 +156,16 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
 
 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
+  linalg::ControlDropUnitDims options;
+  linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
 }
 
 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
+  linalg::ControlDropUnitDims options;
+  options.rankReductionStrategy =
+      linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
+  linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
 }
 
 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 894036a535302e..b33c75ca94fc0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -43,196 +43,6 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 namespace {
-enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
-} // namespace
-
-/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
-/// broadcasting. For example,
-///
-/// ```mlir
-/// #accesses = [
-///   affine_map<(d0, d1) -> (0, d1)>,
-///   affine_map<(d0, d1) -> (d0, 0)>,
-///   affine_map<(d0, d1) -> (d0, d1)>
-/// ]
-///
-/// #trait = {
-///   args_in = 2,
-///   args_out = 1,
-///   indexing_maps = #accesses,
-///   iterator_types = ["parallel", "parallel"],
-///   library_call = "some_external_fn"
-/// }
-///
-/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
-/// tensor<5x5xf32>
-/// {
-///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
-///        tensor<5xf32> into tensor<1x5xf32>
-///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
-///        tensor<5xf32> into tensor<5x1xf32>
-///   %2 = linalg.generic #trait %0, %1 {
-///        ^bb0(%arg2: f32, %arg3: f32):
-///          %3 = arith.addf %arg2, %arg3 : f32
-///          linalg.yield %3 : f32
-///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
-///   return %2 : tensor<5x5xf32>
-/// }
-///
-/// would canonicalize to
-///
-/// ```mlir
-/// #accesses = [
-///   affine_map<(d0, d1) -> (d1)>,
-///   affine_map<(d0, d1) -> (d0)>,
-///   affine_map<(d0, d1) -> (d0, d1)>
-/// ]
-///
-/// #trait = {
-///   args_in = 2,
-///   args_out = 1,
-///   indexing_maps = #accesses,
-///   iterator_types = ["parallel", "parallel"],
-///   library_call = "some_external_fn"
-/// }
-///
-/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
-/// tensor<5x5xf32>
-/// {
-///   %0 = linalg.generic #trait %arg0, %arg1 {
-///        ^bb0(%arg2: f32, %arg3: f32):
-///          %3 = arith.addf %arg2, %arg3 : f32
-///          linalg.yield %3 : f32
-///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
-///   return %0 : tensor<5x5xf32>
-/// }
-
-/// Given dims of the iteration space of a structured op that are known to be
-/// single trip count (`unitDims`), return the indexing maps to use in the
-/// canonicalized op with these dims removed, given the original `indexingMaps`.
-static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
-                                 ArrayRef<AffineMap> indexingMaps,
-                                 MLIRContext *context) {
-  if (indexingMaps.empty())
-    return nullptr;
-  unsigned numIterationDims = indexingMaps.front().getNumDims();
-  unsigned numSymbols = indexingMaps.front().getNumSymbols();
-
-  // Compute the replacement for each dim expr.
-  SmallVector<AffineExpr, 4> dimReplacements;
-  dimReplacements.reserve(numIterationDims);
-  unsigned numKeptDims = 0;
-  for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
-    if (unitDims.count(dim))
-      dimReplacements.push_back(getAffineConstantExpr(0, context));
-    else
-      dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
-  }
-
-  // Symbols remain the same.
-  SmallVector<AffineExpr, 4> symReplacements;
-  symReplacements.reserve(numSymbols);
-  for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
-    symReplacements.push_back(getAffineSymbolExpr(symbol, context));
-
-  SmallVector<AffineMap, 4> newIndexingMaps;
-  newIndexingMaps.reserve(indexingMaps.size());
-  for (AffineMap operandMap : indexingMaps) {
-    // Expected indexing maps to have no symbols.
-    if (operandMap.getNumSymbols())
-      return nullptr;
-    newIndexingMaps.push_back(simplifyAffineMap(
-        operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
-                                         numIterationDims - unitDims.size(),
-                                         numSymbols)));
-  }
-
-  // Check that the new index maps are invertible. If not, something went
-  // wrong, so abort.
-  if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
-    return nullptr;
-  return ArrayAttr::get(context,
-                        llvm::to_vector<4>(llvm::map_range(
-                            newIndexingMaps, [](AffineMap map) -> Attribute {
-                              return AffineMapAttr::get(map);
-                            })));
-}
-
-/// Update the index accesses of linalg operations having index semantics.
-static void replaceUnitDimIndexOps(GenericOp genericOp,
-                                   const DenseSet<unsigned> &unitDims,
-                                   PatternRewriter &rewriter) {
-  for (IndexOp indexOp :
-       llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPoint(indexOp);
-    if (unitDims.count(indexOp.getDim()) != 0) {
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
-    } else {
-      // Update the dimension of the index operation if needed.
-      unsigned droppedDims = llvm::count_if(
-          unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
-      if (droppedDims != 0)
-        rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
-                                             indexOp.getDim() - droppedDims);
-    }
-  }
-}
-
-namespace {
-/// Pattern to fold unit-trip count loops in GenericOps.
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray();
-    if (indexingMaps.empty())
-      return failure();
-
-    // Check if any of the iteration dimensions are unit-trip count. They will
-    // end up being unit-trip count if they are used to index into a unit-dim
-    // tensor/memref.
-    AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
-    if (!invertedMap)
-      return failure();
-    SmallVector<int64_t> dims = genericOp.getStaticShape();
-
-    DenseSet<unsigned> unitDims;
-    SmallVector<unsigned, 4> unitDimsReductionLoops;
-    ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
-    for (const auto &expr : enumerate(invertedMap.getResults())) {
-      if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
-        if (dims[dimExpr.getPosition()] == 1)
-          unitDims.insert(expr.index());
-    }
-
-    if (unitDims.empty())
-      return failure();
-
-    // Compute the modified indexing maps.
-    MLIRContext *context = rewriter.getContext();
-    ArrayAttr newIndexingMapAttr =
-        replaceUnitDims(unitDims, indexingMaps, context);
-    if (!newIndexingMapAttr)
-      return genericOp.emitError("unable to compute modified indexing_maps");
-
-    // Compute the iterator types of the modified op by dropping the one-trip
-    // count loops.
-    SmallVector<Attribute, 4> newIteratorTypes;
-    for (const auto &attr : llvm::enumerate(iteratorTypes)) {
-      if (!unitDims.count(attr.index()))
-        newIteratorTypes.push_back(attr.value());
-    }
-
-    rewriter.startRootUpdate(genericOp);
-    genericOp.setIndexingMapsAttr(newIndexingMapAttr);
-    genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes));
-    replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
-    rewriter.finalizeRootUpdate(genericOp);
-    return success();
-  }
-};
-
 /// Pattern to move init operands to ins when all the loops are parallel and
 /// blockArgument corresponding to init is used in the region. This is a fix-up
 /// when unit reduction dimensions are all folded away. In this context, it
@@ -351,243 +161,405 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Drop loops that are unit-extents within Linalg operations.
+//===---------------------------------------------------------------------===//
+
+/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
+/// broadcasting. For example,
+///
+/// ```mlir
+/// #accesses = [
+///   affine_map<(d0, d1) -> (0, d1)>,
+///   affine_map<(d0, d1) -> (d0, 0)>,
+///   affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+///   args_in = 2,
+///   args_out = 1,
+///   indexing_maps = #accesses,
+///   iterator_types = ["parallel", "parallel"],
+///   library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+///        tensor<5xf32> into tensor<1x5xf32>
+///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+///        tensor<5xf32> into tensor<5x1xf32>
+///   %2 = linalg.generic #trait %0, %1 {
+///        ^bb0(%arg2: f32, %arg3: f32):
+///          %3 = arith.addf %arg2, %arg3 : f32
+///          linalg.yield %3 : f32
+///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+///   return %2 : tensor<5x5xf32>
+/// }
+///
+/// would canonicalize to
+///
+/// ```mlir
+/// #accesses = [
+///   affine_map<(d0, d1) -> (d1)>,
+///   affine_map<(d0, d1) -> (d0)>,
+///   affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+///   args_in = 2,
+///   args_out = 1,
+///   indexing_maps = #accesses,
+///   iterator_types = ["parallel", "parallel"],
+///   library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+///   %0 = linalg.generic #trait %arg0, %arg1 {
+///        ^bb0(%arg2: f32, %arg3: f32):
+///          %3 = arith.addf %arg2, %arg3 : f32
+///          linalg.yield %3 : f32
+///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
+///   return %0 : tensor<5x5xf32>
+/// }
 
+/// Update the index accesses of linalg operations having index semantics.
+static void
+replaceUnitDimIndexOps(GenericOp genericOp,
+                       const llvm::SmallDenseSet<unsigned> &unitDims,
+                       RewriterBase &rewriter) {
+  for (IndexOp indexOp :
+       llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(indexOp);
+    if (unitDims.count(indexOp.getDim()) != 0) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
+    } else {
+      // Update the dimension of the index operation if needed.
+      unsigned droppedDims = llvm::count_if(
+          unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
+      if (droppedDims != 0)
+        rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
+                                             indexOp.getDim() - droppedDims);
+    }
+  }
+}
+
+/// Expand the given `value` so that the type matches the type of `origDest`.
+/// The `reassociation` is used when `rankReductionStrategy` is set to
+/// `RankReductionStrategy::ReassociativeReshape`.
+static Value
+expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
+            ArrayRef<ReassociationIndices> reassociation,
+            ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+  // There are no results for memref outputs.
+  auto origResultType = cast<RankedTensorType>(origDest.getType());
+  if (rankReductionStrategy ==
+      ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+    unsigned rank = origResultType.getRank();
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, origDest);
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    return rewriter.createOrFold<tensor::InsertSliceOp>(
+        loc, result, origDest, offsets, sizes, strides);
+  }
+
+  assert(rankReductionStrategy ==
+             ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+         "unknown rank reduction strategy");
+  return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
+                                                reassociation);
+}
+
+/// Collapse the given `value` so that the type matches the type of
+/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
+/// set to `RankReductionStrategy::ReassociativeReshape`.
+static Value collapseValue(
+    RewriterBase &rewriter, Location loc, Value operand,
+    ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
+    ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
+    if (rankReductionStrategy ==
+        ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+      FailureOr<Value> rankReducingExtract =
+          memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
+                                                targetShape);
+      assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+      return *rankReducingExtract;
+    }
+
+    assert(
+        rankReductionStrategy ==
+            ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+        "unknown rank reduction strategy");
+    MemRefLayoutAttrInterface layout;
+    auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
+                                      layout, memrefType.getMemorySpace());
+    return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
+                                                    reassociation);
+  }
+  if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
+    if (rankReductionStrategy ==
+        ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+      FailureOr<Value> rankReducingExtract =
+          tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
+                                                     targetShape);
+      assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+      return *rankReducingExtract;
+    }
+
+    assert(
+        rankReductionStrategy ==
+            ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+        "unknown rank reduction strategy");
+    auto targetType =
+        RankedTensorType::get(targetShape, tensorType.getElementType());
+    return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
+                                                    reassociation);
+  }
+  llvm_unreachable("unsupported operand type");
+}
+
+/// Compute the modified metadata for an operands of operation
+/// whose unit dims are being dropped. Return the new indexing map
+/// to use, the shape of the operand in the replacement op
+/// and the `reassocation` to use to go from original operand shape
+/// to modified operand shape.
 struct UnitExtentReplacementInfo {
   AffineMap indexMap;
   SmallVector<ReassociationIndices> reassociation;
   SmallVector<int64_t> targetShape;
 };
-} // namespace
-
-/// Utility function for replacing operands/results to a linalg generic
-/// operation with unit-extent dimensions. These can be replaced with
-/// an operand/result with the unit-extent dimension removed. This is only done
-/// if the indexing map used to access that dimension has a
-/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
-/// Linalg op, and its `indexMap` the utility function returns:
-/// - the new type with dimensions of size 1 removed.
-/// - modified index map that can be used to access the replaced result/operand
-/// - the reassociation that converts from the original tensor type to the
-///   modified tensor type.
-static std::optional<UnitExtentReplacementInfo>
-replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
-                   MLIRContext *context) {
+static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
+    MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+    llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
+    ArrayRef<AffineExpr> dimReplacements) {
+  UnitExtentReplacementInfo info;
+  ReassociationIndices reassociationGroup;
+  SmallVector<AffineExpr> newIndexExprs;
   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
-  ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
+  ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
-  SmallVector<AffineExpr> newIndexExprs;
-  SmallVector<int64_t> newShape;
 
-  int64_t origRank = genericOp.getRank(opOperand);
-  AffineExpr zeroExpr = getAffineConstantExpr(0, context);
-  auto isUnitExtent = [&](int64_t dim) -> bool {
-    return shape[dim] == 1 && exprs[dim] == zeroExpr;
+  auto isUnitDim = [&](unsigned dim) {
+    if (auto dimExpr = exprs[dim].dyn_cast<AffineDimExpr>()) {
+      unsigned oldPosition = dimExpr.getPosition();
+      return !oldDimsToNewDimsMap.count(oldPosition);
+    }
+    // Handle the other case where the shape is 1, and is accessed using a
+    // constant 0.
+    if (operandShape[dim] == 1) {
+      auto constAffineExpr = exprs[dim].dyn_cast<AffineConstantExpr>();
+      return constAffineExpr && constAffineExpr.getValue() == 0;
+    }
+    return false;
   };
 
-  // Early return for memrefs with affine maps to represent that we will always
-  // leave them unchanged.
-  Type actualType = opOperand->get().getType();
-  if (auto memref = dyn_cast<MemRefType>(actualType)) {
-    if (!memref.getLayout().isIdentity())
-      return std::nullopt;
-  }
-
   int64_t dim = 0;
-  SmallVector<ReassociationIndices> reassociation;
-  ReassociationIndices reassociationGroup;
-  // Fold dimensions that are unit-extent at the beginning of the tensor.
-  while (dim < origRank && isUnitExtent(dim))
+  while (dim < operandShape.size() && isUnitDim(dim))
     reassociationGroup.push_back(dim++);
-  while (dim < origRank) {
-    assert(!isUnitExtent(dim) && "expected non unit-extent");
+  while (dim < operandShape.size()) {
+    assert(!isUnitDim(dim) && "expected non unit-extent");
     reassociationGroup.push_back(dim);
-    newIndexExprs.push_back(exprs[dim]);
-    newShape.push_back(shape[dim]);
+    AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+    newIndexExprs.push_back(newExpr);
+    info.targetShape.push_back(operandShape[dim]);
     ++dim;
     // Fold all following dimensions that are unit-extent.
-    while (dim < origRank && isUnitExtent(dim))
+    while (dim < operandShape.size() && isUnitDim(dim)) {
       reassociationGroup.push_back(dim++);
-    reassociation.push_back(reassociationGroup);
+    }
+    info.reassociation.push_back(reassociationGroup);
     reassociationGroup.clear();
   }
-
-  // Return if the rank was not reduced.
-  if (origRank == static_cast<int64_t>(newShape.size()))
-    return std::nullopt;
-
-  UnitExtentReplacementInfo info = {
-      /*indexMap=*/AffineMap::get(indexingMap.getNumDims(),
-                                  indexingMap.getNumSymbols(), newIndexExprs,
-                                  context),
-      /*reassociation=*/reassociation, /*targetShape=*/newShape};
+  info.indexMap =
+      AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
+                     newIndexExprs, context);
   return info;
 }
 
-namespace {
+LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+                                   const ControlDropUnitDims &options) {
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  if (indexingMaps.empty())
+    return failure();
+
+  // 1. Check if any of the iteration dimensions are unit-trip count. They will
+  //    end up being unit-trip count if they are used to index into a unit-dim
+  //    tensor/memref.
+  AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
+  if (!invertedMap) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "invalid indexing maps for operation");
+  }
+  SmallVector<int64_t> dims = genericOp.getStaticShape();
 
-/// Pattern to replace tensor/buffer operands/results that are unit extents.
-struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
-  ReplaceUnitExtents(MLIRContext *ctx,
-                     RankReductionStrategy rankReductionStrategy)
-      : OpRewritePattern<GenericOp>(ctx),
-        rankReductionStrategy(rankReductionStrategy) {}
-
-  // Expand the given value.
-  Value expandValue(Value result, Value origOutput,
-                    ArrayRef<ReassociationIndices> reassociation, Location loc,
-                    PatternRewriter &rewriter) const {
-    // There are no results for memref outputs.
-    auto origResultType = cast<RankedTensorType>(origOutput.getType());
-    if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
-      unsigned rank = origResultType.getRank();
-      SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-      SmallVector<OpFoldResult> sizes =
-          tensor::getMixedSizes(rewriter, loc, origOutput);
-      SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-      return rewriter.createOrFold<tensor::InsertSliceOp>(
-          loc, result, origOutput, offsets, sizes, strides);
+  // 1a. Get the allowed list of dimensions to drop from the `options`.
+  SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+  if (allowedUnitDims.empty()) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "control function returns no allowed unit dims to prune");
+  }
+  llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
+                                               allowedUnitDims.end());
+  llvm::SmallDenseSet<unsigned> unitDims;
+  ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
+  for (const auto &expr : enumerate(invertedMap.getResults())) {
+    if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) {
+      if (dims[dimExpr.getPosition()] == 1 &&
+          unitDimsFilter.count(expr.index()))
+        unitDims.insert(expr.index());
     }
+  }
 
-    assert(rankReductionStrategy ==
-               RankReductionStrategy::ReassociativeReshape &&
-           "unknown rank reduction strategy");
-    return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
-                                                  reassociation);
+  // 2. Compute the iterator types of the modified op by dropping the one-trip
+  //    count loops.
+  SmallVector<utils::IteratorType> newIteratorTypes;
+  llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
+  SmallVector<AffineExpr> dimReplacements;
+  unsigned newDims = 0;
+  for (auto [index, attr] :
+       llvm::enumerate(genericOp.getIteratorTypesArray())) {
+    if (unitDims.count(index)) {
+      dimReplacements.push_back(
+          getAffineConstantExpr(0, rewriter.getContext()));
+    } else {
+      newIteratorTypes.push_back(attr);
+      oldDimToNewDimMap[index] = newDims;
+      dimReplacements.push_back(
+          getAffineDimExpr(newDims, rewriter.getContext()));
+      newDims++;
+    }
   }
 
-  // Collapse the given value.
-  Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
-                      ArrayRef<ReassociationIndices> reassociation,
-                      Location loc, PatternRewriter &rewriter) const {
-    if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
-      if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
-        FailureOr<Value> rankReducingExtract =
-            memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
-                                                  targetShape);
-        assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
-        return *rankReducingExtract;
-      }
-
-      assert(rankReductionStrategy ==
-                 RankReductionStrategy::ReassociativeReshape &&
-             "unknown rank reduction strategy");
-      MemRefLayoutAttrInterface layout;
-      auto targetType =
-          MemRefType::get(targetShape, memrefType.getElementType(), layout,
-                          memrefType.getMemorySpace());
-      return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
-                                                      reassociation);
+  // 3. For each of the operands, find the
+  //    - modified affine map to use.
+  //    - shape of the operands after the unit-dims are dropped.
+  //    - the reassociation indices used to convert from the original
+  //      operand type to modified operand (needed only when using reshapes
+  //      for rank reduction strategy)
+  // Note that the indexing maps might need changing even if there are no
+  // unit dimensions that are dropped to handle cases where `0` is used to
+  // access a unit-extent tensor. Consider moving this out of this specific
+  // transformation as a stand-alone transformation. Kept here right now due
+  // to legacy.
+  SmallVector<AffineMap> newIndexingMaps;
+  SmallVector<SmallVector<ReassociationIndices>> reassociations;
+  SmallVector<SmallVector<int64_t>> targetShapes;
+  SmallVector<bool> collapsed;
+  auto hasCollapsibleType = [](OpOperand &operand) {
+    Type operandType = operand.get().getType();
+    if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
+      return memrefOperandType.getLayout().isIdentity();
+    } else if (auto tensorOperandType =
+                   dyn_cast<RankedTensorType>(operandType)) {
+      return tensorOperandType.getEncoding() == nullptr;
     }
-    if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
-      if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
-        FailureOr<Value> rankReducingExtract =
-            tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
-                                                       targetShape);
-        assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
-        return *rankReducingExtract;
-      }
-
-      assert(rankReductionStrategy ==
-                 RankReductionStrategy::ReassociativeReshape &&
-             "unknown rank reduction strategy");
-      auto targetType =
-          RankedTensorType::get(targetShape, tensorType.getElementType());
-      return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
-                                                      reassociation);
+    return false;
+  };
+  for (OpOperand &opOperand : genericOp->getOpOperands()) {
+    auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
+    ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+    if (!hasCollapsibleType(opOperand)) {
+      AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
+          dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
+      newIndexingMaps.push_back(newIndexingMap);
+      targetShapes.push_back(llvm::to_vector(shape));
+      collapsed.push_back(false);
+      reassociations.push_back({});
+      continue;
     }
-    llvm_unreachable("unsupported operand type");
+    auto replacementInfo = dropUnitExtentFromOperandMetadata(
+        rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
+        dimReplacements);
+    reassociations.push_back(replacementInfo.reassociation);
+    newIndexingMaps.push_back(replacementInfo.indexMap);
+    targetShapes.push_back(replacementInfo.targetShape);
+    collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
+                          indexingMap.getNumResults()));
   }
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    // Skip the pattern if the op has any tensor with special encoding.
-    if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
-          auto tensorType = dyn_cast<RankedTensorType>(type);
-          return tensorType && tensorType.getEncoding() != nullptr;
-        }))
-      return failure();
-    MLIRContext *context = rewriter.getContext();
-    Location loc = genericOp.getLoc();
-    SmallVector<Value> oldOutputs(genericOp.getOutputs().begin(),
-                                  genericOp.getOutputs().end());
-
-    SmallVector<AffineMap> newIndexingMaps;
-    SmallVector<SmallVector<ReassociationIndices>> reassociations;
-    SmallVector<SmallVector<int64_t>> targetShapes;
-    SmallVector<bool> collapsed;
-    for (OpOperand &opOperand : genericOp->getOpOperands()) {
-      auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
-      if (replacementInfo) {
-        reassociations.push_back(replacementInfo->reassociation);
-        newIndexingMaps.push_back(replacementInfo->indexMap);
-        targetShapes.push_back(replacementInfo->targetShape);
-        collapsed.push_back(true);
-      } else {
-        // If replaceUnitExtents cannot handle this case (or no unit dim was
-        // removed), maintain the same type, indexing map, and create a set of
-        // mappings representing an identity matrix.
-        newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
-        reassociations.emplace_back();
-        targetShapes.emplace_back();
-        collapsed.push_back(false);
-      }
+  // Abort if the indexing maps of the result operation are not invertible
+  // (i.e. not legal) or if no dimension was reduced.
+  if (newIndexingMaps == indexingMaps ||
+      !inversePermutation(concatAffineMaps(newIndexingMaps)))
+    return failure();
+
+  Location loc = genericOp.getLoc();
+  // 4. For each of the operands, collapse the operand to convert
+  //    from original shape to shape in the modified operation if needed,
+  //    either through use of reshapes or rank-reducing slices as
+  //    specified in `options`.
+  SmallVector<Value> newOperands;
+  for (OpOperand &opOperand : genericOp->getOpOperands()) {
+    int64_t idx = opOperand.getOperandNumber();
+    if (!collapsed[idx]) {
+      newOperands.push_back(opOperand.get());
+      continue;
     }
+    newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
+                                        targetShapes[idx], reassociations[idx],
+                                        options.rankReductionStrategy));
+  }
 
-    // Abort if the indexing maps of the result operation are not invertible
-    // (i.e. not legal) or if no dimension was reduced.
-    if (!llvm::any_of(collapsed, [](bool c) { return c; }) ||
-        !inversePermutation(concatAffineMaps(newIndexingMaps)))
-      return failure();
-
-    // Insert rank reductions.
-    SmallVector<Value> newOperands;
-    for (OpOperand &opOperand : genericOp->getOpOperands()) {
-      int64_t idx = opOperand.getOperandNumber();
-      if (!collapsed[idx]) {
-        newOperands.push_back(opOperand.get());
-        continue;
-      }
-      newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx],
-                                          reassociations[idx], loc, rewriter));
+  // 5. Create the `linalg.generic` operation with the new operands,
+  //    indexing maps, iterator types and result types.
+  ArrayRef<Value> newInputs =
+      ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+  ArrayRef<Value> newOutputs =
+      ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+  SmallVector<Type> resultTypes;
+  resultTypes.reserve(genericOp.getNumResults());
+  for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+    resultTypes.push_back(newOutputs[i].getType());
+  GenericOp replacementOp =
+      rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
+                                 newIndexingMaps, newIteratorTypes);
+  rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+                              replacementOp.getRegion().begin());
+  // 5a. Replace `linalg.index` operations that refer to the dropped unit
+  // dimensions.
+  replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+
+  // 6. If any result type changes, insert a reshape/slice to convert from the
+  // original
+  //    type to the new type.
+  SmallVector<Value> resultReplacements;
+  for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
+    unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
+    Value origDest = genericOp.getDpsInitOperand(index)->get();
+    if (!collapsed[opOperandIndex]) {
+      resultReplacements.push_back(result);
+      continue;
     }
+    resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
+                                             reassociations[opOperandIndex],
+                                             options.rankReductionStrategy));
+  }
 
-    // If any result type changes, insert a reshape to convert from the original
-    // type to the new type.
-    ArrayRef<Value> newInputs =
-        ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
-    ArrayRef<Value> newOutputs =
-        ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
-    SmallVector<Type> resultTypes;
-    resultTypes.reserve(genericOp.getNumResults());
-    for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
-      resultTypes.push_back(newOutputs[i].getType());
-    GenericOp replacementOp = rewriter.create<GenericOp>(
-        loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
-        genericOp.getIteratorTypesArray());
-    rewriter.inlineRegionBefore(genericOp.getRegion(),
-                                replacementOp.getRegion(),
-                                replacementOp.getRegion().begin());
-
-    // If any result tensor has a modified shape, then add reshape to recover
-    // the original shape.
-    SmallVector<Value> resultReplacements;
-    for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
-      unsigned index = result.index() + replacementOp.getNumDpsInputs();
-      Value origOutput = oldOutputs[result.index()];
-      if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) {
-        resultReplacements.push_back(result.value());
-        continue;
-      }
-      resultReplacements.push_back(expandValue(
-          result.value(), origOutput, reassociations[index], loc, rewriter));
-    }
+  rewriter.replaceOp(genericOp, resultReplacements);
+  return success();
+}
 
-    rewriter.replaceOp(genericOp, resultReplacements);
-    return success();
+namespace {
+struct DropUnitDims : public OpRewritePattern<GenericOp> {
+  DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
+               PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), options(std::move(options)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    return dropUnitDims(rewriter, genericOp, options);
   }
 
 private:
-  RankReductionStrategy rankReductionStrategy;
+  ControlDropUnitDims options;
 };
 } // namespace
 
@@ -641,8 +613,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
     tensor::CollapseShapeOp reshapedSource;
     {
       OpBuilder::InsertionGuard g(rewriter);
-      // The only 
diff erence between InsertSliceOp and ParallelInsertSliceOp is
-      // the insertion point is just before the ParallelCombiningOp in the
+      // The only 
diff erence between InsertSliceOp and ParallelInsertSliceOp
+      // is the insertion point is just before the ParallelCombiningOp in the
       // parallel case.
       if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
@@ -660,13 +632,13 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
 
 /// Patterns that are used to canonicalize the use of unit-extent dims for
 /// broadcasting.
-void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
-    RewritePatternSet &patterns) {
+static void
+populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
+                                              ControlDropUnitDims &options) {
   auto *context = patterns.getContext();
-  patterns.add<ReplaceUnitExtents>(context,
-                                   RankReductionStrategy::ReassociativeReshape);
+  patterns.add<DropUnitDims>(context, options);
   // TODO: Patterns unrelated to unit dim folding should be factored out.
-  patterns.add<FoldUnitDimLoops, RankReducedExtractSliceOp,
+  patterns.add<RankReducedExtractSliceOp,
                RankReducedInsertSliceOp<tensor::InsertSliceOp>,
                RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
       context);
@@ -679,12 +651,13 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 
-void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
-    RewritePatternSet &patterns) {
+static void
+populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
+                                            ControlDropUnitDims &options) {
   auto *context = patterns.getContext();
-  patterns.add<ReplaceUnitExtents>(context,
-                                   RankReductionStrategy::ExtractInsertSlice);
-  patterns.add<FoldUnitDimLoops>(context);
+  options.rankReductionStrategy =
+      ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
+  patterns.add<DropUnitDims>(context, options);
   // TODO: Patterns unrelated to unit dim folding should be factored out.
   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
@@ -693,6 +666,18 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+    RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
+  if (options.rankReductionStrategy ==
+      linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+    populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
+  } else if (options.rankReductionStrategy ==
+             linalg::ControlDropUnitDims::RankReductionStrategy::
+                 ReassociativeReshape) {
+    populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
+  }
+}
+
 void mlir::linalg::populateMoveInitOperandsToInputPattern(
     RewritePatternSet &patterns) {
   patterns.add<MoveInitOperandsToInput>(patterns.getContext());
@@ -706,15 +691,13 @@ struct LinalgFoldUnitExtentDimsPass
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
     RewritePatternSet patterns(context);
-    if (foldOneTripLoopsOnly) {
-      patterns.add<FoldUnitDimLoops, MoveInitOperandsToInput>(context);
-    } else if (useRankReducingSlices) {
-      populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
-      populateMoveInitOperandsToInputPattern(patterns);
-    } else {
-      populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
-      populateMoveInitOperandsToInputPattern(patterns);
+    ControlDropUnitDims options;
+    if (useRankReducingSlices) {
+      options.rankReductionStrategy = linalg::ControlDropUnitDims::
+          RankReductionStrategy::ExtractInsertSlice;
     }
+    linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
+    populateMoveInitOperandsToInputPattern(patterns);
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index b48a1d6e0cf0b4..88659f8628ae70 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -59,24 +59,24 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
   library_call = "some_external_func"
 }
 
-func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
+func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<1x1x?x1x1xf32>) -> tensor<1x1x?x1x1xf32> {
   %0 = linalg.generic #trait
      ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32)
-    outs(%shape : tensor<?x1x?x1x?xf32>) {
+    outs(%shape : tensor<1x1x?x1x1xf32>) {
        ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
          linalg.yield %arg3 : f32
-       } -> tensor<?x1x?x1x?xf32>
-  return %0 : tensor<?x1x?x1x?xf32>
+       } -> tensor<1x1x?x1x1xf32>
+  return %0 : tensor<1x1x?x1x1xf32>
 }
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
-//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, d0, 0)>
+//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func @drop_one_trip_loops_all_ones
 //       CHECK: tensor.collapse_shape %{{.*}} []
-//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME:   iterator_types = ["parallel"]
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
 
 // -----
 
@@ -922,4 +922,3 @@ func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3>
 // CHECK-SLICES-LABEL: func @drop_all_loops
 //       CHECK-SLICES:   memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref<f32, strided<[]>, 3>
 //       CHECK-SLICES:   linalg.generic{{.*}}memref<f32, strided<[]>, 3>
-

diff  --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
deleted file mode 100644
index 2f265a72fd7bf7..00000000000000
--- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
+++ /dev/null
@@ -1,110 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(linalg-fold-unit-extent-dims{fold-one-trip-loops-only}))" | FileCheck %s
-
-#accesses = [
-  affine_map<(i, j, k, l, m) -> (i, k, m)>,
-  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
-]
-
-#trait = {
-  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
-  indexing_maps = #accesses,
-  library_call = "some_external_func"
-}
-
-func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
-{
-  %0 = linalg.generic #trait
-    ins(%arg0 : tensor<?x1x?xf32>)
-    outs(%shape : tensor<?x1x?x1x?xf32>) {
-       ^bb0(%arg1 : f32, %arg2 : f32) :
-         linalg.yield %arg1 : f32
-       } -> tensor<?x1x?x1x?xf32>
-  return %0 : tensor<?x1x?x1x?xf32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
-//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)>
-// CHECK-LABEL: func @drop_one_trip_loops
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"]
-
-// -----
-
-#map0 = affine_map<(i, j) -> (i, j)>
-#access = [#map0, #map0]
-#trait = {
-  iterator_types = ["parallel", "parallel"],
-  indexing_maps = #access,
-  library_call = "some_external_func"
-}
-
-func.func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
-{
-  %0 = linalg.generic #trait
-     ins(%arg0 : tensor<1x1xf32>)
-    outs(%arg0 : tensor<1x1xf32>) {
-       ^bb0(%arg1: f32, %arg2: f32) :
-         linalg.yield %arg1 : f32
-       } -> tensor<1x1xf32>
-  return %0 : tensor<1x1xf32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
-// CHECK-LABEL: func @drop_all_loops
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
-//  CHECK-SAME:     iterator_types = []
-
-// -----
-
-#map0 = affine_map<(i, j) -> (i, j)>
-#access = [#map0, #map0]
-#trait = {
-  iterator_types = ["parallel", "parallel"],
-  indexing_maps = #access,
-  library_call = "some_external_func"
-}
-
-func.func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
-{
-  linalg.generic #trait
-     ins(%arg0 : memref<1x1xf32>)
-    outs(%arg1 : memref<1x1xf32>) {
-    ^bb0(%arg2: f32, %arg3 : f32) :
-      linalg.yield %arg2 : f32
-    }
-  return
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
-// CHECK-LABEL: func @drop_all_loops
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
-//  CHECK-SAME:     iterator_types = []
-
-// -----
-
-#accesses = [
-  affine_map<(d0, d1) -> (d0, d1)>,
-  affine_map<(d0, d1) -> (d1)>
-]
-
-#trait = {
-  indexing_maps = #accesses,
-  iterator_types = ["parallel", "parallel"],
-  library_call = "some_external_fn"
-}
-
-func.func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf32>) -> tensor<5xf32> {
-  %0 = linalg.generic #trait
-       ins(%arg0 : tensor<1x5xf32>)
-      outs(%shape : tensor<5xf32>) {
-    ^bb0(%arg2: f32, %arg3: f32):
-      linalg.yield %arg2 : f32
-  } -> tensor<5xf32>
-  return %0 : tensor<5xf32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (0, d0)>
-//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @leading_dim_1_canonicalization
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//  CHECK-SAME:     iterator_types = ["parallel"]

diff  --git a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
new file mode 100644
index 00000000000000..35eeffc1f99532
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s
+
+// Drop only the outermost unit dimension (controlled using a control function)
+func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42xf32> {
+  %0 = tensor.empty() : tensor<1x1x42xf32>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : tensor<1x1x42xf32>) outs(%0 : tensor<1x1x42xf32>) {
+      ^bb0(%b0: f32, %b1 : f32):
+        %2 = arith.addf %b0, %b1 : f32
+        linalg.yield %2 : f32
+    } -> tensor<1x1x42xf32>
+  return %1 : tensor<1x1x42xf32>
+}
+// CHECK-LABEL: func @drop_outermost_unit_dims
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1x42xf32>
+//       CHECK:   %[[OUTS:.+]] = tensor.empty()
+//       CHECK:   %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2]{{\]}}
+//       CHECK:   %[[OUTS_RESHAPE:.+]] = tensor.collapse_shape %[[OUTS]] {{\[}}[0, 1], [2]{{\]}}
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0_RESHAPE]] :
+//  CHECK-SAME:       outs(%[[OUTS_RESHAPE]] :
+//       CHECK:   %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}}
+//       CHECK:   return %[[EXPAND_SHAPE]]

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 03aefc8d7117e0..b28f2b3564662a 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRLinalgTestPasses
   TestDataLayoutPropagation.cpp
   TestLinalgDecomposeOps.cpp
+  TestLinalgDropUnitDims.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgTransforms.cpp

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
new file mode 100644
index 00000000000000..a3a6a49d64b003
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -0,0 +1,73 @@
+//===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the transformation to drop unit
+// extent dimensions from `linalg.generic` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
+                                    linalg::GenericOp genericOp) {
+  linalg::ControlDropUnitDims options;
+  options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
+  return linalg::dropUnitDims(rewriter, genericOp, options);
+}
+
+struct TestLinalgDropUnitDims
+    : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
+
+  TestLinalgDropUnitDims() = default;
+  TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect>();
+  }
+
+  StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; }
+
+  StringRef getDescriptions() const {
+    return "Test transformation to drop unit-extent dims from Linalg "
+           "operations";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    func::FuncOp funcOp = this->getOperation();
+    IRRewriter rewriter(context);
+    SmallVector<linalg::GenericOp> genericOps;
+    funcOp.walk(
+        [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); });
+
+    for (auto genericOp : genericOps) {
+      rewriter.setInsertionPoint(genericOp);
+      dropOutermostUnitDims(rewriter, genericOp);
+    }
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgDropUnitDims() {
+  PassRegistration<TestLinalgDropUnitDims>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e4ba8ab36393bc..08bb6445fb0bc8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -100,6 +100,7 @@ void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
 void registerTestLastModifiedPass();
 void registerTestLinalgDecomposeOps();
+void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgTransforms();
@@ -222,6 +223,7 @@ void registerTestPasses() {
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLastModifiedPass();
   mlir::test::registerTestLinalgDecomposeOps();
+  mlir::test::registerTestLinalgDropUnitDims();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
   mlir::test::registerTestLinalgTransforms();


        


More information about the Mlir-commits mailing list