[Mlir-commits] [mlir] c333f7d - [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (#168626)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 26 08:10:47 PST 2025
Author: Charitha Saumya
Date: 2025-11-26T10:10:36-06:00
New Revision: c333f7dab9f89734777f7d19bc7b68c86f393216
URL: https://github.com/llvm/llvm-project/commit/c333f7dab9f89734777f7d19bc7b68c86f393216
DIFF: https://github.com/llvm/llvm-project/commit/c333f7dab9f89734777f7d19bc7b68c86f393216.diff
LOG: [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (#168626)
This PR adds general SIMT distribution support for
`vector.extract/insert_strided_slice`. Currently vector distribution
already have support for these operations but have restrictions to avoid
requiring layouts during distribution logic. For example,
`extract_stride_slice` require that distributed dimension is fully
extracted. However, more complex cases may require extracting partially
from distributed dimension (eg. 8x16xf16 extraction from 8x32xf16).
These types of cases need the layouts to reason about how the data is
spread across SIMT lanes.
Currently, we don't have layout access in vector distribution so these
new patterns are place in XeGPU side. They have higher pattern benefit
so that they will be tried first before trying regular vector
distribution based patterns.
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b64eb5b29ccb0..0d1c5eeeff711 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -174,6 +174,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+ VectorType distributedType) {
+ assert(originalType.getRank() == distributedType.getRank() &&
+ "sequential and distributed vector types must have the same rank");
+ SmallVector<int64_t> distributedDims;
+ for (int64_t i = 0; i < originalType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
+ distributedDims.push_back(i);
+ }
+ }
+ return distributedDims;
+}
+
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
/// contained within a WarpExecuteOnLane0Op.
@@ -1469,6 +1484,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+// advanced cases where the distributed dimension is partially extracted and
+// currently not supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimensions.
+ auto extractResultType = cast<VectorType>(operand->get().getType());
+ auto distributedDims =
+ getDistributedDims(extractResultType, distributedType);
+ // Collect updated source type, sizes and offsets. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ VectorType updatedSourceType = extractOp.getSourceVectorType();
+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+ extractOp.getSizes(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ extractOp.getOffsets(), [](Attribute attr) { return attr; });
+ // If the result is distributed, it must be distributed in exactly one
+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
+ // and distributedOffsets accordingly.
+ if (distributedDims.size() > 0) {
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source can not be distributed in multiple dimensions.");
+ int64_t distributedDim = distributedDims[0];
+ int sourceDistrDimSize =
+ extractOp.getSourceVectorType().getShape()[distributedDim];
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source of extract_strided_slice op lacks distribution "
+ "layout");
+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize = sourceLaneLayout[distributedDim];
+ // Check if the source size in the distributed dimension is a multiple of
+ // subgroup size.
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Source size along distributed dimension is not a multiple of "
+ "subgroup size.");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // We expect lane data to be all ones in this case.
+ if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source layout");
+ // The offsets in the distributed dimention must be a multiple of subgroup
+ // size.
+ int64_t distrDimOffset =
+ cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+ if (distrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Offset along distributed dimension "
+ "is not a multiple of subgroup size.");
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, extractOp.getSourceVectorType())
+ .value();
+ // Update the distributed sizes to match the distributed type.
+ updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+ distributedType.getDimSize(distributedDim));
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[distributedDim] =
+ rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source of the extract op from
+ // the warp op and creating a new extract op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ // Create a new extract op outside the warp op.
+ Value newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), distributedType, source,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ ArrayAttr::get(rewriter.getContext(), updatedSizes),
+ extractOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
+ return success();
+ }
+};
+
+/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+/// advanced cases where the distributed dimension is partially inserted and
+/// currently not supported by the generic vector distribution patterns.
+struct VectorInsertStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Find the distributed dimensions of the dest vector.
+ auto insertResultType = cast<VectorType>(operand->get().getType());
+ auto destDistributedDims =
+ getDistributedDims(insertResultType, distributedType);
+ // Collect updated offsets, source type and dest type. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ insertOp.getOffsets(), [](Attribute attr) { return attr; });
+ VectorType updatedSourceType = insertOp.getSourceVectorType();
+ VectorType updatedDestType = insertOp.getDestVectorType();
+ if (destDistributedDims.size() > 0) {
+ // Only single dimension distribution is supported.
+ if (destDistributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting source to be distributed in a single dimension.");
+ int64_t destDistributedDim = destDistributedDims[0];
+
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp,
+ "distributed dimension must be in the last k (i.e. source "
+ "rank) dims of dest vector");
+ int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+ // Obtain the source and dest layouts.
+ auto destLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+ if (!destLayout || !sourceLayout ||
+ destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source or dest of insert_strided_slice op lacks "
+ "distribution layout");
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize =
+ destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+ // We require that source and dest lane data are all ones to ensure
+ // uniform round robin distribution.
+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+ !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source and dest layouts");
+ // Source distributed dim size must be multiples of subgroup size.
+ if (srcDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Distributed dimension size in source is not a multiple of "
+ "subgroup size.");
+ // Offsets in the distributed dimension must be multiples of subgroup
+ // size.
+ int64_t destDistrDimOffset =
+ cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+ if (destDistrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Offset along distributed dimension in dest is not a multiple of "
+ "subgroup size.");
+ // Update the source and dest types based on their layouts.
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, insertOp.getSourceVectorType())
+ .value();
+ updatedDestType = getDistVecTypeBasedOnLaneLayout(
+ destLayout, insertOp.getDestVectorType())
+ .value();
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[destDistributedDim] =
+ rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source and dest of the insert op
+ // from the warp op and creating a new insert op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {updatedSourceType, updatedDestType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
+ Value dest = newWarpOp.getResult(newRetIndices[1]);
+ // Create a new insert op outside the warp op.
+ Value newInsertOp = vector::InsertStridedSliceOp::create(
+ rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ insertOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+ newInsertOp);
+ return success();
+ }
+};
+
/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
/// outside of the warp op.
@@ -1626,9 +1861,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
- patterns.add<VectorShapeCastDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/highPatternBenefit);
+ // For following patterns, we need to override the regular vector distribution
+ // patterns. Therefore, assign higher benefit.
+ patterns
+ .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
+ VectorInsertStridedSliceDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/highPatternBenefit);
}
void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff609f2b..44ec21359593f 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \
-// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s
-
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute \
+// RUN: -allow-unregistered-dialect -canonicalize -cse %s | FileCheck %s
+gpu.module @xevm_module{
// CHECK-LABEL: gpu.func @store_nd_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -11,20 +11,17 @@
// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @xevm_module{
- gpu.func @store_nd_1d(%laneid: index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %cst = "some_op"() : () -> vector<16xf32>
- xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- }
- gpu.return
+gpu.func @store_nd_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %cst = "some_op"() : () -> vector<16xf32>
+ xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
}
+ gpu.return
}
-// -----
// CHECK-LABEL: gpu.func @store_nd_2d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
@@ -37,22 +34,18 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @store_nd_2d(%laneid : index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %cst = "some_op"() : () -> vector<16x16xf16>
- xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- }
- gpu.return
+gpu.func @store_nd_2d(%laneid : index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %cst = "some_op"() : () -> vector<16x16xf16>
+ xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
}
+ gpu.return
}
-
-// -----
// CHECK-LABEL: gpu.func @load_nd_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>,
@@ -63,21 +56,19 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.load_nd %[[T1]][%[[W]]#2] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
-gpu.module @xevm_module{
- gpu.func @load_nd_1d(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
- !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
- gpu.yield %1 : vector<16xf32>
- }
- "some_user_op"(%r) : (vector<1xf32>) -> ()
- gpu.return
+gpu.func @load_nd_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+ !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+ gpu.yield %1 : vector<16xf32>
}
+ "some_user_op"(%r) : (vector<1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @load_nd_2d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16,
@@ -89,21 +80,19 @@ gpu.module @xevm_module{
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16>
-gpu.module @xevm_module{
- gpu.func @load_nd_2d(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- gpu.yield %1 : vector<16x16xf16>
- }
- "some_user_op"(%r) : (vector<16x1xf16>) -> ()
- gpu.return
+gpu.func @load_nd_2d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ gpu.yield %1 : vector<16x16xf16>
}
+ "some_user_op"(%r) : (vector<16x1xf16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @load_nd_array_length
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>,
@@ -118,23 +107,21 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16,
// CHECK-SAME: #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
// CHECK-NEXT: vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16>
-gpu.module @xevm_module{
- gpu.func @load_nd_array_length(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
- #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
- #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
- gpu.yield %1 : vector<2x16x16xf16>
- }
- "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
- gpu.return
+gpu.func @load_nd_array_length(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+ gpu.yield %1 : vector<2x16x16xf16>
}
+ "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @dpas
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] ->
@@ -146,29 +133,27 @@ gpu.module @xevm_module{
// CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32>
// CHECK-NEXT: %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
// CHECK-NEXT: vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-gpu.module @xevm_module{
- gpu.func @dpas(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
- %0 = "some_op"() : () -> vector<8x16xf16>
- %1 = "some_op"() : () -> vector<16x16xf16>
- %2 = "some_op"() : () -> vector<8x16xf32>
- %3 = xegpu.dpas %0, %1, %2
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- gpu.yield %3 : vector<8x16xf32>
- }
- "some_user_op"(%r) : (vector<8x1xf32>) -> ()
- gpu.return
+gpu.func @dpas(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_op"() : () -> vector<8x16xf16>
+ %1 = "some_op"() : () -> vector<16x16xf16>
+ %2 = "some_op"() : () -> vector<8x16xf32>
+ %3 = xegpu.dpas %0, %1, %2
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ gpu.yield %3 : vector<8x16xf32>
}
+ "some_user_op"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -178,21 +163,19 @@ gpu.module @xevm_module{
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK-NEXT: builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16,
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch}
-gpu.module @xevm_module{
- gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
- %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
- !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- }
- "some_user_op"(%r)
- : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
- gpu.return
+gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+ %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
+ !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
}
+ "some_user_op"(%r)
+ : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @prefetch_2d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16,
@@ -204,21 +187,19 @@ gpu.module @xevm_module{
// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2]
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @prefetch_2d(%laneid: index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : ()
- -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.prefetch_nd %0[%c0, %c0]
- <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- }
- gpu.return
+gpu.func @prefetch_2d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : ()
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %0[%c0, %c0]
+ <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @prefetch_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16,
@@ -229,44 +210,40 @@ gpu.module @xevm_module{
// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch}
// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>,
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
-gpu.module @xevm_module{
- gpu.func @prefetch_1d(%laneid: index) {
- %c0 = arith.constant 0 : index
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %0 = "some_op"() : ()
- -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.prefetch_nd %0[%c0]
- <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- }
- gpu.return
+gpu.func @prefetch_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : ()
+ -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ xegpu.prefetch_nd %0[%c0]
+ <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
// CHECK: gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
// CHECK: gpu.yield %{{.*}}
// CHECK: }
// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
// CHECK: gpu.barrier
-gpu.module @xevm_module{
- gpu.func @gpu_barrier(%laneid: index) {
- %c0 = arith.constant 0 : index
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
- %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0[%c0]
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
- gpu.barrier
- gpu.yield %1 : vector<16xf16>
- }
- "some_user_op"(%r) : (vector<1xf16>) -> ()
- gpu.return
+gpu.func @gpu_barrier(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %1 = xegpu.load_nd %0[%c0]
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
+ gpu.barrier
+ gpu.yield %1 : vector<16xf16>
}
+ "some_user_op"(%r) : (vector<1xf16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
@@ -285,7 +262,6 @@ gpu.module @xevm_module{
// CHECK: %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
// CHECK: %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32>
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -307,9 +283,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
// CHECK-NEXT: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
@@ -320,7 +295,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
// CHECK-NEXT: %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
// CHECK-NEXT: gpu.yield %[[T6]] : vector<2xf32>
// CHECK-NEXT: }
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -342,9 +316,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
@@ -358,7 +331,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
// CHECK: %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
// CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32
// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -380,9 +352,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
@@ -397,7 +368,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
// CHECK: gpu.yield %[[T7]] : vector<2xf32>
// CHECK: }
-gpu.module @xevm_module{
gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -419,9 +389,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
"some_user_op"(%r) : (vector<2xf32>) -> ()
gpu.return
}
-}
-// -----
+
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -434,35 +403,33 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
// CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
- gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<1>: vector<16xi1>
- %offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
- {
- layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
- }
- : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- }
- : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
- }
- gpu.return
+gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %1 = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<1>: vector<16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+ {
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ }
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
@@ -475,156 +442,144 @@ gpu.module @xevm_module{
// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
// CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3
// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
- gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
- gpu.warp_execute_on_lane_0(%laneid)[16] {
- %1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<1> : vector<16xi1>
- %offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1
- {
- layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
- xegpu.store %3, %src[%offset], %1
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- }
- : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %1 = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<1> : vector<16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1
+ {
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset], %1
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
}
- gpu.return
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) {
// CHECK: gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16>
// CHECK-NEXT: }
// CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index
// CHECK-NEXT: arith.index_cast %[[INTPTR]] : index to i64
-gpu.module @xevm_module{
- gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
- %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
- gpu.yield %ptr : index
- }
- %ptr_i64 = arith.index_cast %r : index to i64
- "some_user_op"(%ptr_i64) : (i64) -> ()
- gpu.return
+gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
+ %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+ gpu.yield %ptr : index
}
+ %ptr_i64 = arith.index_cast %r : index to i64
+ "some_user_op"(%ptr_i64) : (i64) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_transpose(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32>
// CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
-gpu.module @xevm_module{
- gpu.func @vector_transpose(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
- : () -> (vector<16x2xf32>)
- %transpose = vector.transpose %cst, [1, 0]
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16x2xf32> to vector<2x16xf32>
- gpu.yield %transpose : vector<2x16xf32>
- }
- "some_user_op"(%r) : (vector<2x1xf32>) -> ()
- gpu.return
+gpu.func @vector_transpose(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : () -> (vector<16x2xf32>)
+ %transpose = vector.transpose %cst, [1, 0]
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x2xf32> to vector<2x16xf32>
+ gpu.yield %transpose : vector<2x16xf32>
}
+ "some_user_op"(%r) : (vector<2x1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_bitcast(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) {
// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8>
// CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8>
// CHECK: }
// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
-gpu.module @xevm_module{
- gpu.func @vector_bitcast(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
- : () -> (vector<4x32xi8>)
- %bitcast = vector.bitcast %cst
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<4x32xi8> to vector<4x16xi16>
- gpu.yield %bitcast : vector<4x16xi16>
- }
- "some_user_op"(%r) : (vector<4x1xi16>) -> ()
- gpu.return
+gpu.func @vector_bitcast(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+ : () -> (vector<4x32xi8>)
+ %bitcast = vector.bitcast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<4x32xi8> to vector<4x16xi16>
+ gpu.yield %bitcast : vector<4x16xi16>
}
+ "some_user_op"(%r) : (vector<4x1xi16>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
-gpu.module @xevm_module {
- gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
- : () -> (vector<16xf32>)
- %cast = vector.shape_cast %cst
- {
- layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16xf32> to vector<1x16xf32>
- gpu.yield %cast : vector<1x16xf32>
- }
- "some_user_op"(%r) : (vector<1x1xf32>) -> ()
- gpu.return
+gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.yield %cast : vector<1x16xf32>
}
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
}
-// -----
+
// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
-gpu.module @xevm_module {
- gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : () -> (vector<1x16xf32>)
- %cast = vector.shape_cast %cst
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
- }
- : vector<1x16xf32> to vector<16xf32>
- gpu.yield %cast : vector<16xf32>
- }
- "some_user_op"(%r) : (vector<1xf32>) -> ()
- gpu.return
+gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<1x16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }
+ : vector<1x16xf32> to vector<16xf32>
+ gpu.yield %cast : vector<16xf32>
}
+ "some_user_op"(%r) : (vector<1xf32>) -> ()
+ gpu.return
}
-// -----
+
// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
//
// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
@@ -634,21 +589,335 @@ gpu.module @xevm_module {
// CHECK: }
// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
// CHECK: gpu.return
-gpu.module @xevm_module {
- gpu.func @vector_shapecast_unsupported(%laneid: index) {
- %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
- : () -> (vector<16xf32>)
- %cast = vector.shape_cast %cst
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16xf32> to vector<1x16xf32>
- gpu.yield %cast : vector<1x16xf32>
+gpu.func @vector_shapecast_unsupported(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.yield %cast : vector<1x16xf32>
+ }
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
+// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x16xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x16xf32> to vector<8x16xf32>
+ gpu.yield %1 : vector<8x16xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_non_distributed
+// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x1xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x1xf32>, vector<24x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_non_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x1xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 1], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x1xf32> to vector<8x1xf32>
+ gpu.yield %1 : vector<8x1xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x64xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x64xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x64xf32> to vector<8x16xf32>
+ gpu.yield %1 : vector<8x16xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
+// CHECK: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK-NEXT: "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
+gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
+ %0 = "some_def"() : () -> (vector<32x16xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+ }
+ : vector<32x16xf32> to vector<16x16xf32>
+ gpu.yield %1 : vector<16x16xf32>
+ }
+ "some_use"(%r) : (vector<1x16xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
+// CHECK: %[[S:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<32xf32>, vector<64xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<2xf32>) -> ()
+gpu.func @vector_extract_strided_slice_1d(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<64xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_use"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<64xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<64xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_use"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.extract_strided_slice
+gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %0 = "some_def"() : () -> (vector<54xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<54xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_use"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
+// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<16x16xf32>)
+ %1 = "some_def"() : () -> (vector<64x16xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x16xf32> into vector<64x16xf32>
+ gpu.yield %2 : vector<64x16xf32>
+ }
+ "some_use"(%r) : (vector<64x1xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_non_distributed
+// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x1xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x1xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_non_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<16x1xf32>)
+ %1 = "some_def"() : () -> (vector<64x1xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
- "some_user_op"(%r) : (vector<1x1xf32>) -> ()
- gpu.return
+ : vector<16x1xf32> into vector<64x1xf32>
+ gpu.yield %2 : vector<64x1xf32>
}
+ "some_use"(%r) : (vector<64x1xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x32xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x2xf32>) -> ()
+gpu.func @vector_insert_strided_slice_inner_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x2xf32>) {
+ %0 = "some_def"() : () -> (vector<16x16xf32>)
+ %1 = "some_def"() : () -> (vector<64x32xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x16xf32> into vector<64x32xf32>
+ gpu.yield %2 : vector<64x32xf32>
+ }
+ "some_use"(%r) : (vector<64x2xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_outer_distributed
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3x32xf32>, vector<1x16xf32>, vector<3x32xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48x32xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48x32xf32>, vector<16x16xf32>, vector<48x32xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3x32xf32>) -> ()
+gpu.func @vector_insert_strided_slice_outer_distributed(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3x32xf32>) {
+ %0 = "some_def"() : () -> (vector<16x16xf32>)
+ %1 = "some_def"() : () -> (vector<48x32xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4], strides = [1, 1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+ }
+ : vector<16x16xf32> into vector<48x32xf32>
+ gpu.yield %2 : vector<48x32xf32>
+ }
+ "some_use"(%r) : (vector<3x32xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>, vector<1xf32>, vector<3xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48xf32>, vector<16xf32>, vector<48xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3xf32>) -> ()
+gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<16xf32>)
+ %1 = "some_def"() : () -> (vector<48xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16xf32> into vector<48xf32>
+ gpu.yield %2 : vector<48xf32>
+ }
+ "some_use"(%r) : (vector<3xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_source
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_source(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<8xf32>)
+ %1 = "some_def"() : () -> (vector<48xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<8xf32> into vector<48xf32>
+ gpu.yield %2 : vector<48xf32>
+ }
+ "some_use"(%r) : (vector<3xf32>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_offset
+// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
+// CHECK: }
+// CHECK-NOT: %{{.*}} = vector.insert_strided_slice
+gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) {
+ %0 = "some_def"() : () -> (vector<16xf32>)
+ %1 = "some_def"() : () -> (vector<48xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [3], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16xf32> into vector<48xf32>
+ gpu.yield %2 : vector<48xf32>
+ }
+ "some_use"(%r) : (vector<3xf32>) -> ()
+ gpu.return
+}
+
}
More information about the Mlir-commits
mailing list