[Mlir-commits] [mlir] 4abb9e5 - [mlir][vector] Clean up and generalize lowering of warp_execute to scf
Thomas Raoux
llvmlistbot at llvm.org
Wed Sep 14 10:36:29 PDT 2022
Author: Thomas Raoux
Date: 2022-09-14T17:36:16Z
New Revision: 4abb9e5d2054be9d1a9d2d859675aa9bb9c9a105
URL: https://github.com/llvm/llvm-project/commit/4abb9e5d2054be9d1a9d2d859675aa9bb9c9a105
DIFF: https://github.com/llvm/llvm-project/commit/4abb9e5d2054be9d1a9d2d859675aa9bb9c9a105.diff
LOG: [mlir][vector] Clean up and generalize lowering of warp_execute to scf
Simplify the lowering of warp_execute_on_lane0 of scf.if by making the
logic more generic. Also remove the assumption that the most inner
dimension is the dimension distributed.
Differential Revision: https://reviews.llvm.org/D133826
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3064bdf17bb79..7bc1799a1fe58 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -20,20 +20,33 @@
using namespace mlir;
using namespace mlir::vector;
-/// TODO: add an analysis step that determines which vector dimension should be
-/// used for distribution.
-static llvm::Optional<int64_t>
-getDistributedVectorDim(VectorType distributedVectorType) {
- return (distributedVectorType)
- ? llvm::Optional<int64_t>(distributedVectorType.getRank() - 1)
- : llvm::None;
-}
-
-static llvm::Optional<int64_t>
-getDistributedSize(VectorType distributedVectorType) {
- auto dim = getDistributedVectorDim(distributedVectorType);
- return (dim) ? llvm::Optional<int64_t>(distributedVectorType.getDimSize(*dim))
- : llvm::None;
+/// Currently the distribution map is implicit based on the vector shape. In the
+/// future it will be part of the op.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
+/// ...
+/// vector.yield %3 : vector<32x16x64xf32>
+/// }
+/// ```
+/// Would have an implicit map of:
+/// `(d0, d1, d2) -> (d0, d2)`
+static AffineMap calculateImplicitMap(VectorType sequentialType,
+ VectorType distributedType) {
+ SmallVector<AffineExpr> perm;
+ perm.reserve(1);
+ // Check which dimensions of the sequential type are
diff erent than the
+ // dimensions of the distributed type to know the distributed dimensions. Then
+ // associate each distributed dimension to an ID in order.
+ for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
+ if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
+ perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
+ }
+ auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
+ distributedType.getContext());
+ assert(map.getNumResults() <= 1 &&
+ "only support distribution along one dimension for now.");
+ return map;
}
namespace {
@@ -42,28 +55,23 @@ namespace {
/// through the parallel / sequential and the sequential / parallel boundaries
/// when performing `rewriteWarpOpToScfFor`.
///
-/// All this assumes the vector distribution occurs along the most minor
-/// distributed vector dimension.
-/// TODO: which is expected to be a multiple of the warp size ?
-/// TODO: add an analysis step that determines which vector dimension should
-/// be used for distribution.
+/// The vector distribution dimension is inferred from the vector types.
struct DistributedLoadStoreHelper {
DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
Value laneId, Value zero)
: sequentialVal(sequentialVal), distributedVal(distributedVal),
laneId(laneId), zero(zero) {
- sequentialType = sequentialVal.getType();
- distributedType = distributedVal.getType();
- sequentialVectorType = sequentialType.dyn_cast<VectorType>();
- distributedVectorType = distributedType.dyn_cast<VectorType>();
+ sequentialVectorType = sequentialVal.getType().dyn_cast<VectorType>();
+ distributedVectorType = distributedVal.getType().dyn_cast<VectorType>();
+ if (sequentialVectorType && distributedVectorType)
+ distributionMap =
+ calculateImplicitMap(sequentialVectorType, distributedVectorType);
}
- Value buildDistributedOffset(RewriterBase &b, Location loc) {
- auto maybeDistributedSize = getDistributedSize(distributedVectorType);
- assert(maybeDistributedSize &&
- "at this point, a distributed size must be determined");
+ Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
+ int64_t distributedSize = distributedVectorType.getDimSize(index);
AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
- return b.createOrFold<AffineApplyOp>(loc, tid * (*maybeDistributedSize),
+ return b.createOrFold<AffineApplyOp>(loc, tid * distributedSize,
ArrayRef<Value>{laneId});
}
@@ -79,27 +87,24 @@ struct DistributedLoadStoreHelper {
assert((val == distributedVal || val == sequentialVal) &&
"Must store either the preregistered distributed or the "
"preregistered sequential value.");
+ // Scalar case can directly use memref.store.
+ if (!val.getType().isa<VectorType>())
+ return b.create<memref::StoreOp>(loc, val, buffer, zero);
+
// Vector case must use vector::TransferWriteOp which will later lower to
// vector.store of memref.store depending on further lowerings.
- if (val.getType().isa<VectorType>()) {
- int64_t rank = sequentialVectorType.getRank();
- if (rank == 0) {
- return b.create<vector::TransferWriteOp>(loc, val, buffer, ValueRange{},
- ArrayRef<bool>{});
+ int64_t rank = sequentialVectorType.getRank();
+ SmallVector<Value> indices(rank, zero);
+ if (val == distributedVal) {
+ for (auto dimExpr : distributionMap.getResults()) {
+ int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+ indices[index] = buildDistributedOffset(b, loc, index);
}
- SmallVector<Value> indices(rank, zero);
- auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
- assert(maybeDistributedDim && "must be able to deduce distributed dim");
- if (val == distributedVal)
- indices[*maybeDistributedDim] =
- (val == distributedVal) ? buildDistributedOffset(b, loc) : zero;
- SmallVector<bool> inBounds(indices.size(), true);
- return b.create<vector::TransferWriteOp>(
- loc, val, buffer, indices,
- ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
- // Scalar case can directly use memref.store.
- return b.create<memref::StoreOp>(loc, val, buffer, zero);
+ SmallVector<bool> inBounds(indices.size(), true);
+ return b.create<vector::TransferWriteOp>(
+ loc, val, buffer, indices,
+ ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
/// Create a load during the process of distributing the
@@ -122,36 +127,24 @@ struct DistributedLoadStoreHelper {
/// // Both types are f32. The constant %cst is broadcasted to all lanes.
/// ```
/// This behavior described in more detail in the documentation of the op.
- Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer,
- bool broadcastMode = false) {
- if (broadcastMode) {
- // Broadcast mode may occur for either scalar or vector operands.
- auto vectorType = type.dyn_cast<VectorType>();
- auto shape = buffer.getType().cast<MemRefType>();
- if (vectorType) {
- SmallVector<bool> inBounds(shape.getRank(), true);
- return b.create<vector::TransferReadOp>(
- loc, vectorType, buffer,
- /*indices=*/SmallVector<Value>(shape.getRank(), zero),
- ArrayRef<bool>(inBounds.begin(), inBounds.end()));
- }
+ Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
+
+ // Scalar case can directly use memref.store.
+ if (!type.isa<VectorType>())
return b.create<memref::LoadOp>(loc, buffer, zero);
- }
// Other cases must be vector atm.
// Vector case must use vector::TransferReadOp which will later lower to
// vector.read of memref.read depending on further lowerings.
- assert(type.isa<VectorType>() && "must be a vector type");
assert((type == distributedVectorType || type == sequentialVectorType) &&
"Must store either the preregistered distributed or the "
"preregistered sequential type.");
- auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
- assert(maybeDistributedDim && "must be able to deduce distributed dim");
SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
if (type == distributedVectorType) {
- indices[*maybeDistributedDim] = buildDistributedOffset(b, loc);
- } else {
- indices[*maybeDistributedDim] = zero;
+ for (auto dimExpr : distributionMap.getResults()) {
+ int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+ indices[index] = buildDistributedOffset(b, loc, index);
+ }
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
@@ -160,8 +153,8 @@ struct DistributedLoadStoreHelper {
}
Value sequentialVal, distributedVal, laneId, zero;
- Type sequentialType, distributedType;
VectorType sequentialVectorType, distributedVectorType;
+ AffineMap distributionMap;
};
} // namespace
@@ -262,32 +255,6 @@ static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
return rewriter.create(res);
}
-/// Currently the distribution map is implicit based on the vector shape. In the
-/// future it will be part of the op.
-/// Example:
-/// ```
-/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
-/// ...
-/// vector.yield %3 : vector<32x16x64xf32>
-/// }
-/// ```
-/// Would have an implicit map of:
-/// `(d0, d1, d2) -> (d0, d2)`
-static AffineMap calculateImplicitMap(Value yield, Value ret) {
- auto srcType = yield.getType().cast<VectorType>();
- auto dstType = ret.getType().cast<VectorType>();
- SmallVector<AffineExpr> perm;
- // Check which dimensions of the yield value are
diff erent than the dimensions
- // of the result to know the distributed dimensions. Then associate each
- // distributed dimension to an ID in order.
- for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
- if (srcType.getDimSize(i) != dstType.getDimSize(i))
- perm.push_back(getAffineDimExpr(i, yield.getContext()));
- }
- auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
- return map;
-}
-
namespace {
/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
@@ -318,13 +285,10 @@ namespace {
///
/// All this assumes the vector distribution occurs along the most minor
/// distributed vector dimension.
-/// TODO: which is expected to be a multiple of the warp size ?
-/// TODO: add an analysis step that determines which vector dimension should be
-/// used for distribution.
-struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
- WarpOpToScfForPattern(MLIRContext *context,
- const WarpExecuteOnLane0LoweringOptions &options,
- PatternBenefit benefit = 1)
+struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ WarpOpToScfIfPattern(MLIRContext *context,
+ const WarpExecuteOnLane0LoweringOptions &options,
+ PatternBenefit benefit = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
options(options) {}
@@ -364,10 +328,8 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
helper.buildStore(rewriter, loc, distributedVal, buffer);
// Load sequential vector from buffer, inside the ifOp.
rewriter.setInsertionPointToStart(ifOp.thenBlock());
- bool broadcastMode =
- (sequentialVal.getType() == distributedVal.getType());
- bbArgReplacements.push_back(helper.buildLoad(
- rewriter, loc, sequentialVal.getType(), buffer, broadcastMode));
+ bbArgReplacements.push_back(
+ helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
}
// Step 3. Insert sync after all the stores and before all the loads.
@@ -404,8 +366,6 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Load distributed value from buffer, after the warpOp.
rewriter.setInsertionPointAfter(ifOp);
- bool broadcastMode =
- (sequentialVal.getType() == distributedVal.getType());
// Result type and yielded value type are the same. This is a broadcast.
// E.g.:
// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
@@ -413,8 +373,8 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
// }
// Both types are f32. The constant %cst is broadcasted to all lanes.
// This is described in more detail in the documentation of the op.
- replacements.push_back(helper.buildLoad(
- rewriter, loc, distributedVal.getType(), buffer, broadcastMode));
+ replacements.push_back(
+ helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
}
// Step 6. Insert sync after all the stores and before all the loads.
@@ -758,7 +718,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
- AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
+ auto sequentialType = read.getResult().getType().cast<VectorType>();
+ auto distributedType = distributedVal.getType().cast<VectorType>();
+ AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(warpOp);
@@ -1118,7 +1080,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
- patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options, benefit);
+ patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateDistributeTransferWriteOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 20bc623064f9f..200033d608ec8 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -734,3 +734,45 @@ func.func @warp_execute_has_broadcast_semantics(%laneid: index, %s0: f32, %v0: v
// CHECK-SCF-IF: return %[[RS0]], %[[RV0]], %[[RV1]], %[[RV2]] : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
return %r#0, %r#1, %r#2, %r#3 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
}
+
+// -----
+
+// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-SCF-IF: func @warp_execute_nd_distribute
+// CHECK-SCF-IF-SAME: (%[[LANEID:.*]]: index
+func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %v1: vector<1x2x128xf32>)
+ -> (vector<1x64x1xf32>, vector<1x2x128xf32>) {
+ // CHECK-SCF-IF-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+ // CHECK-SCF-IF: vector.transfer_write %{{.*}}, %{{.*}}[%[[LANEID]], %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x1xf32>, memref<32x64x1xf32, 3>
+ // CHECK-SCF-IF: %[[RID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
+ // CHECK-SCF-IF: vector.transfer_write %{{.*}}, %{{.*}}[%[[C0]], %[[RID]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x2x128xf32>, memref<1x64x128xf32, 3>
+ // CHECK-SCF-IF: gpu.barrier
+
+ // CHECK-SCF-IF: scf.if{{.*}}{
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32]
+ args(%v0, %v1 : vector<1x64x1xf32>, vector<1x2x128xf32>) -> (vector<1x64x1xf32>, vector<1x2x128xf32>) {
+ ^bb0(%arg0: vector<32x64x1xf32>, %arg1: vector<1x64x128xf32>):
+
+ // CHECK-SCF-IF-DAG: %[[SR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<32x64x1xf32>
+ // CHECK-SCF-IF-DAG: %[[SR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x64x128xf32>
+ // CHECK-SCF-IF: %[[W0:.*]] = "some_def_0"(%[[SR0]]) : (vector<32x64x1xf32>) -> vector<32x64x1xf32>
+ // CHECK-SCF-IF: %[[W1:.*]] = "some_def_1"(%[[SR1]]) : (vector<1x64x128xf32>) -> vector<1x64x128xf32>
+ // CHECK-SCF-IF-DAG: vector.transfer_write %[[W0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<32x64x1xf32>, memref<32x64x1xf32, 3>
+ // CHECK-SCF-IF-DAG: vector.transfer_write %[[W1]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x64x128xf32>, memref<1x64x128xf32, 3>
+
+ %r0 = "some_def_0"(%arg0) : (vector<32x64x1xf32>) -> vector<32x64x1xf32>
+ %r1 = "some_def_1"(%arg1) : (vector<1x64x128xf32>) -> vector<1x64x128xf32>
+
+ // CHECK-SCF-IF-NOT: vector.yield
+ vector.yield %r0, %r1 : vector<32x64x1xf32>, vector<1x64x128xf32>
+ }
+
+ // CHECK-SCF-IF: gpu.barrier
+ // CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
+ // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
+ // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
+ // CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
+ return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
+}
More information about the Mlir-commits
mailing list