[PATCH] D77727: Only insert memref_cast when needed during canonicalization.

Pierre Oechsel via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 8 07:02:33 PDT 2020


poechsel created this revision.
Herald added subscribers: llvm-commits, frgossen, grosul1, Joonsoo, liufengdb, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, burmako, jpienaar, rriddle, mehdi_amini.
Herald added a project: LLVM.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D77727

Files:
  mlir/lib/Dialect/StandardOps/IR/Ops.cpp


Index: mlir/lib/Dialect/StandardOps/IR/Ops.cpp
===================================================================
--- mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -309,6 +309,33 @@
 }
 
 namespace {
+// Create a memref_cast when needed to convert from newType to the type of
+// oldOp. Called when canonilizing a view or subview op that changed the type of
+// the memref (from a dynamic-sized memref to a statically-sized memref). Only
+// inserts the memref_cast when the corresponding memref is used as argument for
+// a function call or is returned directly.
+void replaceOpInsertMemrefCastWhenNeeded(PatternRewriter &rewriter,
+                                         Value oldValue, Value newValue) {
+  bool canSkipMemrefCast = true;
+  auto oldOp = oldValue.getDefiningOp();
+  auto uses = oldOp->getUses();
+  for (auto it = uses.begin(); it != uses.end(); it++) {
+    auto *user = it.getUser();
+    canSkipMemrefCast &= !(isa<CallOp>(user)
+                           || isa<CallIndirectOp>(user)
+                           || isa<ReturnOp>(user));
+  }
+  if (canSkipMemrefCast) {
+    for (auto it = uses.begin(); it != uses.end(); it++)
+      it->get().setType(newValue.getType());
+    rewriter.replaceOp(oldOp, {newValue});
+  } else {
+    auto castOp = rewriter.create<MemRefCastOp>(oldValue.getLoc(), newValue,
+                                                oldValue.getType());
+    rewriter.replaceOp(oldOp, {castOp});
+  }
+}
+
 /// Fold constant dimensions into an alloc operation.
 struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
   using OpRewritePattern<AllocOp>::OpRewritePattern;
@@ -359,11 +386,7 @@
     // Create and insert the alloc op for the new memref.
     auto newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType,
                                              newOperands, IntegerAttr());
-    // Insert a cast so we have the same type as the old alloc.
-    auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
-                                                    alloc.getType());
-
-    rewriter.replaceOp(alloc, {resultCast});
+    replaceOpInsertMemrefCastWhenNeeded(rewriter, alloc, newAlloc);
     return success();
   }
 };
@@ -2025,9 +2048,7 @@
     auto newSubViewOp = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
         ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
-    // Insert a memref_cast for compatibility of the uses of the op.
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
-                                              subViewOp.getType());
+    replaceOpInsertMemrefCastWhenNeeded(rewriter, subViewOp, newSubViewOp);
     return success();
   }
 };
@@ -2074,9 +2095,7 @@
     auto newSubViewOp = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
         subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
-    // Insert a memref_cast for compatibility of the uses of the op.
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
-                                              subViewOp.getType());
+    replaceOpInsertMemrefCastWhenNeeded(rewriter, subViewOp, newSubViewOp);
     return success();
   }
 };
@@ -2125,9 +2144,7 @@
     auto newSubViewOp = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
         subViewOp.sizes(), subViewOp.strides(), newMemRefType);
-    // Insert a memref_cast for compatibility of the uses of the op.
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
-                                              subViewOp.getType());
+    replaceOpInsertMemrefCastWhenNeeded(rewriter, subViewOp, newSubViewOp);
     return success();
   }
 };
@@ -2442,8 +2459,7 @@
     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
                                              viewOp.getOperand(0), newOperands);
     // Insert a cast so we have the same type as the old memref type.
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
-                                              viewOp.getType());
+    replaceOpInsertMemrefCastWhenNeeded(rewriter, viewOp, newViewOp);
     return success();
   }
 };


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D77727.256008.patch
Type: text/x-patch
Size: 4343 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200408/1f02d9e6/attachment.bin>


More information about the llvm-commits mailing list