[Mlir-commits] [mlir] [mlir][vector] Notify the rewriter when sinking out of warp ops (PR #71964)
Quinn Dawkins
llvmlistbot at llvm.org
Fri Nov 10 10:58:04 PST 2023
https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/71964
>From 48a0036b5fbb38549e582df67152af8692b63308 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Fri, 10 Nov 2023 12:31:23 -0500
Subject: [PATCH 1/3] [mlir][vector] NFC: Notify the rewriter when sinking out
of warp ops
---
.../Vector/Transforms/VectorDistribute.cpp | 30 +++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 645caa9c1378821..ffe2a3e44fdec03 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -645,6 +645,10 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
});
if (!yieldOperand)
return failure();
+
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+
Operation *elementWise = yieldOperand->get().getDefiningOp();
unsigned operandIndex = yieldOperand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
@@ -683,6 +687,7 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
{newWarpOp.getResult(operandIndex).getType()});
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
newOp->getResult(0));
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -713,6 +718,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
auto newAttr = DenseElementsAttr::get(
@@ -721,6 +728,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -823,7 +831,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
OpBuilder::InsertionGuard g(rewriter);
WarpExecuteOnLane0Op newWarpOp = warpOp;
Value newMask = read.getMask();
+ bool hasMask = false;
if (read.getMask()) {
+ hasMask = true;
// TODO: Distribution of masked reads with non-trivial permutation maps
// requires the distribution of the mask to elementwise match the
// distribution of the permuted written vector. Currently the details
@@ -894,6 +904,11 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
rewriter.replaceAllUsesWith(distributedVal, newRead);
+ if (hasMask) {
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+ rewriter.finalizeRootUpdate(warpOp);
+ }
return success();
}
};
@@ -996,7 +1011,10 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
if (!valForwarded)
return failure();
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -1024,6 +1042,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
@@ -1032,6 +1052,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
broadcasted);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -1046,6 +1067,10 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
if (!operand)
return failure();
+
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1074,6 +1099,7 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -1133,6 +1159,9 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
mask, "cannot delinearize lane ID for distribution");
assert(!delinearizedIds.empty());
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
SmallVector<Value> newOperands;
@@ -1151,6 +1180,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto newMask =
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
>From 3604a33a7121841c05fda2ba7c54f7fa0ce0b48f Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Fri, 10 Nov 2023 13:05:15 -0500
Subject: [PATCH 2/3] Address comments and fix shape cast
---
.../Vector/Transforms/VectorDistribute.cpp | 22 ++++++++++---------
.../Vector/vector-warp-distribute.mlir | 20 +++++++++++++++++
2 files changed, 32 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index ffe2a3e44fdec03..482b24069868fe1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -850,6 +850,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
newRetIndices);
newMask = newWarpOp.getResult(newRetIndices[0]);
distributedVal = newWarpOp.getResult(operandIndex);
+ } else {
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
}
rewriter.setInsertionPointAfter(newWarpOp);
@@ -859,9 +862,12 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
distributedType.getShape(), newWarpOp.getWarpSize(),
- newWarpOp.getLaneid(), delinearizedIds))
+ newWarpOp.getLaneid(), delinearizedIds)) {
+ if (!hasMask)
+ rewriter.cancelRootUpdate(warpOp);
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
+ }
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
@@ -900,15 +906,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
return (newRead.getMask() && value == newRead.getMask()) ||
newWarpOp.isDefinedOutsideOfRegion(value);
- }))
+ })) {
+ if (!hasMask)
+ rewriter.cancelRootUpdate(warpOp);
return failure();
+ }
rewriter.replaceAllUsesWith(distributedVal, newRead);
- if (hasMask) {
- // Notify the rewriter that the warp op is changing.
- rewriter.startRootUpdate(warpOp);
+ if (!hasMask)
rewriter.finalizeRootUpdate(warpOp);
- }
return success();
}
};
@@ -1068,9 +1074,6 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!operand)
return failure();
- // Notify the rewriter that the warp op is changing.
- rewriter.startRootUpdate(warpOp);
-
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1099,7 +1102,6 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
- rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 1821190c44e3af4..8056260f4610977 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1351,6 +1351,26 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
// -----
+func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>, vector<2xf32>) {
+ %mask = vector.create_mask %mask_ub: vector<128xi1>
+ %0 = vector.transfer_read %src[%c0, %index], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+ %1 = vector.transfer_read %src[%c0, %index2], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+ vector.yield %0, %1 : vector<128xf32>, vector<128xf32>
+ }
+ return %r#0, %r#1 : vector<2xf32>, vector<2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_masked_transfer_read_shared_mask
+// CHECK-PROP: vector.create_mask %{{.*}} : vector<2xi1>
+// CHECK-PROP: vector.transfer_read %{{.*}} : memref<4096x4096xf32>, vector<2xf32>
+// CHECK-PROP: vector.create_mask %{{.*}} : vector<2xi1>
+// CHECK-PROP: vector.transfer_read %{{.*}} : memref<4096x4096xf32>, vector<2xf32>
+
+// -----
+
func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref<128xf32>, %f1: f32) -> (vector<2xf32>, vector<4xf32>) {
%f0 = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
>From 69701ab6a66a3f4259bb9e759863f343bb8d461c Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Fri, 10 Nov 2023 13:57:47 -0500
Subject: [PATCH 3/3] Add comment explaining why this is needed
---
.../Vector/Transforms/VectorDistribute.cpp | 24 ++++++++++++++-----
1 file changed, 18 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 482b24069868fe1..ac73cf07004ded8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -646,7 +646,8 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!yieldOperand)
return failure();
- // Notify the rewriter that the warp op is changing.
+ // Notify the rewriter that the warp op is changing (see the comment on
+ // the WarpOpTransferRead pattern).
rewriter.startRootUpdate(warpOp);
Operation *elementWise = yieldOperand->get().getDefiningOp();
@@ -718,7 +719,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
- // Notify the rewriter that the warp op is changing.
+ // Notify the rewriter that the warp op is changing (see the comment on
+ // the WarpOpTransferRead pattern).
rewriter.startRootUpdate(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
@@ -851,7 +853,14 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
newMask = newWarpOp.getResult(newRetIndices[0]);
distributedVal = newWarpOp.getResult(operandIndex);
} else {
- // Notify the rewriter that the warp op is changing.
+ // This pattern does not actually change the warp op directly. Instead it
+ // just rewrites a new transfer read (when not masked) outside of the warp
+ // op and replaces the correponding result. There are then follow up
+ // patterns to erase now dead results of the warp op. This erasure allows
+ // propagation to continue, but this pattern on its own never actually
+ // tells the pattern rewriter that the warp op "changed." Notify the
+ // rewriter here that the warp op is changing. Similar situations are
+ // noted in following patterns.
rewriter.startRootUpdate(warpOp);
}
@@ -1017,7 +1026,8 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
if (!valForwarded)
return failure();
- // Notify the rewriter that the warp op is changing.
+ // Notify the rewriter that the warp op is changing (see the comment on
+ // the WarpOpTransferRead pattern).
rewriter.startRootUpdate(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
rewriter.finalizeRootUpdate(warpOp);
@@ -1048,7 +1058,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
- // Notify the rewriter that the warp op is changing.
+ // Notify the rewriter that the warp op is changing (see the comment on
+ // the WarpOpTransferRead pattern).
rewriter.startRootUpdate(warpOp);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1161,7 +1172,8 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
mask, "cannot delinearize lane ID for distribution");
assert(!delinearizedIds.empty());
- // Notify the rewriter that the warp op is changing.
+ // Notify the rewriter that the warp op is changing (see the comment on
+ // the WarpOpTransferRead pattern).
rewriter.startRootUpdate(warpOp);
AffineExpr s0, s1;
More information about the Mlir-commits
mailing list