[Mlir-commits] [mlir] [mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (PR #145421)
Charitha Saumya
llvmlistbot at llvm.org
Wed Jun 25 12:28:14 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/145421
>From 7eda85a4ba0fbde42af5fc59cf7ea22f3146b048 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 20 Jun 2025 16:43:18 +0000
Subject: [PATCH 1/6] save work
---
.../Vector/Transforms/VectorDistribute.cpp | 26 +++++++++++++++----
1 file changed, 21 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..c25cece08b950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1076,6 +1076,21 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
}
};
+struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto extractOp =
+ operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+ return success();
+ }
+};
+
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public WarpDistributionPattern {
@@ -1761,11 +1776,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+ WarpOpExtractElement, WarpOpInsertElement, WarpOpInsertScalar,
+ WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
>From c999c2021f127cdde3bed32dfe91183cfab1dc02 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 23 Jun 2025 21:59:27 +0000
Subject: [PATCH 2/6] add comments and test
---
.../Vector/Transforms/VectorDistribute.cpp | 214 ++++++++++++++++--
.../Vector/vector-warp-distribute.mlir | 80 +++++++
2 files changed, 279 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c25cece08b950..297bb40cbb334 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,9 +15,12 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <utility>
@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
return map;
}
+static int getDistributedDim(VectorType origType, VectorType distributedType) {
+ assert(origType.getRank() == distributedType.getRank() &&
+ "sequential and distributed vector types must have the same rank");
+ int64_t distributedDim = -1;
+ for (int64_t i = 0; i < origType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
+ // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+ // support distributing multiple dimensions in the future.
+ assert(distributedDim == -1 && "found multiple distributed dims");
+ distributedDim = i;
+ }
+ }
+ return distributedDim;
+}
+
namespace {
/// Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1094,123 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
}
};
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+/// ...
+/// %src = ... : vector<4x16xf32>
+/// %dest = ... : vector<8x16xf32>
+/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+/// gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+/// ...
+/// %src = ... : vector<4x16xf32>
+/// %dest = ... : vector<8x16xf32>
+/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Distributed type must be 2D or higher.
+ // TODO: Support 1D distributed types.
+ if (distributedType.getRank() < 2)
+ return rewriter.notifyMatchFailure(
+ insertOp, "result vector type must be 2D or higher");
+ // Find the distributed dimension of the dest vector. There should be
+ // exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ int64_t destDistributedDim =
+ getDistributedDim(yieldedType, distributedType);
+ assert(destDistributedDim != -1 && "could not find distributed dimension");
+ (void)destDistributedDim;
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ // TODO: Add support for case where source vector is not distributed.
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp, "distributed dimension must be in the last k dims");
+ // Distributed dimension must be fully inserted.
+ if (srcType.getDimSize(sourceDistributedDim) !=
+ destType.getDimSize(destDistributedDim))
+ return rewriter.notifyMatchFailure(
+ insertOp, "distributed dimension must be fully inserted");
+ SmallVector<int64_t> newSourceDistShape(
+ insertOp.getSourceVectorType().getShape()),
+ newDestDistShape(insertOp.getDestVectorType().getShape());
+ newSourceDistShape[sourceDistributedDim] =
+ distributedType.getDimSize(destDistributedDim);
+ newDestDistShape[destDistributedDim] =
+ distributedType.getDimSize(destDistributedDim);
+ auto newSourceTy =
+ VectorType::get(newSourceDistShape, distributedType.getElementType());
+ auto newDestTy =
+ VectorType::get(newDestDistShape, distributedType.getElementType());
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {newSourceTy, newDestTy}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
+ auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
+ // Create a new insert strided slice op that inserts distributed source into
+ // distributed dest.
+ Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
+ insertOp.getLoc(), distributedDest.getType(), distributedSource,
+ distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
+ return success();
+ }
+};
+
+/// Sink out extract_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
+/// ...
+/// %src = ... : vector<32x16xf32>
+/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
+/// strides = [1] : vector<32x16xf32> to vector<16x16xf32>
+/// gpu.yield %extract : vector<16x16xf32>
+/// }
+/// ```
+/// To
+/// ````
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+/// ...
+/// %src = ... : vector<32x16xf32>
+/// gpu.yield %src : vector<32x16xf32>
+/// }
+/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
+/// strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// ```
+/// NOTE: Current support assumes that the extraction happens only on non
+/// distributed dimensions (does not require cross lane communication).
struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
@@ -1087,6 +1222,63 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp =
operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Distributed type must be 2D or higher.
+ // TODO: Support 1D distributed types.
+ if (distributedType.getRank() < 2)
+ return rewriter.notifyMatchFailure(
+ extractOp, "result vector type must be 2D or higher");
+
+ // Find the distributed dimension. There should be exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
+ assert(distributedDim != -1 && "could not find distributed dimension");
+ (void)distributedDim;
+
+ // Distributed dimension must be fully extracted.
+ // TODO: Partial extraction from distributed dimension require cross lane
+ // communication.
+ if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
+ int64_t distributedDimOffset =
+ llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
+ .getInt();
+ int64_t distributedDimSize =
+ llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
+ .getInt();
+ if (distributedDimOffset != 0 ||
+ distributedDimSize != yieldedType.getDimSize(distributedDim))
+ return rewriter.notifyMatchFailure(
+ extractOp, "distributed dimension must be fully extracted");
+ }
+ SmallVector<int64_t> newDistributedShape(
+ extractOp.getSourceVectorType().getShape());
+ newDistributedShape[distributedDim] =
+ distributedType.getDimSize(distributedDim);
+ auto newDistributedType =
+ VectorType::get(newDistributedShape, distributedType.getElementType());
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+ extractOp.getSizes(), [](Attribute attr) { return attr; });
+ // Update the distributed sizes to match the distributed type.
+ if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
+ distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+ distributedType.getDimSize(distributedDim));
+
+ // Create a new extract strided slice op that extracts from the
+ // distributed vector.
+ Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+ Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+ extractOp.getLoc(), distributedType, distributedVec,
+ extractOp.getOffsets(),
+ ArrayAttr::get(rewriter.getContext(), distributedSizes),
+ extractOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ newExtract);
return success();
}
};
@@ -1137,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
auto yieldedType = cast<VectorType>(operand->get().getType());
- int64_t distributedDim = -1;
- for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
- if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
- // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
- // support distributing multiple dimensions in the future.
- assert(distributedDim == -1 && "found multiple distributed dims");
- distributedDim = i;
- }
- }
+ int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
assert(distributedDim != -1 && "could not find distributed dimension");
(void)distributedDim;
@@ -1776,12 +1960,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns
- .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
- WarpOpExtractElement, WarpOpInsertElement, WarpOpInsertScalar,
- WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice>(
- patterns.getContext(), benefit);
+ patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..8c3060c91f0d1 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
return %r : vector<4x96xf32>
}
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
+// CHECK-RPOP-SAME: %[[LANEID:.*]]: index
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
+// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
+// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+// CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
+// CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
+func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<24x1xf32>) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
+ %0 = "some_def"() : () -> (vector<64x32xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
+ : vector<64x32xf32> to vector<24x32xf32>
+ gpu.yield %1 : vector<24x32xf32>
+ }
+ return %r : vector<24x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
+// CHECK-PROP-SAME: %[[LANEID:.*]]: index
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
+// CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
+// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+// CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
+// CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
+func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<1x8xf32>) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
+ %0 = "some_def"() : () -> (vector<32x64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
+ : vector<32x64xf32> to vector<32x8xf32>
+ gpu.yield %1 : vector<32x8xf32>
+ }
+ return %r : vector<1x8xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
+// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
+// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
+// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
+// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
+// CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<32xf32>)
+ %1 = "some_def"() : () -> (vector<64x32xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
+ : vector<32xf32> into vector<64x32xf32>
+ gpu.yield %2 : vector<64x32xf32>
+ }
+ return %r : vector<64x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
+// CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
+// CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
+// CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
+// CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
+// CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<16x32xf32>)
+ %1 = "some_def"() : () -> (vector<64x32xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0], strides = [1, 1]}
+ : vector<16x32xf32> into vector<64x32xf32>
+ gpu.yield %2 : vector<64x32xf32>
+ }
+ return %r : vector<64x1xf32>
+}
+
// -----
// Make sure that all operands of the transfer_read op are properly propagated.
>From 0993ce901fd93b6b32cd98021f4317b1d806e7ed Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 00:14:44 +0000
Subject: [PATCH 3/6] address comments
---
.../Dialect/Vector/Transforms/VectorDistribute.cpp | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 297bb40cbb334..4e91a14d74c61 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -55,12 +55,16 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
return map;
}
-static int getDistributedDim(VectorType origType, VectorType distributedType) {
- assert(origType.getRank() == distributedType.getRank() &&
+/// Given a sequential and distributed vector type, returns the distributed
+/// dimension. This function expects that only a single dimension is
+/// distributed.
+static int getDistributedDim(VectorType sequentialType,
+ VectorType distributedType) {
+ assert(sequentialType.getRank() == distributedType.getRank() &&
"sequential and distributed vector types must have the same rank");
int64_t distributedDim = -1;
- for (int64_t i = 0; i < origType.getRank(); ++i) {
- if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
+ for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
// support distributing multiple dimensions in the future.
assert(distributedDim == -1 && "found multiple distributed dims");
@@ -1234,7 +1238,6 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
assert(distributedDim != -1 && "could not find distributed dimension");
- (void)distributedDim;
// Distributed dimension must be fully extracted.
// TODO: Partial extraction from distributed dimension require cross lane
>From 3ae71df69c19bc72242792cce0d24383c642f8b7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 17:53:37 +0000
Subject: [PATCH 4/6] address comments
---
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 4e91a14d74c61..10b6fd4ca564c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1148,7 +1148,7 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
int64_t destDistributedDim =
getDistributedDim(yieldedType, distributedType);
assert(destDistributedDim != -1 && "could not find distributed dimension");
- (void)destDistributedDim;
+
VectorType srcType = insertOp.getSourceVectorType();
VectorType destType = insertOp.getDestVectorType();
// Currently we require that both source (kD) and dest (nD) vectors are
@@ -1242,7 +1242,9 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
// Distributed dimension must be fully extracted.
// TODO: Partial extraction from distributed dimension require cross lane
// communication.
- if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
+ int64_t extractedDimsRank =
+ static_cast<int64_t>(extractOp.getSizes().size());
+ if (distributedDim < extractedDimsRank) {
int64_t distributedDimOffset =
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
.getInt();
>From b087820a98b02aa04cc810c57a1823ca3d6fb33a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 19:16:32 +0000
Subject: [PATCH 5/6] address comments
---
.../Vector/Transforms/VectorDistribute.cpp | 33 ++++++++++---------
1 file changed, 17 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 10b6fd4ca564c..495d1a8865812 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1102,11 +1102,11 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
/// ...
-/// %src = ... : vector<4x16xf32>
-/// %dest = ... : vector<8x16xf32>
+/// %src = ... : vector<4x32xf32>
+/// %dest = ... : vector<8x32xf32>
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
-/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
-/// gpu.yield %insert : vector<8x16xf32>
+/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
+/// gpu.yield %insert : vector<8x32xf32>
/// }
/// ```
/// To
@@ -1114,14 +1114,14 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
/// vector<8x1xf32>) {
/// ...
-/// %src = ... : vector<4x16xf32>
-/// %dest = ... : vector<8x16xf32>
+/// %src = ... : vector<4x32xf32>
+/// %dest = ... : vector<8x32xf32>
/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
/// }
/// %insert = vector.insert_strided_slice %0#0, %0#1,
/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
/// ```
-/// NOTE: Current support assume that both src and dest vectors are distributed
+/// NOTE: Current support assumes that both src and dest vectors are distributed
/// to lanes and sinking the insert op does not require any cross lane
/// communication.
struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
@@ -1159,7 +1159,8 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
destDistributedDim - (destType.getRank() - srcType.getRank());
if (sourceDistributedDim < 0)
return rewriter.notifyMatchFailure(
- insertOp, "distributed dimension must be in the last k dims");
+ insertOp,
+ "distributed dimension must be in the last k dims of dest vector");
// Distributed dimension must be fully inserted.
if (srcType.getDimSize(sourceDistributedDim) !=
destType.getDimSize(destDistributedDim))
@@ -1197,21 +1198,21 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
/// ...
-/// %src = ... : vector<32x16xf32>
+/// %src = ... : vector<64x32xf32>
/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
-/// strides = [1] : vector<32x16xf32> to vector<16x16xf32>
-/// gpu.yield %extract : vector<16x16xf32>
+/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
+/// gpu.yield %extract : vector<16x32xf32>
/// }
/// ```
/// To
-/// ````
-/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
/// ...
-/// %src = ... : vector<32x16xf32>
-/// gpu.yield %src : vector<32x16xf32>
+/// %src = ... : vector<64x32xf32>
+/// gpu.yield %src : vector<64x32xf32>
/// }
/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
-/// strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
/// ```
/// NOTE: Current support assumes that the extraction happens only on non
/// distributed dimensions (does not require cross lane communication).
>From 99cbbe5c827b91c64c68b0abc39a60d7a8ddb58f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 19:27:57 +0000
Subject: [PATCH 6/6] address comments
---
.../Dialect/Vector/Transforms/VectorDistribute.cpp | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 495d1a8865812..063a32f87e66a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1240,12 +1240,15 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
assert(distributedDim != -1 && "could not find distributed dimension");
- // Distributed dimension must be fully extracted.
+ int64_t numOfExtractedDims =
+ static_cast<int64_t>(extractOp.getSizes().size());
+ // If the distributed dim is included in the extracted dims, then we make
+ // sure distributed dim is fully extracted. If distributed dim is not
+ // included in extracted dims, it is guaranteed to be fully extracted (i.e.
+ // distributed dim comes after all the extracted dims)
// TODO: Partial extraction from distributed dimension require cross lane
// communication.
- int64_t extractedDimsRank =
- static_cast<int64_t>(extractOp.getSizes().size());
- if (distributedDim < extractedDimsRank) {
+ if (distributedDim < numOfExtractedDims) {
int64_t distributedDimOffset =
llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
.getInt();
More information about the Mlir-commits
mailing list