[Mlir-commits] [mlir] 73ddc44 - [mlir][vector] Enable distribution over multiple dimensions
Lei Zhang
llvmlistbot at llvm.org
Wed Aug 16 12:08:59 PDT 2023
Author: Lei Zhang
Date: 2023-08-16T12:08:43-07:00
New Revision: 73ddc4474bc4c34d3b6f50cd7a6e88a12ca83f8d
URL: https://github.com/llvm/llvm-project/commit/73ddc4474bc4c34d3b6f50cd7a6e88a12ca83f8d
DIFF: https://github.com/llvm/llvm-project/commit/73ddc4474bc4c34d3b6f50cd7a6e88a12ca83f8d.diff
LOG: [mlir][vector] Enable distribution over multiple dimensions
This commit starts enabling vector distruction over multiple
dimensions. It requires delinearize the lane ID to match the
expected rank. shape_cast and transfer_read now can properly
handle multiple dimensions.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D157931
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5f5909ec998105..1ad22cdf9788c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5760,22 +5760,26 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
"expected distributed vectors to have same rank and element type.");
- bool foundDistributedDim = false;
+
+ SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
- if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
- continue;
- if (expandedVecType.getDimSize(i) ==
- distributedVecType.getDimSize(i) * warpSize) {
- if (foundDistributedDim)
- return op->emitOpError()
- << "expected only one dimension to be distributed from "
- << expandedVecType << " to " << distributedVecType;
- foundDistributedDim = true;
+ int64_t eDim = expandedVecType.getDimSize(i);
+ int64_t dDim = distributedVecType.getDimSize(i);
+ if (eDim == dDim)
continue;
- }
- return op->emitOpError() << "incompatible distribution dimensions from "
- << expandedVecType << " to " << distributedVecType;
+ if (eDim % dDim != 0)
+ return op->emitOpError()
+ << "expected expanded vector dimension #" << i << " (" << eDim
+ << ") to be a multipler of the distributed vector dimension ("
+ << dDim << ")";
+ scales[i] = eDim / dDim;
}
+ if (std::accumulate(scales.begin(), scales.end(), 1,
+ std::multiplies<int64_t>()) != warpSize)
+ return op->emitOpError()
+ << "incompatible distribution dimensions from " << expandedVecType
+ << " to " << distributedVecType << " with warp size = " << warpSize;
+
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 9d6c45b4bceaec..2182a2ebf7f081 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -16,6 +16,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
+#include <numeric>
#include <utility>
using namespace mlir;
@@ -45,8 +46,6 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
}
auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
distributedType.getContext());
- assert(map.getNumResults() <= 1 &&
- "only support distribution along one dimension for now.");
return map;
}
@@ -702,6 +701,49 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Delinearize the given `laneId` into multiple dimensions, where each
+/// dimension's size is determined by `originalShape` and `distributedShape`
+/// together. This function expects the total numbers of threads needed for
+/// distribution is equal to `warpSize`. Returns true and updates
+/// `delinearizedIds` if so.
+bool delinearizeLaneId(OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> originalShape,
+ ArrayRef<int64_t> distributedShape, int64_t warpSize,
+ Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
+ SmallVector<int64_t> sizes;
+ for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
+ if (large % small != 0)
+ return false;
+ sizes.push_back(large / small);
+ }
+ if (std::accumulate(sizes.begin(), sizes.end(), 1,
+ std::multiplies<int64_t>()) != warpSize)
+ return false;
+
+ AffineExpr s0, s1;
+ bindSymbols(builder.getContext(), s0, s1);
+
+ int64_t usedThreads = 1;
+
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ delinearizedIds.assign(sizes.size(), zero);
+
+ for (int i = sizes.size() - 1; i >= 0; --i) {
+ usedThreads *= sizes[i];
+ if (usedThreads == warpSize) {
+ // We've used up all available threads. Don't need to perform modulo
+ // anymore. And we can stop the calculation for further dimensions.
+ delinearizedIds[i] = laneId;
+ break;
+ }
+ delinearizedIds[i] =
+ affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
+ laneId = affine::makeComposedAffineApply(
+ builder, loc, s0.floorDiv(usedThreads), {laneId});
+ }
+ return true;
+}
+
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -743,6 +785,16 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
AffineMap indexMap = map.compose(read.getPermutationMap());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(warpOp);
+
+ // Try to delinearize the lane ID to match the rank expected for
+ // distribution.
+ SmallVector<Value> delinearizedIds;
+ if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
+ distributedType.getShape(), warpOp.getWarpSize(),
+ warpOp.getLaneid(), delinearizedIds))
+ return rewriter.notifyMatchFailure(
+ read, "cannot delinearize lane ID for distribution");
+
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
@@ -751,11 +803,10 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
- int64_t scale =
- cast<VectorType>(distributedVal.getType()).getDimSize(vectorPos);
+ int64_t scale = distributedType.getDimSize(vectorPos);
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
- {indices[indexPos], warpOp.getLaneid()});
+ {indices[indexPos], delinearizedIds[vectorPos]});
}
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
@@ -918,6 +969,48 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Pattern to move shape cast out of the warp op. shape cast is basically a
+/// no-op for warp distribution; we need to handle the shape though.
+struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
+ if (!operand)
+ return failure();
+ auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
+
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto castDistributedType =
+ cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
+ VectorType castOriginalType = oldCastOp.getSourceVectorType();
+ VectorType castResultType = castDistributedType;
+
+ // We expect the distributed type to have a smaller rank than the original
+ // type. Prepend with size-one dimensions to make them the same.
+ unsigned castDistributedRank = castDistributedType.getRank();
+ unsigned castOriginalRank = castOriginalType.getRank();
+ if (castDistributedRank < castOriginalRank) {
+ SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
+ llvm::append_range(shape, castDistributedType.getShape());
+ castDistributedType =
+ VectorType::get(shape, castDistributedType.getElementType());
+ }
+
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value newCast = rewriter.create<vector::ShapeCastOp>(
+ oldCastOp.getLoc(), castResultType,
+ newWarpOp->getResult(newRetIndices[0]));
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
+ return success();
+ }
+};
+
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1557,9 +1650,9 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
- WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
- patterns.getContext(), benefit);
+ WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+ WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
+ WarpOpInsert>(patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 16fb631af25834..50119c2b4a3626 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1593,7 +1593,7 @@ func.func @warp_wrong_arg_distribution(%laneid: index, %v0 : vector<4xi32>) {
// -----
func.func @warp_2_distributed_dims(%laneid: index) {
- // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}}
+ // expected-error at +1 {{incompatible distribution dimensions from 'vector<128x128xi32>' to 'vector<4x4xi32>' with warp size = 32}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
%0 = arith.constant dense<2>: vector<128x128xi32>
vector.yield %0 : vector<128x128xi32>
@@ -1603,6 +1603,17 @@ func.func @warp_2_distributed_dims(%laneid: index) {
// -----
+func.func @warp_2_distributed_dims(%laneid: index) {
+ // expected-error at +1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+ %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
+ %0 = arith.constant dense<2>: vector<4x8xi32>
+ vector.yield %0 : vector<4x8xi32>
+ }
+ return
+}
+
+// -----
+
func.func @warp_mismatch_rank(%laneid: index) {
// expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 90dc9a954a2fdb..2154304965a5d0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -849,6 +849,17 @@ func.func @warp_execute_on_lane_0(%laneid: index) {
return
}
+// CHECK-LABEL: func.func @warp_execute_on_lane_0_2d
+func.func @warp_execute_on_lane_0_2d(%laneid: index) {
+ // CHECK: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x4xi32>)
+ %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x4xi32>) {
+ %0 = arith.constant dense<2>: vector<4x32xi32>
+ // CHECK: vector.yield %{{.+}} : vector<4x32xi32>
+ vector.yield %0 : vector<4x32xi32>
+ }
+ return
+}
+
// CHECK-LABEL: func @warp_operand_result(
func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) {
// CHECK-NEXT: %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) {
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 28efd5721524ee..cd0a14eb5f7211 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -827,6 +827,50 @@ func.func @lane_dependent_warp_propagate_read(
// -----
+func.func @warp_propagate_read_3d(%laneid: index, %src: memref<32x4x32xf32>) -> vector<1x1x4xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<1x1x4xf32>) {
+ %2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32>
+ vector.yield %2 : vector<32x4x32xf32>
+ }
+ return %r : vector<1x1x4xf32>
+}
+
+// CHECK-PROP-DAG: #[[$ID0MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+// CHECK-PROP-DAG: #[[$ID1MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 4)>
+// CHECK-PROP-DAG: #[[$ID2MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) floordiv 32)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_read_3d
+// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x4x32xf32>)
+// CHECK-PROP-DAG: %[[ID0:.+]] = affine.apply #[[$ID0MAP]]()[%[[LANE]]]
+// CHECK-PROP-DAG: %[[ID1:.+]] = affine.apply #[[$ID1MAP]]()[%[[LANE]]]
+// CHECK-PROP-DAG: %[[ID2:.+]] = affine.apply #[[$ID2MAP]]()[%[[LANE]]]
+// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID2]], %[[ID1]], %[[ID0]]], %{{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32>
+// CHECK-PROP: return %[[READ]] : vector<1x1x4xf32>
+
+// -----
+
+func.func @warp_propagate_read_broadcast(%laneid: index, %src: memref<32x1xf32>) -> vector<1x4xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = vector.warp_execute_on_lane_0(%laneid)[512] -> (vector<1x4xf32>) {
+ %2 = vector.transfer_read %src[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d0, 0)>} : memref<32x1xf32>, vector<32x64xf32>
+ vector.yield %2 : vector<32x64xf32>
+ }
+ return %r : vector<1x4xf32>
+}
+
+// CHECK-PROP-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 16)>
+// CHECK-PROP-DAG: #[[$READMAP:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_read_broadcast
+// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x1xf32>)
+// CHECK-PROP: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-PROP: %[[ID:.+]] = affine.apply #[[$MAP]]()[%[[LANE]]]
+// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID]], %[[C0]]], %{{.+}} {in_bounds = [true, true], permutation_map = #[[$READMAP]]} : memref<32x1xf32>, vector<1x4xf32>
+// CHECK-PROP: return %[[READ]] : vector<1x4xf32>
+
+// -----
+
// CHECK-PROP: func @dont_duplicate_read
func.func @dont_duplicate_read(
%laneid: index, %src: memref<1024xf32>) -> vector<1xf32> {
@@ -1173,3 +1217,22 @@ func.func @dont_fold_vector_broadcast(%laneid: index) {
vector.print %r : vector<1x2xf32>
return
}
+
+// -----
+
+func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<4xf32>) {
+ %2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32>
+ %3 = vector.shape_cast %2 : vector<32x4x32xf32> to vector<4096xf32>
+ vector.yield %3 : vector<4096xf32>
+ }
+ return %r : vector<4xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast
+// CHECK-PROP: %[[READ:.+]] = vector.transfer_read {{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32>
+// CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK-PROP: return %[[CAST]] : vector<4xf32>
+
More information about the Mlir-commits
mailing list