[Mlir-commits] [mlir] [mlir][linalg][Transform] Fix use-after-free in `SplitOp::apply` (PR #96390)

Matthias Springer llvmlistbot at llvm.org
Mon Jun 24 12:26:14 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/96390

>From 87eb38226edc6650ca7bac0f48078cf34fee2faa Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 24 Jun 2024 21:25:47 +0200
Subject: [PATCH 1/2] Revert "[mlir] Fix use-after-free introduced in
 a9efcbf490d9b8f46ec37062ca8653b4068000e5."

This reverts commit 48cf6b6bbe7a22bfcd98f82dc7afd21c9decd22f.
---
 .../Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4ef27b1d091274..37467db568c27d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2335,10 +2335,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
   };
 
   auto checkFailureInSplitting =
-      [&](bool hasFailed, Operation *op) -> DiagnosedSilenceableFailure {
+      [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
     if (hasFailed) {
       auto diag = emitDefiniteFailure() << "internal failure in splitting";
-      diag.attachNote(op->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
     return DiagnosedSilenceableFailure::success();
@@ -2376,7 +2376,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
       // Propagate errors.
       DiagnosedSilenceableFailure diag =
-          checkFailureInSplitting(!head && !tail, target);
+          checkFailureInSplitting(!head && !tail, target->getLoc());
       if (diag.isDefiniteFailure())
         return diag;
 
@@ -2408,8 +2408,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
           getDimension(), std::get<1>(pair));
 
       // Propagate errors.
-      DiagnosedSilenceableFailure diagSplit =
-          checkFailureInSplitting(!first.back() && !second.back(), target);
+      DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
+          !first.back() && !second.back(), target->getLoc());
       if (diagSplit.isDefiniteFailure())
         return diag;
 

>From de6b2cb1eca786378ddbc1f867f27343ca7de0ab Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 22 Jun 2024 14:55:35 +0200
Subject: [PATCH 2/2] [mlir][linalg][Transform] Fix use-after-free in
 `SplitOp::apply`

---
 .../TransformOps/LinalgTransformOps.cpp       | 22 ++++++++++---------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 37467db568c27d..4eb334f8bbbfaf 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     }
   } else {
     chunkSizes.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticChunkSizes()));
+                      rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
   auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
     if (getDimension() >= linalgOp.getNumLoops()) {
       auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                          << " does not exist in target op";
+                                         << " does not exist in target op";
       diag.attachNote(loc) << "target op";
       return diag;
     }
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
         break;
 
       linalgOp = cast<LinalgOp>(target);
+      Location loc = target->getLoc();
 
       rewriter.setInsertionPoint(linalgOp);
       std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
       // Propagate errors.
       DiagnosedSilenceableFailure diag =
-          checkFailureInSplitting(!head && !tail, target->getLoc());
+          checkFailureInSplitting(!head && !tail, loc);
       if (diag.isDefiniteFailure())
         return diag;
 
@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     Operation *noSecondPart = nullptr;
     for (const auto &pair : llvm::zip(payload, chunkSizes)) {
       Operation *target = std::get<0>(pair);
+      Location loc = target->getLoc();
       LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
       DiagnosedSilenceableFailure diag =
           checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2408,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
           getDimension(), std::get<1>(pair));
 
       // Propagate errors.
-      DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
-          !first.back() && !second.back(), target->getLoc());
+      DiagnosedSilenceableFailure diagSplit =
+          checkFailureInSplitting(!first.back() && !second.back(), loc);
       if (diagSplit.isDefiniteFailure())
         return diag;
 
@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
 
     auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
       return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
-            return builder.getI64IntegerAttr(value);
-          });
+        return builder.getI64IntegerAttr(value);
+      });
     };
     transformResults.setParams(cast<OpResult>(getTileSizes()),
                                getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto getDefiningOps = [&](ArrayRef<Value> values) {
-        return llvm::map_to_vector(values, [&](Value value) -> Operation * {
-          return value.getDefiningOp();
-        });
+    return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+      return value.getDefiningOp();
+    });
   };
 
   transformResults.set(cast<OpResult>(getTileSizes()),



More information about the Mlir-commits mailing list