[Mlir-commits] [mlir] 4dbaef6 - [mlir][Linalg] Avoid doing op replacement in `linalg::dropUnitDims`. (#105749)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 23 13:43:37 PDT 2024


Author: MaheshRavishankar
Date: 2024-08-23T13:43:33-07:00
New Revision: 4dbaef6d5ea71fb183114a82da4028960906c42b

URL: https://github.com/llvm/llvm-project/commit/4dbaef6d5ea71fb183114a82da4028960906c42b
DIFF: https://github.com/llvm/llvm-project/commit/4dbaef6d5ea71fb183114a82da4028960906c42b.diff

LOG: [mlir][Linalg] Avoid doing op replacement in `linalg::dropUnitDims`. (#105749)

It is better to do the replacement in the caller. This avoids the
footgun if the caller needs the original operation. Instead return the
produced operation and replacement values.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bee3452ebb685f..0208f854f799ec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -488,8 +488,13 @@ struct ControlDropUnitDims {
     return SmallVector<unsigned>{};
   };
 };
-LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
-                           const ControlDropUnitDims &options);
+struct DropUnitDimsResult {
+  linalg::GenericOp resultOp;
+  SmallVector<Value> replacements;
+};
+FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
+                                           GenericOp genericOp,
+                                           const ControlDropUnitDims &options);
 
 /// Fuse two `linalg.generic` operations that have a producer-consumer
 /// relationship captured through `fusedOperand`. The method expects

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 36f8696bf1b274..88ef82fb38d67b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -386,8 +386,9 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
   return info;
 }
 
-LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
-                                   const ControlDropUnitDims &options) {
+FailureOr<DropUnitDimsResult>
+linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+                     const ControlDropUnitDims &options) {
   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
   if (indexingMaps.empty())
     return failure();
@@ -545,8 +546,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
     resultReplacements.push_back(expandedValue);
   }
 
-  rewriter.replaceOp(genericOp, resultReplacements);
-  return success();
+  return DropUnitDimsResult{replacementOp, resultReplacements};
 }
 
 namespace {
@@ -557,7 +557,13 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    return dropUnitDims(rewriter, genericOp, options);
+    FailureOr<DropUnitDimsResult> result =
+        dropUnitDims(rewriter, genericOp, options);
+    if (failed(result)) {
+      return failure();
+    }
+    rewriter.replaceOp(genericOp, result->replacements);
+    return success();
   }
 
 private:

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
index 85a6d5f9d9215c..402ce154c0848e 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -25,7 +25,13 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
                                     linalg::GenericOp genericOp) {
   linalg::ControlDropUnitDims options;
   options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
-  return linalg::dropUnitDims(rewriter, genericOp, options);
+  FailureOr<linalg::DropUnitDimsResult> result =
+      linalg::dropUnitDims(rewriter, genericOp, options);
+  if (failed(result)) {
+    return failure();
+  }
+  rewriter.replaceOp(genericOp, result->replacements);
+  return success();
 }
 
 struct TestLinalgDropUnitDims


        


More information about the Mlir-commits mailing list