[Mlir-commits] [mlir] 6a57d8f - [mlir][vector] Untangle TransferWriteDistribution and avoid crashing in the 0-D case.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jul 1 00:15:41 PDT 2022
Author: Nicolas Vasilache
Date: 2022-07-01T00:15:34-07:00
New Revision: 6a57d8fba5b3c319d81eb91afbd02d22ab1c8662
URL: https://github.com/llvm/llvm-project/commit/6a57d8fba5b3c319d81eb91afbd02d22ab1c8662
DIFF: https://github.com/llvm/llvm-project/commit/6a57d8fba5b3c319d81eb91afbd02d22ab1c8662.diff
LOG: [mlir][vector] Untangle TransferWriteDistribution and avoid crashing in the 0-D case.
This revision avoids a crash in the 0-D case of distributing vector.transfer ops out of
vector.warp_execute_on_lane_0.
Due to the code complexity and lack of documentation, it took untangling the implementation
before realizing that the simple fix was to fail in the 0-D case.
The rewrite is still very useful to understand this code better.
Differential Revision: https://reviews.llvm.org/D128793
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 08eced2bd935e..bf6e2225e6f20 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -262,6 +262,28 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
const WarpExecuteOnLane0LoweringOptions &options;
};
+/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
+/// op with the proper return type.
+/// The new write op is updated to write the result of the new warp execute op.
+/// The old `writeOp` is deleted.
+static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
+ WarpExecuteOnLane0Op warpOp,
+ vector::TransferWriteOp writeOp,
+ VectorType targetType) {
+ assert(writeOp->getParentOp() == warpOp &&
+ "write must be nested immediately under warp");
+ OpBuilder::InsertionGuard g(rewriter);
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+ TypeRange{targetType});
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newWriteOp =
+ cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+ rewriter.eraseOp(writeOp);
+ newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+ return newWriteOp;
+}
+
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
@@ -290,11 +312,21 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
LogicalResult tryDistributeOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
+ VectorType writtenVectorType = writeOp.getVectorType();
+
+ // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
+ // to separate it from the rest.
+ if (writtenVectorType.getRank() == 0)
+ return failure();
+
+ // 2. Compute the distribution map.
AffineMap map = distributionMapFn(writeOp);
- SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
- writeOp.getVectorType().getShape().end());
- assert(map.getNumResults() == 1 &&
- "multi-dim distribution not implemented yet");
+ if (map.getNumResults() != 1)
+ return writeOp->emitError("multi-dim distribution not implemented yet");
+
+ // 3. Compute the targetType using the distribution map.
+ SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
+ writtenVectorType.getShape().end());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
if (targetShape[position] % warpOp.getWarpSize() != 0)
@@ -302,20 +334,16 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
targetShape[position] = targetShape[position] / warpOp.getWarpSize();
}
VectorType targetType =
- VectorType::get(targetShape, writeOp.getVectorType().getElementType());
-
- SmallVector<Value> yieldValues = {writeOp.getVector()};
- SmallVector<Type> retTypes = {targetType};
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, yieldValues, retTypes);
- rewriter.setInsertionPointAfter(newWarpOp);
+ VectorType::get(targetShape, writtenVectorType.getElementType());
- // Move op outside of region: Insert clone at the insertion point and delete
- // the old op.
- auto newWriteOp =
- cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
- rewriter.eraseOp(writeOp);
+ // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
+ // the rest.
+ vector::TransferWriteOp newWriteOp =
+ cloneWriteOp(rewriter, warpOp, writeOp, targetType);
+ // 5. Reindex the write using the distribution map.
+ auto newWarpOp =
+ newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
rewriter.setInsertionPoint(newWriteOp);
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
@@ -329,13 +357,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
- auto scale =
- getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
+ auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
indices[indexPos] =
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
{indices[indexPos], newWarpOp.getLaneid()});
}
- newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
newWriteOp.getIndicesMutable().assign(indices);
return success();
@@ -634,7 +660,6 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value broadcasted = rewriter.create<vector::BroadcastOp>(
loc, destVecType, newWarpOp->getResults().back());
newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
-
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 084a935ce23eb..718a7bf1c6b08 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -491,3 +491,23 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
}
return %r : f32
}
+
+// -----
+
+func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
+ %c0 = arith.constant 0: index
+ %f0 = arith.constant 0.0: f32
+ // CHECK-D: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
+ // CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+ // CHECK-D: vector.transfer_write %[[R]], %{{.*}}[] : vector<f32>, memref<f32>
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+ %0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32>
+ %1 = vector.transfer_read %m1[], %f0 : memref<f32>, vector<f32>
+ %2 = vector.extractelement %1[] : vector<f32>
+ %3 = vector.reduction <add>, %0 : vector<32xf32> into f32
+ %4 = arith.addf %3, %2 : f32
+ %5 = vector.broadcast %4 : f32 to vector<f32>
+ vector.transfer_write %5, %m1[] : vector<f32>, memref<f32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list