[Mlir-commits] [mlir] 9085f00 - [mlir][vector] Support vector.extract distribution of >1D vectors
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 9 07:39:58 PST 2023
Author: Matthias Springer
Date: 2023-01-09T16:39:50+01:00
New Revision: 9085f00b4d4526013e9ae98e31cef3d1e64b92f9
URL: https://github.com/llvm/llvm-project/commit/9085f00b4d4526013e9ae98e31cef3d1e64b92f9
DIFF: https://github.com/llvm/llvm-project/commit/9085f00b4d4526013e9ae98e31cef3d1e64b92f9.diff
LOG: [mlir][vector] Support vector.extract distribution of >1D vectors
Ops such as `%1 = vector.extract %0[2] : vector<5x96xf32>`.
Distribute the source vector, then extract. In case of a 1d extract, rewrite to vector.extractelement.
Differential Revision: https://reviews.llvm.org/D137646
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 08841e38eecb7..60ca036228dcf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -897,16 +897,81 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
- if (extractOp.getVectorType().getNumElements() != 1)
- return failure();
+ VectorType extractSrcType = extractOp.getVectorType();
Location loc = extractOp.getLoc();
+
+ // "vector.extract %v[] : vector<f32>" is an invalid op.
+ assert(extractSrcType.getRank() > 0 &&
+ "vector.extract does not support rank 0 sources");
+
+ // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v.
+ if (extractOp.getPosition().empty())
+ return failure();
+
+ // Rewrite vector.extract with 1d source to vector.extractelement.
+ if (extractSrcType.getRank() == 1) {
+ assert(extractOp.getPosition().size() == 1 && "expected 1 index");
+ int64_t pos = extractOp.getPosition()[0].cast<IntegerAttr>().getInt();
+ rewriter.setInsertionPoint(extractOp);
+ rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
+ extractOp, extractOp.getVector(),
+ rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ return success();
+ }
+
+ // All following cases are 2d or higher dimensional source vectors.
+
+ if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
+ // There is no distribution, this is a broadcast. Simply move the extract
+ // out of the warp op.
+ // TODO: This could be optimized. E.g., in case of a scalar result, let
+ // one lane extract and shuffle the result to all other lanes (same as
+ // the 1d case).
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getVector()},
+ {extractOp.getVectorType()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+ // Extract from distributed vector.
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ loc, distributedVec, extractOp.getPosition());
+ newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ return success();
+ }
+
+ // Find the distributed dimension. There should be exactly one.
+ auto distributedType =
+ warpOp.getResult(operandNumber).getType().cast<VectorType>();
+ auto yieldedType = operand->get().getType().cast<VectorType>();
+ int64_t distributedDim = -1;
+ for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
+ // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+ // support distributing multiple dimensions in the future.
+ assert(distributedDim == -1 && "found multiple distributed dims");
+ distributedDim = i;
+ }
+ }
+ assert(distributedDim != -1 && "could not find distributed dimension");
+
+ // Yield source vector from warp op.
+ SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
+ extractSrcType.getShape().end());
+ for (int i = 0; i < distributedType.getRank(); ++i)
+ newDistributedShape[i + extractOp.getPosition().size()] =
+ distributedType.getDimSize(i);
+ auto newDistributedType =
+ VectorType::get(newDistributedShape, distributedType.getElementType());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+ rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
+ Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+ // Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
- loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
+ loc, distributedVec, extractOp.getPosition());
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3489054407262..5a238c57c933a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -648,17 +648,58 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// -----
-// CHECK-PROP-LABEL: func.func @vector_extract_simple(
-// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
-// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32>
-// CHECK-PROP: vector.yield %[[V]] : vector<1xf32>
+// TODO: We could use warp shuffles instead of broadcasting the entire vector.
+
+// CHECK-PROP-LABEL: func.func @vector_extract_1d(
+// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
+// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP: vector.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : vector<1xf32>
-// CHECK-PROP: return %[[E]] : f32
-func.func @vector_extract_simple(%laneid: index) -> (f32) {
+// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32>
+// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
+// CHECK-PROP: return %[[SHUFFLED]] : f32
+func.func @vector_extract_1d(%laneid: index) -> (f32) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
- %0 = "some_def"() : () -> (vector<1xf32>)
- %1 = vector.extract %0[0] : vector<1xf32>
+ %0 = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.extract %0[9] : vector<64xf32>
+ vector.yield %1 : f32
+ }
+ return %r : f32
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_2d(
+// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x3xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[V]] : vector<5x96xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][2] : vector<5x3xf32>
+// CHECK-PROP: return %[[E]]
+func.func @vector_extract_2d(%laneid: index) -> (vector<3xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<5x96xf32>)
+ %1 = vector.extract %0[2] : vector<5x96xf32>
+ vector.yield %1 : vector<96xf32>
+ }
+ return %r : vector<3xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_2d_broadcast_scalar(
+// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x96xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[V]] : vector<5x96xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][1, 2] : vector<5x96xf32>
+// CHECK-PROP: return %[[E]]
+func.func @vector_extract_2d_broadcast_scalar(%laneid: index) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (vector<5x96xf32>)
+ %1 = vector.extract %0[1, 2] : vector<5x96xf32>
vector.yield %1 : f32
}
return %r : f32
@@ -666,6 +707,42 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_2d_broadcast(
+// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x96xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[V]] : vector<5x96xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][2] : vector<5x96xf32>
+// CHECK-PROP: return %[[E]]
+func.func @vector_extract_2d_broadcast(%laneid: index) -> (vector<96xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
+ %0 = "some_def"() : () -> (vector<5x96xf32>)
+ %1 = vector.extract %0[2] : vector<5x96xf32>
+ vector.yield %1 : vector<96xf32>
+ }
+ return %r : vector<96xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_3d(
+// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8x4x96xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"
+// CHECK-PROP: vector.yield %[[V]] : vector<8x128x96xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[W]][2] : vector<8x4x96xf32>
+// CHECK-PROP: return %[[E]]
+func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) {
+ %0 = "some_def"() : () -> (vector<8x128x96xf32>)
+ %1 = vector.extract %0[2] : vector<8x128x96xf32>
+ vector.yield %1 : vector<128x96xf32>
+ }
+ return %r : vector<4x96xf32>
+}
+
+// -----
+
// CHECK-PROP-LABEL: func.func @vector_extractelement_0d(
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<f32>
More information about the Mlir-commits
mailing list