[Mlir-commits] [mlir] 6a57d8f - [mlir][vector] Untangle TransferWriteDistribution and avoid crashing in the 0-D case.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 1 00:15:41 PDT 2022


Author: Nicolas Vasilache
Date: 2022-07-01T00:15:34-07:00
New Revision: 6a57d8fba5b3c319d81eb91afbd02d22ab1c8662

URL: https://github.com/llvm/llvm-project/commit/6a57d8fba5b3c319d81eb91afbd02d22ab1c8662
DIFF: https://github.com/llvm/llvm-project/commit/6a57d8fba5b3c319d81eb91afbd02d22ab1c8662.diff

LOG: [mlir][vector] Untangle TransferWriteDistribution and avoid crashing in the 0-D case.

This revision avoids a crash in the 0-D case of distributing vector.transfer ops out of
vector.warp_execute_on_lane_0.
Due to the code complexity and lack of documentation, it took untangling the implementation
before realizing that the simple fix was to fail in the 0-D case.
The rewrite is still very useful to understand this code better.

Differential Revision: https://reviews.llvm.org/D128793

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 08eced2bd935e..bf6e2225e6f20 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -262,6 +262,28 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
   const WarpExecuteOnLane0LoweringOptions &options;
 };
 
+/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
+/// op with the proper return type.
+/// The new write op is updated to write the result of the new warp execute op.
+/// The old `writeOp` is deleted.
+static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
+                                            WarpExecuteOnLane0Op warpOp,
+                                            vector::TransferWriteOp writeOp,
+                                            VectorType targetType) {
+  assert(writeOp->getParentOp() == warpOp &&
+         "write must be nested immediately under warp");
+  OpBuilder::InsertionGuard g(rewriter);
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+      rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+      TypeRange{targetType});
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto newWriteOp =
+      cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+  rewriter.eraseOp(writeOp);
+  newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+  return newWriteOp;
+}
+
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`.
 /// Example:
@@ -290,11 +312,21 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
   LogicalResult tryDistributeOp(RewriterBase &rewriter,
                                 vector::TransferWriteOp writeOp,
                                 WarpExecuteOnLane0Op warpOp) const {
+    VectorType writtenVectorType = writeOp.getVectorType();
+
+    // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
+    // to separate it from the rest.
+    if (writtenVectorType.getRank() == 0)
+      return failure();
+
+    // 2. Compute the distribution map.
     AffineMap map = distributionMapFn(writeOp);
-    SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
-                                     writeOp.getVectorType().getShape().end());
-    assert(map.getNumResults() == 1 &&
-           "multi-dim distribution not implemented yet");
+    if (map.getNumResults() != 1)
+      return writeOp->emitError("multi-dim distribution not implemented yet");
+
+    // 3. Compute the targetType using the distribution map.
+    SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
+                                     writtenVectorType.getShape().end());
     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
       unsigned position = map.getDimPosition(i);
       if (targetShape[position] % warpOp.getWarpSize() != 0)
@@ -302,20 +334,16 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
     }
     VectorType targetType =
-        VectorType::get(targetShape, writeOp.getVectorType().getElementType());
-
-    SmallVector<Value> yieldValues = {writeOp.getVector()};
-    SmallVector<Type> retTypes = {targetType};
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, yieldValues, retTypes);
-    rewriter.setInsertionPointAfter(newWarpOp);
+        VectorType::get(targetShape, writtenVectorType.getElementType());
 
-    // Move op outside of region: Insert clone at the insertion point and delete
-    // the old op.
-    auto newWriteOp =
-        cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
-    rewriter.eraseOp(writeOp);
+    // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
+    // the rest.
+    vector::TransferWriteOp newWriteOp =
+        cloneWriteOp(rewriter, warpOp, writeOp, targetType);
 
+    // 5. Reindex the write using the distribution map.
+    auto newWarpOp =
+        newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
     rewriter.setInsertionPoint(newWriteOp);
     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
     Location loc = newWriteOp.getLoc();
@@ -329,13 +357,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
         continue;
       unsigned indexPos = indexExpr.getPosition();
       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
-      auto scale =
-          getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
+      auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
       indices[indexPos] =
           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
                                   {indices[indexPos], newWarpOp.getLaneid()});
     }
-    newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
     newWriteOp.getIndicesMutable().assign(indices);
 
     return success();
@@ -634,7 +660,6 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Value broadcasted = rewriter.create<vector::BroadcastOp>(
         loc, destVecType, newWarpOp->getResults().back());
     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
-
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 084a935ce23eb..718a7bf1c6b08 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -491,3 +491,23 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
   }
   return %r : f32
 }
+
+// -----
+
+func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
+  %c0 = arith.constant 0: index
+  %f0 = arith.constant 0.0: f32
+  //     CHECK-D: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
+  //     CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+  //     CHECK-D:   vector.transfer_write %[[R]], %{{.*}}[] : vector<f32>, memref<f32>
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32>
+    %1 = vector.transfer_read %m1[], %f0 : memref<f32>, vector<f32>
+    %2 = vector.extractelement %1[] : vector<f32>
+    %3 = vector.reduction <add>, %0 : vector<32xf32> into f32
+    %4 = arith.addf %3, %2 : f32
+    %5 = vector.broadcast %4 : f32 to vector<f32>
+    vector.transfer_write %5, %m1[] : vector<f32>, memref<f32>
+  }
+  return 
+}


        


More information about the Mlir-commits mailing list