[Mlir-commits] [mlir] 0886ea9 - [mlir][Arith] Fix a use-after-free after rewriting ops to unsigned

Benjamin Kramer llvmlistbot at llvm.org
Wed Jun 15 01:29:12 PDT 2022


Author: Benjamin Kramer
Date: 2022-06-15T10:28:43+02:00
New Revision: 0886ea902b1417acc8fa31b7c9fbaa6a1ab40e8f

URL: https://github.com/llvm/llvm-project/commit/0886ea902b1417acc8fa31b7c9fbaa6a1ab40e8f
DIFF: https://github.com/llvm/llvm-project/commit/0886ea902b1417acc8fa31b7c9fbaa6a1ab40e8f.diff

LOG: [mlir][Arith] Fix a use-after-free after rewriting ops to unsigned

Just short-circuit when a change was made, the erased value is invalid
after that. Found by asan.

This pass looks like it could use rewrite patterns instead which don't
have this issue, but let's fix the asan build first.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index 30fb51725dcb0..5cecc69285bea 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -90,7 +90,7 @@ static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
 }
 
 template <typename T, typename U>
-static void rewriteOp(Operation *op, OpBuilder &b) {
+static bool rewriteOp(Operation *op, OpBuilder &b) {
   if (isa<T>(op)) {
     OpBuilder::InsertionGuard guard(b);
     b.setInsertionPoint(op);
@@ -98,28 +98,31 @@ static void rewriteOp(Operation *op, OpBuilder &b) {
                                    op->getOperands(), op->getAttrs());
     op->replaceAllUsesWith(newOp->getResults());
     op->erase();
+    return true;
   }
+  return false;
 }
 
-static void rewriteCmpI(Operation *op, OpBuilder &b) {
+static bool rewriteCmpI(Operation *op, OpBuilder &b) {
   if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
     cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
         b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
+    return true;
   }
+  return false;
 }
 
 static void rewrite(Operation *root, const OpList &toReplace) {
   OpBuilder b(root->getContext());
   b.setInsertionPoint(root);
   for (Operation *op : toReplace) {
-    rewriteOp<DivSIOp, DivUIOp>(op, b);
-    rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b);
-    rewriteOp<FloorDivSIOp, DivUIOp>(op, b);
-    rewriteOp<RemSIOp, RemUIOp>(op, b);
-    rewriteOp<MinSIOp, MinUIOp>(op, b);
-    rewriteOp<MaxSIOp, MaxUIOp>(op, b);
-    rewriteOp<ExtSIOp, ExtUIOp>(op, b);
-    rewriteCmpI(op, b);
+    rewriteOp<DivSIOp, DivUIOp>(op, b) ||
+        rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b) ||
+        rewriteOp<FloorDivSIOp, DivUIOp>(op, b) ||
+        rewriteOp<RemSIOp, RemUIOp>(op, b) ||
+        rewriteOp<MinSIOp, MinUIOp>(op, b) ||
+        rewriteOp<MaxSIOp, MaxUIOp>(op, b) ||
+        rewriteOp<ExtSIOp, ExtUIOp>(op, b) || rewriteCmpI(op, b);
   }
 }
 


        


More information about the Mlir-commits mailing list