[Mlir-commits] [mlir] f6c79c6 - [mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks in the region

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jun 24 07:34:04 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-24T07:33:58-07:00
New Revision: f6c79c6ae49f3a642bebe32a2346186c38bb83d7

URL: https://github.com/llvm/llvm-project/commit/f6c79c6ae49f3a642bebe32a2346186c38bb83d7
DIFF: https://github.com/llvm/llvm-project/commit/f6c79c6ae49f3a642bebe32a2346186c38bb83d7.diff

LOG: [mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks in the region

Differential Revision: https://reviews.llvm.org/D128534

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 9f308d46900d5..08ea44225d4f0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -46,8 +46,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     Value bbArg = warpOpBody->getArgument(it.index());
 
     rewriter.setInsertionPoint(ifOp);
-    Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
-                                            bbArg.getType());
+    Value buffer =
+        options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
 
     // Store arg vector into buffer.
     rewriter.setInsertionPoint(ifOp);
@@ -68,7 +68,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
   // Insert sync after all the stores and before all the loads.
   if (!warpOp.getArgs().empty()) {
     rewriter.setInsertionPoint(ifOp);
-    options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
+    options.warpSyncronizationFn(loc, rewriter, warpOp);
   }
 
   // Move body of warpOp to ifOp.
@@ -82,8 +82,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     Value val = it.value();
     Type resultType = warpOp->getResultTypes()[it.index()];
     rewriter.setInsertionPoint(ifOp);
-    Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
-                                            val.getType());
+    Value buffer =
+        options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
 
     // Store yielded value into buffer.
     rewriter.setInsertionPoint(yieldOp);
@@ -121,7 +121,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
   // Insert sync after all the stores and before all the loads.
   if (!yieldOp.operands().empty()) {
     rewriter.setInsertionPointAfter(ifOp);
-    options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
+    options.warpSyncronizationFn(loc, rewriter, warpOp);
   }
 
   // Delete terminator and add empty scf.yield.
@@ -148,7 +148,12 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
 
   Region &opBody = warpOp.getBodyRegion();
   Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
   auto yield =
       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
 


        


More information about the Mlir-commits mailing list