[Mlir-commits] [mlir] [mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (PR #145421)

Charitha Saumya llvmlistbot at llvm.org
Wed Jun 25 12:34:27 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/145421

>From 7eda85a4ba0fbde42af5fc59cf7ea22f3146b048 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 20 Jun 2025 16:43:18 +0000
Subject: [PATCH 1/6] save work

---
 .../Vector/Transforms/VectorDistribute.cpp    | 26 +++++++++++++++----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..c25cece08b950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1076,6 +1076,21 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   }
 };
 
+struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp =
+        operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public WarpDistributionPattern {
@@ -1761,11 +1776,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+           WarpOpExtractElement, WarpOpInsertElement, WarpOpInsertScalar,
+           WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

>From c999c2021f127cdde3bed32dfe91183cfab1dc02 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 23 Jun 2025 21:59:27 +0000
Subject: [PATCH 2/6] add comments and test

---
 .../Vector/Transforms/VectorDistribute.cpp    | 214 ++++++++++++++++--
 .../Vector/vector-warp-distribute.mlir        |  80 +++++++
 2 files changed, 279 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c25cece08b950..297bb40cbb334 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,9 +15,12 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <utility>
 
@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
   return map;
 }
 
+static int getDistributedDim(VectorType origType, VectorType distributedType) {
+  assert(origType.getRank() == distributedType.getRank() &&
+         "sequential and distributed vector types must have the same rank");
+  int64_t distributedDim = -1;
+  for (int64_t i = 0; i < origType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
+      // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+      // support distributing multiple dimensions in the future.
+      assert(distributedDim == -1 && "found multiple distributed dims");
+      distributedDim = i;
+    }
+  }
+  return distributedDim;
+}
+
 namespace {
 
 /// Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1094,123 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   }
 };
 
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+///     strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+///   gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          insertOp, "result vector type must be 2D or higher");
+    // Find the distributed dimension of the dest vector. There should be
+    // exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t destDistributedDim =
+        getDistributedDim(yieldedType, distributedType);
+    assert(destDistributedDim != -1 && "could not find distributed dimension");
+    (void)destDistributedDim;
+    VectorType srcType = insertOp.getSourceVectorType();
+    VectorType destType = insertOp.getDestVectorType();
+    // Currently we require that both source (kD) and dest (nD) vectors are
+    // distributed. This requires that distributedDim (d) is contained in the
+    // last k dims of the dest vector (d >= n - k).
+    // TODO: Add support for case where source vector is not distributed.
+    int64_t sourceDistributedDim =
+        destDistributedDim - (destType.getRank() - srcType.getRank());
+    if (sourceDistributedDim < 0)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be in the last k dims");
+    // Distributed dimension must be fully inserted.
+    if (srcType.getDimSize(sourceDistributedDim) !=
+        destType.getDimSize(destDistributedDim))
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be fully inserted");
+    SmallVector<int64_t> newSourceDistShape(
+        insertOp.getSourceVectorType().getShape()),
+        newDestDistShape(insertOp.getDestVectorType().getShape());
+    newSourceDistShape[sourceDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    newDestDistShape[destDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    auto newSourceTy =
+        VectorType::get(newSourceDistShape, distributedType.getElementType());
+    auto newDestTy =
+        VectorType::get(newDestDistShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+        {newSourceTy, newDestTy}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
+    auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
+    // Create a new insert strided slice op that inserts distributed source into
+    // distributed dest.
+    Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
+        insertOp.getLoc(), distributedDest.getType(), distributedSource,
+        distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
+    return success();
+  }
+};
+
+/// Sink out extract_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
+///     strides = [1] : vector<32x16xf32> to vector<16x16xf32>
+///   gpu.yield %extract : vector<16x16xf32>
+/// }
+/// ```
+/// To
+/// ````
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   gpu.yield %src : vector<32x16xf32>
+/// }
+/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
+///   strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// ```
+/// NOTE: Current support assumes that the extraction happens only on non
+/// distributed dimensions (does not require cross lane communication).
 struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
   using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
@@ -1087,6 +1222,63 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp =
         operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          extractOp, "result vector type must be 2D or higher");
+
+    // Find the distributed dimension. There should be exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
+    assert(distributedDim != -1 && "could not find distributed dimension");
+    (void)distributedDim;
+
+    // Distributed dimension must be fully extracted.
+    // TODO: Partial extraction from distributed dimension require cross lane
+    // communication.
+    if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
+      int64_t distributedDimOffset =
+          llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
+              .getInt();
+      int64_t distributedDimSize =
+          llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
+              .getInt();
+      if (distributedDimOffset != 0 ||
+          distributedDimSize != yieldedType.getDimSize(distributedDim))
+        return rewriter.notifyMatchFailure(
+            extractOp, "distributed dimension must be fully extracted");
+    }
+    SmallVector<int64_t> newDistributedShape(
+        extractOp.getSourceVectorType().getShape());
+    newDistributedShape[distributedDim] =
+        distributedType.getDimSize(distributedDim);
+    auto newDistributedType =
+        VectorType::get(newDistributedShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    // Update the distributed sizes to match the distributed type.
+    if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
+      distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+          distributedType.getDimSize(distributedDim));
+
+    // Create a new extract strided slice op that extracts from the
+    // distributed vector.
+    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+        extractOp.getLoc(), distributedType, distributedVec,
+        extractOp.getOffsets(),
+        ArrayAttr::get(rewriter.getContext(), distributedSizes),
+        extractOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                newExtract);
     return success();
   }
 };
@@ -1137,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
     auto distributedType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
     auto yieldedType = cast<VectorType>(operand->get().getType());
-    int64_t distributedDim = -1;
-    for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
-      if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
-        // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
-        // support distributing multiple dimensions in the future.
-        assert(distributedDim == -1 && "found multiple distributed dims");
-        distributedDim = i;
-      }
-    }
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
     assert(distributedDim != -1 && "could not find distributed dimension");
     (void)distributedDim;
 
@@ -1776,12 +1960,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
-           WarpOpExtractElement, WarpOpInsertElement, WarpOpInsertScalar,
-           WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice>(
-          patterns.getContext(), benefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
+               WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+      patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..8c3060c91f0d1 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
   return %r : vector<4x96xf32>
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
+//  CHECK-RPOP-SAME: %[[LANEID:.*]]: index
+//       CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
+//       CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
+//       CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+//  CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
+//       CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
+func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<24x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
+    %0 = "some_def"() : () -> (vector<64x32xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
+      : vector<64x32xf32> to vector<24x32xf32>
+    gpu.yield %1 : vector<24x32xf32>
+  }
+  return %r : vector<24x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index
+//       CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
+//       CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
+//       CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
+//       CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+//  CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
+//       CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
+func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<1x8xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
+    %0 = "some_def"() : () -> (vector<32x64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
+      : vector<32x64xf32> to vector<32x8xf32>
+    gpu.yield %1 : vector<32x8xf32>
+  }
+  return %r : vector<1x8xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+//       CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
+//       CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
+//       CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
+//       CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
+//  CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+//       CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<32xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
+      : vector<32xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  return %r : vector<64x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+//       CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
+//       CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
+//       CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
+//       CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
+//  CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
+//       CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x32xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0],  strides = [1, 1]}
+      : vector<16x32xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  return %r : vector<64x1xf32>
+}
+
 // -----
 
 // Make sure that all operands of the transfer_read op are properly propagated.

>From 0993ce901fd93b6b32cd98021f4317b1d806e7ed Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 00:14:44 +0000
Subject: [PATCH 3/6] address comments

---
 .../Dialect/Vector/Transforms/VectorDistribute.cpp  | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 297bb40cbb334..4e91a14d74c61 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -55,12 +55,16 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
   return map;
 }
 
-static int getDistributedDim(VectorType origType, VectorType distributedType) {
-  assert(origType.getRank() == distributedType.getRank() &&
+/// Given a sequential and distributed vector type, returns the distributed
+/// dimension. This function expects that only a single dimension is
+/// distributed.
+static int getDistributedDim(VectorType sequentialType,
+                             VectorType distributedType) {
+  assert(sequentialType.getRank() == distributedType.getRank() &&
          "sequential and distributed vector types must have the same rank");
   int64_t distributedDim = -1;
-  for (int64_t i = 0; i < origType.getRank(); ++i) {
-    if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
+  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
       // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
       // support distributing multiple dimensions in the future.
       assert(distributedDim == -1 && "found multiple distributed dims");
@@ -1234,7 +1238,6 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
     auto yieldedType = cast<VectorType>(operand->get().getType());
     int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
     assert(distributedDim != -1 && "could not find distributed dimension");
-    (void)distributedDim;
 
     // Distributed dimension must be fully extracted.
     // TODO: Partial extraction from distributed dimension require cross lane

>From 3ae71df69c19bc72242792cce0d24383c642f8b7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 17:53:37 +0000
Subject: [PATCH 4/6] address comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 4e91a14d74c61..10b6fd4ca564c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1148,7 +1148,7 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
     int64_t destDistributedDim =
         getDistributedDim(yieldedType, distributedType);
     assert(destDistributedDim != -1 && "could not find distributed dimension");
-    (void)destDistributedDim;
+
     VectorType srcType = insertOp.getSourceVectorType();
     VectorType destType = insertOp.getDestVectorType();
     // Currently we require that both source (kD) and dest (nD) vectors are
@@ -1242,7 +1242,9 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
     // Distributed dimension must be fully extracted.
     // TODO: Partial extraction from distributed dimension require cross lane
     // communication.
-    if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
+    int64_t extractedDimsRank =
+        static_cast<int64_t>(extractOp.getSizes().size());
+    if (distributedDim < extractedDimsRank) {
       int64_t distributedDimOffset =
           llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
               .getInt();

>From b087820a98b02aa04cc810c57a1823ca3d6fb33a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 19:16:32 +0000
Subject: [PATCH 5/6] address comments

---
 .../Vector/Transforms/VectorDistribute.cpp    | 33 ++++++++++---------
 1 file changed, 17 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 10b6fd4ca564c..495d1a8865812 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1102,11 +1102,11 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
 ///   ...
-///   %src = ... : vector<4x16xf32>
-///   %dest = ... : vector<8x16xf32>
+///   %src = ... : vector<4x32xf32>
+///   %dest = ... : vector<8x32xf32>
 ///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
-///     strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
-///   gpu.yield %insert : vector<8x16xf32>
+///     strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
+///   gpu.yield %insert : vector<8x32xf32>
 /// }
 /// ```
 /// To
@@ -1114,14 +1114,14 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
 /// vector<8x1xf32>) {
 ///   ...
-///   %src = ... : vector<4x16xf32>
-///   %dest = ... : vector<8x16xf32>
+///   %src = ... : vector<4x32xf32>
+///   %dest = ... : vector<8x32xf32>
 ///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
 /// }
 /// %insert = vector.insert_strided_slice %0#0, %0#1,
 ///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
 /// ```
-/// NOTE: Current support assume that both src and dest vectors are distributed
+/// NOTE: Current support assumes that both src and dest vectors are distributed
 /// to lanes and sinking the insert op does not require any cross lane
 /// communication.
 struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
@@ -1159,7 +1159,8 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
         destDistributedDim - (destType.getRank() - srcType.getRank());
     if (sourceDistributedDim < 0)
       return rewriter.notifyMatchFailure(
-          insertOp, "distributed dimension must be in the last k dims");
+          insertOp,
+          "distributed dimension must be in the last k dims of dest vector");
     // Distributed dimension must be fully inserted.
     if (srcType.getDimSize(sourceDistributedDim) !=
         destType.getDimSize(destDistributedDim))
@@ -1197,21 +1198,21 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
 ///   ...
-///   %src = ... : vector<32x16xf32>
+///   %src = ... : vector<64x32xf32>
 ///   %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
-///     strides = [1] : vector<32x16xf32> to vector<16x16xf32>
-///   gpu.yield %extract : vector<16x16xf32>
+///     strides = [1] : vector<64x32xf32> to vector<16x32xf32>
+///   gpu.yield %extract : vector<16x32xf32>
 /// }
 /// ```
 /// To
-/// ````
-/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
 ///   ...
-///   %src = ... : vector<32x16xf32>
-///   gpu.yield %src : vector<32x16xf32>
+///   %src = ... : vector<64x32xf32>
+///   gpu.yield %src : vector<64x32xf32>
 /// }
 /// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
-///   strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+///   strides = [1] : vector<64x1xf32> to vector<16x1xf32>
 /// ```
 /// NOTE: Current support assumes that the extraction happens only on non
 /// distributed dimensions (does not require cross lane communication).

>From 99cbbe5c827b91c64c68b0abc39a60d7a8ddb58f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 25 Jun 2025 19:27:57 +0000
Subject: [PATCH 6/6] address comments

---
 .../Dialect/Vector/Transforms/VectorDistribute.cpp    | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 495d1a8865812..063a32f87e66a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1240,12 +1240,15 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
     int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
     assert(distributedDim != -1 && "could not find distributed dimension");
 
-    // Distributed dimension must be fully extracted.
+    int64_t numOfExtractedDims =
+        static_cast<int64_t>(extractOp.getSizes().size());
+    // If the distributed dim is included in the extracted dims,  then we make
+    // sure distributed dim is fully extracted. If distributed dim is not
+    // included in extracted dims, it is guaranteed to be fully extracted (i.e.
+    // distributed dim comes after all the extracted dims)
     // TODO: Partial extraction from distributed dimension require cross lane
     // communication.
-    int64_t extractedDimsRank =
-        static_cast<int64_t>(extractOp.getSizes().size());
-    if (distributedDim < extractedDimsRank) {
+    if (distributedDim < numOfExtractedDims) {
       int64_t distributedDimOffset =
           llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
               .getInt();



More information about the Mlir-commits mailing list