[Mlir-commits] [mlir] [mlir][linalg] Drop unit dims on IndexingMapOpInterface (PR #150280)

Ian Wood llvmlistbot at llvm.org
Wed Jul 23 11:22:22 PDT 2025


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/150280

>From cc8510076d54d4f520f75bab4d87dda7910509cc Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 21 Jul 2025 17:41:50 +0100
Subject: [PATCH 1/2] [mlir][linalg] Move dropUnitDims to work on
 IndexingMapOpInterface

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  12 +-
 .../Linalg/Transforms/DropUnitDims.cpp        | 121 +++++++++++-------
 2 files changed, 88 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 38e53648e7c34..8d4abb0d5810c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -537,10 +537,20 @@ struct ControlDropUnitDims {
     return SmallVector<unsigned>{};
   };
 };
+
 struct DropUnitDimsResult {
-  linalg::GenericOp resultOp;
+  IndexingMapOpInterface resultOp;
   SmallVector<Value> replacements;
 };
+using DroppedUnitDimsBuilder = llvm::function_ref<IndexingMapOpInterface(
+    Location loc, OpBuilder &, IndexingMapOpInterface,
+    ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+    const llvm::SmallDenseSet<unsigned> &droppedDims)>;
+
+FailureOr<DropUnitDimsResult>
+dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+             DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
+             const ControlDropUnitDims &options);
 FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
                                            GenericOp genericOp,
                                            const ControlDropUnitDims &options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..1312add2f9298 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -331,14 +331,14 @@ struct UnitExtentReplacementInfo {
   SmallVector<int64_t> targetShape;
 };
 static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
-    MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+    MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
     llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
     ArrayRef<AffineExpr> dimReplacements) {
   UnitExtentReplacementInfo info;
   ReassociationIndices reassociationGroup;
   SmallVector<AffineExpr> newIndexExprs;
-  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
-  ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
+  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
+  SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
 
   auto isUnitDim = [&](unsigned dim) {
@@ -380,9 +380,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
 }
 
 FailureOr<DropUnitDimsResult>
-linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+                     DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
                      const ControlDropUnitDims &options) {
-  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+  if (!dpsOp) {
+    return rewriter.notifyMatchFailure(
+        op, "op should implement DestinationStyleOpInterface");
+  }
+
+  SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
   if (indexingMaps.empty())
     return failure();
 
@@ -392,19 +399,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
   AffineMap invertedMap =
       inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
   if (!invertedMap) {
-    return rewriter.notifyMatchFailure(genericOp,
+    return rewriter.notifyMatchFailure(op,
                                        "invalid indexing maps for operation");
   }
 
   SmallVector<int64_t> allShapesSizes;
-  for (OpOperand &opOperand : genericOp->getOpOperands())
-    llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
+  for (OpOperand &opOperand : op->getOpOperands())
+    llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
 
   // 1a. Get the allowed list of dimensions to drop from the `options`.
-  SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+  SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
   if (allowedUnitDims.empty()) {
     return rewriter.notifyMatchFailure(
-        genericOp, "control function returns no allowed unit dims to prune");
+        op, "control function returns no allowed unit dims to prune");
   }
   llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
                                                allowedUnitDims.end());
@@ -417,19 +424,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
     }
   }
 
-  // 2. Compute the iterator types of the modified op by dropping the one-trip
+  // 2. Compute the new loops of the modified op by dropping the one-trip
   //    count loops.
-  SmallVector<utils::IteratorType> newIteratorTypes;
   llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
   SmallVector<AffineExpr> dimReplacements;
   unsigned newDims = 0;
-  for (auto [index, attr] :
-       llvm::enumerate(genericOp.getIteratorTypesArray())) {
+  for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
     if (unitDims.count(index)) {
       dimReplacements.push_back(
           getAffineConstantExpr(0, rewriter.getContext()));
     } else {
-      newIteratorTypes.push_back(attr);
       oldDimToNewDimMap[index] = newDims;
       dimReplacements.push_back(
           getAffineDimExpr(newDims, rewriter.getContext()));
@@ -462,9 +466,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
     }
     return false;
   };
-  for (OpOperand &opOperand : genericOp->getOpOperands()) {
-    auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
-    ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    auto indexingMap = op.getMatchingIndexingMap(&opOperand);
+    SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
     if (!hasCollapsibleType(opOperand)) {
       AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
           dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
@@ -474,9 +478,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
       reassociations.push_back({});
       continue;
     }
-    auto replacementInfo = dropUnitExtentFromOperandMetadata(
-        rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
-        dimReplacements);
+    auto replacementInfo =
+        dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
+                                          oldDimToNewDimMap, dimReplacements);
     reassociations.push_back(replacementInfo.reassociation);
     newIndexingMaps.push_back(replacementInfo.indexMap);
     targetShapes.push_back(replacementInfo.targetShape);
@@ -491,13 +495,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
           concatAffineMaps(newIndexingMaps, rewriter.getContext())))
     return failure();
 
-  Location loc = genericOp.getLoc();
+  Location loc = op.getLoc();
   // 4. For each of the operands, collapse the operand to convert
   //    from original shape to shape in the modified operation if needed,
   //    either through use of reshapes or rank-reducing slices as
   //    specified in `options`.
   SmallVector<Value> newOperands;
-  for (OpOperand &opOperand : genericOp->getOpOperands()) {
+  for (OpOperand &opOperand : op->getOpOperands()) {
     int64_t idx = opOperand.getOperandNumber();
     if (!collapsed[idx]) {
       newOperands.push_back(opOperand.get());
@@ -508,31 +512,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
                                         options.rankReductionStrategy));
   }
 
-  // 5. Create the `linalg.generic` operation with the new operands,
-  //    indexing maps, iterator types and result types.
-  ArrayRef<Value> newInputs =
-      ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
-  ArrayRef<Value> newOutputs =
-      ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
-  SmallVector<Type> resultTypes;
-  resultTypes.reserve(genericOp.getNumResults());
-  for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
-    resultTypes.push_back(newOutputs[i].getType());
-  GenericOp replacementOp =
-      rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
-                                 newIndexingMaps, newIteratorTypes);
-  rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
-                              replacementOp.getRegion().begin());
-  // 5a. Replace `linalg.index` operations that refer to the dropped unit
-  //     dimensions.
-  replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+  IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
+      loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
 
   // 6. If any result type changes, insert a reshape/slice to convert from the
   //    original type to the new type.
   SmallVector<Value> resultReplacements;
-  for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
-    unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
-    Value origDest = genericOp.getDpsInitOperand(index)->get();
+  for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
+    unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
+    Value origDest = dpsOp.getDpsInitOperand(index)->get();
     if (!collapsed[opOperandIndex]) {
       resultReplacements.push_back(result);
       continue;
@@ -546,6 +534,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
   return DropUnitDimsResult{replacementOp, resultReplacements};
 }
 
+FailureOr<DropUnitDimsResult>
+linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+                     const ControlDropUnitDims &options) {
+
+  DroppedUnitDimsBuilder build =
+      [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
+         ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+         const llvm::SmallDenseSet<unsigned> &droppedDims)
+      -> IndexingMapOpInterface {
+    auto genericOp = cast<GenericOp>(op);
+    // Compute the iterator types of the modified op by dropping the one-trip
+    // count loops.
+    SmallVector<utils::IteratorType> newIteratorTypes;
+    for (auto [index, attr] :
+         llvm::enumerate(genericOp.getIteratorTypesArray())) {
+      if (!droppedDims.count(index))
+        newIteratorTypes.push_back(attr);
+    }
+
+    // Create the `linalg.generic` operation with the new operands,
+    //    indexing maps, iterator types and result types.
+    ArrayRef<Value> newInputs =
+        ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+    ArrayRef<Value> newOutputs =
+        ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+    SmallVector<Type> resultTypes;
+    resultTypes.reserve(genericOp.getNumResults());
+    for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+      resultTypes.push_back(newOutputs[i].getType());
+    GenericOp replacementOp =
+        b.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
+                            newIndexingMaps, newIteratorTypes);
+    b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+                        replacementOp.getRegion().begin());
+    // 5a. Replace `linalg.index` operations that refer to the dropped unit
+    //     dimensions.
+    IRRewriter rewriter(b);
+    replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
+
+    return replacementOp;
+  };
+
+  return dropUnitDims(rewriter, genericOp, build, options);
+}
+
 namespace {
 struct DropUnitDims : public OpRewritePattern<GenericOp> {
   DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},

>From 461a58c6a98463822cfe2ae982cd283f291e43f3 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood at u.northwestern.edu>
Date: Wed, 23 Jul 2025 10:38:33 -0700
Subject: [PATCH 2/2] Fix Wdangling

Signed-off-by: Ian Wood <ianwood at u.northwestern.edu>
---
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d4abb0d5810c..e625eefac5f78 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -542,14 +542,14 @@ struct DropUnitDimsResult {
   IndexingMapOpInterface resultOp;
   SmallVector<Value> replacements;
 };
-using DroppedUnitDimsBuilder = llvm::function_ref<IndexingMapOpInterface(
+using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface(
     Location loc, OpBuilder &, IndexingMapOpInterface,
     ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
     const llvm::SmallDenseSet<unsigned> &droppedDims)>;
 
 FailureOr<DropUnitDimsResult>
 dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
-             DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
+             const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
              const ControlDropUnitDims &options);
 FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
                                            GenericOp genericOp,



More information about the Mlir-commits mailing list