[Mlir-commits] [mlir] 9d51b4e - [mlir][vector] Support vector.extractelement distribution of 1D vectors
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 10 06:11:58 PST 2022
Author: Matthias Springer
Date: 2022-11-10T15:07:56+01:00
New Revision: 9d51b4e4e77691930fa837dd423c648acf4beb5e
URL: https://github.com/llvm/llvm-project/commit/9d51b4e4e77691930fa837dd423c648acf4beb5e
DIFF: https://github.com/llvm/llvm-project/commit/9d51b4e4e77691930fa837dd423c648acf4beb5e.diff
LOG: [mlir][vector] Support vector.extractelement distribution of 1D vectors
Ops such as `%1 = vector.extractelement %0[%pos : index] : vector<96xf32>`.
In case of an extract from a 1D vector, the source vector is distributed. The lane into which the requested position falls, extracts the element and shuffles it to all other lanes.
Differential Revision: https://reviews.llvm.org/D137336
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 49e34274f9891..a76a58eb5ec6d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -67,11 +67,17 @@ void populateDistributeTransferWriteOpPatterns(
/// region.
void moveScalarUniformCode(WarpExecuteOnLane0Op op);
+/// Lambda signature to compute a warp shuffle of a given value of a given lane
+/// within a given warp size.
+using WarpShuffleFromIdxFn =
+ std::function<Value(Location, OpBuilder &b, Value, Value, int64_t)>;
+
/// Collect patterns to propagate warp distribution. `distributionMapFn` is used
/// to decide how a value should be distributed when this cannot be inferred
/// from its uses.
void populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
+ const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit = 1);
/// Lambda signature to compute a reduction of a distributed value for the given
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index a2916a57350ba..c56af3a7583c8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -915,7 +915,10 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
- using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
+ PatternBenefit b = 1)
+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+ warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
@@ -925,19 +928,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
- if (extractOp.getVectorType().getRank() != 0)
- return failure();
+ VectorType extractSrcType = extractOp.getVectorType();
+ bool is0dExtract = extractSrcType.getRank() == 0;
+ Type elType = extractSrcType.getElementType();
+ VectorType distributedVecType;
+ if (!is0dExtract) {
+ assert(extractSrcType.getRank() == 1 &&
+ "expected that extractelement src rank is 0 or 1");
+ int64_t elementsPerLane =
+ extractSrcType.getShape()[0] / warpOp.getWarpSize();
+ distributedVecType = VectorType::get({elementsPerLane}, elType);
+ } else {
+ distributedVecType = extractSrcType;
+ }
+
+ // Yield source vector from warp op.
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+ rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- Value newExtract = rewriter.create<vector::ExtractElementOp>(
- loc, newWarpOp->getResult(newRetIndices[0]));
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+
+ // 0d extract: The new warp op broadcasts the source vector to all lanes.
+ // All lanes extract the scalar.
+ if (is0dExtract) {
+ Value newExtract =
+ rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ return success();
+ }
+
+ // 1d extract: Distribute the source vector. One lane extracts and shuffles
+ // the value to all other lanes.
+ int64_t elementsPerLane = distributedVecType.getShape()[0];
+ AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
+ // tid of extracting thread: pos / elementsPerLane
+ Value broadcastFromTid = rewriter.create<AffineApplyOp>(
+ loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
+ // Extract at position: pos % elementsPerLane
+ Value pos = rewriter.create<AffineApplyOp>(loc, sym0 % elementsPerLane,
+ extractOp.getPosition());
+ Value extracted =
+ rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
+
+ // Shuffle the extracted value to all lanes.
+ Value shuffled = warpShuffleFromIdxFn(
+ loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled);
return success();
}
+
+private:
+ WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
@@ -1194,11 +1238,12 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
- PatternBenefit benefit) {
+ const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
- WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
- WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(),
- benefit);
+ WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
+ WarpOpConstant>(patterns.getContext(), benefit);
+ patterns.add<WarpOpExtractElement>(patterns.getContext(),
+ warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index daebccd92008d..b69874542803d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -666,14 +666,14 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
// -----
-// CHECK-PROP-LABEL: func.func @vector_extractelement_simple(
+// CHECK-PROP-LABEL: func.func @vector_extractelement_0d(
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<f32>
// CHECK-PROP: vector.yield %[[V]] : vector<f32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
// CHECK-PROP: return %[[E]] : f32
-func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
+func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
%0 = "some_def"() : () -> (vector<f32>)
%1 = vector.extractelement %0[] : vector<f32>
@@ -684,6 +684,32 @@ func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
// -----
+// CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+// CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
+// CHECK-PROP-LABEL: func.func @vector_extractelement_1d(
+// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index
+// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<3xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[V]] : vector<96xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[FROM_LANE:.*]] = affine.apply #[[$map]]()[%[[POS]]]
+// CHECK-PROP: %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]]]
+// CHECK-PROP: %[[EXTRACTED:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32>
+// CHECK-PROP: %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32
+// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32
+// CHECK-PROP: return %[[SHUFFLED]]
+func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (vector<96xf32>)
+ %1 = vector.extractelement %0[%pos : index] : vector<96xf32>
+ vector.yield %1 : f32
+ }
+ return %r : f32
+}
+
+// -----
+
// CHECK-PROP: func @lane_dependent_warp_propagate_read
// CHECK-PROP-SAME: %[[ID:.*]]: index
func.func @lane_dependent_warp_propagate_read(
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 4f44d43cdb317..6b9afe3b42e71 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -759,6 +759,21 @@ struct TestVectorDistribution
return AffineMap::get(val.getContext());
return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
};
+ auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
+ Value srcIdx, int64_t warpSz) {
+ assert((val.getType().isF32() || val.getType().isInteger(32)) &&
+ "unsupported shuffle type");
+ Type i32Type = builder.getIntegerType(32);
+ Value srcIdxI32 =
+ builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
+ Value warpSzI32 = builder.create<arith::ConstantOp>(
+ loc, builder.getIntegerAttr(i32Type, warpSz));
+ Value result = builder
+ .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
+ gpu::ShuffleMode::IDX)
+ .getResult(0);
+ return result;
+ };
if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
@@ -766,8 +781,8 @@ struct TestVectorDistribution
}
if (propagateDistribution) {
RewritePatternSet patterns(ctx);
- vector::populatePropagateWarpVectorDistributionPatterns(patterns,
- distributionFn);
+ vector::populatePropagateWarpVectorDistributionPatterns(
+ patterns, distributionFn, shuffleFn);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
More information about the Mlir-commits
mailing list