[Mlir-commits] [mlir] b86fef8 - [mlir][xegpu] Create a test pass for subgroup distribution. (#161592)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 3 12:35:17 PDT 2025
Author: Charitha Saumya
Date: 2025-10-03T12:35:13-07:00
New Revision: b86fef88c54b20812ed3f1a67869dcaa54a07882
URL: https://github.com/llvm/llvm-project/commit/b86fef88c54b20812ed3f1a67869dcaa54a07882
DIFF: https://github.com/llvm/llvm-project/commit/b86fef88c54b20812ed3f1a67869dcaa54a07882.diff
LOG: [mlir][xegpu] Create a test pass for subgroup distribution. (#161592)
Current subgroup distribution test employ the entire
`xegpu-subgroup-distribute` pass which include multiple steps like
layout propagation, move func body into warp op, and distribute to work
items.
This makes it harder to isolate the testing for xegpu subgroup
distribution logic, because certain corner cases may be not supported
yet by other steps mentioned above.
This PR introduces a test pass for subgroup distribution logic and
isolate the testing for distribution logic. We plan to add more corner
case (that were not possible before) covering non-xegpu ops (like
vector) in next PRs.
This PR also include,
1. minor bug fixes in gather/scatter distribution.
2. bug fix in vector multi reduction lowering where it fails to retain
some layouts.
Added:
mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 83b128e2c7cbf..564d9c4d5422b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -27,10 +27,6 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
"vector::VectorDialect"];
- let options = [Option<
- "enableSGReductions", "enable-sg-reductions", "bool",
- /*default=*/"true",
- "Enable subgroup reductions using subgroup shuffles.">];
}
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 882691fd19f58..f1dbc5ddb2022 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -875,14 +875,17 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
storeScatterOp,
"Some vector operands have no layouts, using defaults instead.");
}
- VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
- VectorType expectedPayloadTy = VectorType::get(
- {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
+ // Distributed store payload type according to the lane layout.
+ VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
+ // Expected distributed payload type is always 1D.
+ VectorType expectedPayloadTy =
+ VectorType::get({distPayloadTyByWarpOp.getNumElements()},
+ distPayloadTyByWarpOp.getElementType());
SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
SmallVector<Type> operandTypesToYield = {
- expectedPayloadTy, operands[1].getType(),
+ distPayloadTyByWarpOp, operands[1].getType(),
distOffsetsByWarpOpOrFailure.value(),
distMaskByWarpOpOrFailure.value()};
@@ -890,8 +893,11 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
-
+ // The payload operand may need type adjustment due to mismatch between warp
+ // distributed type and expected SIMT type.
rewriter.setInsertionPointAfter(newWarpOp);
+ newStoreScatterOpOperands[0] = resolveDistributedTy(
+ newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
@@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
distMaskByWarpOpOrFailure.value()};
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
- VectorType loadVecTy =
+ VectorType distResultTy =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Distributed load op will always be 1D.
+ VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
+ distResultTy.getElementType());
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -991,13 +1000,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
loadGatherOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
- rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ distributedVal,
+ resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
return success();
}
};
/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
-/// VectorReductionOps.
+/// VectorReductionOps. We also insert layouts for the newly created ops.
static Value lowerToVectorReductions(TypedValue<VectorType> src,
TypedValue<VectorType> acc,
vector::CombiningKind kind,
@@ -1014,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
Value reductionResult = arith::ConstantOp::create(
rewriter, loc, acc.getType(),
DenseElementsAttr::get(acc.getType(), zeroAttr));
+ // Reduction result should have the same layout as the accumulator.
+ xegpu::setDistributeLayoutAttr(cast<OpResult>(reductionResult),
+ xegpu::getDistributeLayoutAttr(acc));
// For each slice of the source, extract the slice vector, do a reduction
// and, insert the reduced value back to the result vector.
for (int i = 0; i < nSlices; ++i) {
@@ -1029,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
sliceSizes, {1, 1});
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
- Value slice = vector::ShapeCastOp::create(
+ vector::ShapeCastOp slice = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get({nSliceElements}, sourceType.getElementType()),
extractOp.getResult());
+ // Shape cast is currently handled in xegpu side. So layouts must be
+ // retained during lowering. Shape cast output has the same layout as the
+ // accumulator. Shape cast source has the same layout as the original
+ // reduction source.
+ // TODO: other ops generated here may also need layout attributes.
+ xegpu::setDistributeLayoutAttr(slice->getOpOperand(0),
+ xegpu::getDistributeLayoutAttr(src));
+ xegpu::setDistributeLayoutAttr(slice->getOpResult(0),
+ xegpu::getDistributeLayoutAttr(acc));
+ // Extract and reduction results in scalars, so no result layout is needed.
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
- Value reduction =
- vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
+ Value reduction = vector::ReductionOp::create(
+ rewriter, loc, kind, slice.getResult(), accExtract);
reductionResult =
vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
}
@@ -1107,7 +1132,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
return failure();
auto reductionOp =
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
- unsigned operandNumber = yieldOperand->getOperandNumber();
+ unsigned operandIdx = yieldOperand->getOperandNumber();
VectorType sourceType = reductionOp.getSourceVectorType();
// Only 2D vectors are supported.
if (sourceType.getRank() != 2)
@@ -1121,7 +1146,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
warpOp, "Only 1 reduction dimension is supported.");
int64_t reductionDim = reductionDims[0];
VectorType distributedResultType =
- cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
VectorType resultType = cast<VectorType>(reductionOp.getType());
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getDistributeLayoutAttr(reductionOp.getSource());
@@ -1184,7 +1209,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
// Replace the warp op result with the final result.
- rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
return success();
}
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms
@@ -1217,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
auto resultDistTy =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
+ xegpu::getDistributeLayoutAttr(shapeCastOp->getOpOperand(0));
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
if (!sourceLayout || !resultLayout)
@@ -1403,11 +1428,6 @@ namespace {
struct XeGPUSubgroupDistributePass final
: public xegpu::impl::XeGPUSubgroupDistributeBase<
XeGPUSubgroupDistributePass> {
- XeGPUSubgroupDistributePass() = default;
- XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
- default;
- XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
- : XeGPUSubgroupDistributeBase(options) {}
void runOnOperation() override;
};
} // namespace
@@ -1515,10 +1535,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return laneVal;
};
- if (enableSGReductions)
- vector::populateDistributeReduction(
- patterns, warpReduction,
- /*pattern benefit=*/regularPatternBenefit);
+ vector::populateDistributeReduction(
+ patterns, warpReduction,
+ /*pattern benefit=*/regularPatternBenefit);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn,
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
new file mode 100644
index 0000000000000..40b66d18cc47f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -0,0 +1,575 @@
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \
+// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: gpu.func @store_nd_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
+// CHECK-SAME: -> (vector<1xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, index) {
+// CHECK: gpu.yield %{{.*}} : vector<16xf32>,
+// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
+// CHECK-NEXT: xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
+gpu.module @xevm_module{
+ gpu.func @store_nd_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %cst = "some_op"() : () -> vector<16xf32>
+ xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ }
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @store_nd_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16]
+// CHECK-SAME: -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index) {
+// CHECK: gpu.yield %{{.*}} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[W]]#0 : vector<16x1xf16> to vector<16xf16>
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
+// CHECK-NEXT: xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+gpu.module @xevm_module{
+ gpu.func @store_nd_2d(%laneid : index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %cst = "some_op"() : () -> vector<16x16xf16>
+ xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ }
+ gpu.return
+ }
+}
+
+
+
+// -----
+// CHECK-LABEL: gpu.func @load_nd_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>,
+// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, index) {
+// CHECK: gpu.yield %{{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>>, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch}
+// CHECK-NEXT: xegpu.load_nd %[[T1]][%[[W]]#2] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
+gpu.module @xevm_module{
+ gpu.func @load_nd_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+ !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+ gpu.yield %1 : vector<16xf32>
+ }
+ "some_user_op"(%r) : (vector<1xf32>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_nd_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index) {
+// CHECK: gpu.yield %{{.*}} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
+// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16>
+gpu.module @xevm_module{
+ gpu.func @load_nd_2d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ gpu.yield %1 : vector<16x16xf16>
+ }
+ "some_user_op"(%r) : (vector<16x1xf16>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_nd_array_length
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>,
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index) {
+// CHECK: gpu.yield %{{.*}} : vector<2x16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<
+// CHECK-SAME: array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16],
+// CHECK-SAME: lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
+// CHECK-NEXT: vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16>
+gpu.module @xevm_module{
+ gpu.func @load_nd_array_length(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+ gpu.yield %1 : vector<2x16x16xf16>
+ }
+ "some_user_op"(%r) : (vector<2x16x1xf16>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @dpas
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] ->
+// CHECK-SAME: (vector<8x1xf32>, vector<8x1xf16>, vector<16x1xf16>, vector<8x1xf32>) {
+// CHECK: gpu.yield %{{.*}} : vector<8x16xf32>, vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>
+// CHECK-NEXT: }
+// CHECK-DAG: %[[T1:.*]] = vector.shape_cast %[[W]]#1 : vector<8x1xf16> to vector<8xf16>
+// CHECK-DAG: %[[T2:.*]] = vector.shape_cast %[[W]]#2 : vector<16x1xf16> to vector<16xf16>
+// CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+// CHECK-NEXT: vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
+gpu.module @xevm_module{
+ gpu.func @dpas(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_op"() : () -> vector<8x16xf16>
+ %1 = "some_op"() : () -> vector<16x16xf16>
+ %2 = "some_op"() : () -> vector<8x16xf32>
+ %3 = xegpu.dpas %0, %1, %2
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ gpu.yield %3 : vector<8x16xf32>
+ }
+ "some_user_op"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+ }
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, ui64) {
+// CHECK: gpu.yield %{{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, ui64
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch}
+gpu.module @xevm_module{
+ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+ %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 ->
+ !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ }
+ "some_user_op"(%r)
+ : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @prefetch_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, index, index) {
+// CHECK: gpu.yield %{{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-SAME: , index, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#0 : !xegpu.tensor_desc<16x16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch}
+// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2]
+// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
+gpu.module @xevm_module{
+ gpu.func @prefetch_2d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : ()
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %0[%c0, %c0]
+ <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ }
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @prefetch_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>>, index) {
+// CHECK: gpu.yield %{{.*}} : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, index
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#0 : !xegpu.tensor_desc<16xf16,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch}
+// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
+gpu.module @xevm_module{
+ gpu.func @prefetch_1d(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %0 = "some_op"() : ()
+ -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ xegpu.prefetch_nd %0[%c0]
+ <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ }
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
+// CHECK: gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
+// CHECK: gpu.yield %{{.*}}
+// CHECK: }
+// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
+// CHECK: gpu.barrier
+gpu.module @xevm_module{
+ gpu.func @gpu_barrier(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) {
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ %1 = xegpu.load_nd %0[%c0]
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
+ gpu.barrier
+ gpu.yield %1 : vector<16xf16>
+ }
+ "some_user_op"(%r) : (vector<1xf16>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
+// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME: -> (vector<2xf32>, vector<16x2xf32>, vector<2xf32>) {
+// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x32xf32>
+// CHECK: gpu.yield %{{.*}}, %[[SRC]], %[[ACC]] : vector<32xf32>, vector<16x32xf32>, vector<32xf32>
+// CHECK-NEXT: }
+// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16x1xf32> to vector<16xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[W]]#2[0] : f32 from vector<2xf32>
+// CHECK: %[[T4:.*]] = vector.reduction <add>, %[[T2]], %[[T3]] : vector<16xf32> into f32
+// CHECK: %[[T5:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<16x1xf32> to vector<16xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
+// CHECK: %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32>
+gpu.module @xevm_module{
+gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<16x32xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+ dense<0.0> : vector<32xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ } [0]
+ : vector<16x32xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_user_op"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
+// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK-NEXT: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
+// CHECK-NEXT: %[[T2:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %cst : vector<16xf32> into f32
+// CHECK-NEXT: %[[T4:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT: %[[T5:.*]] = vector.reduction <add>, %[[T4]], %cst : vector<16xf32> into f32
+// CHECK-NEXT: %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
+// CHECK-NEXT: gpu.yield %[[T6]] : vector<2xf32>
+// CHECK-NEXT: }
+gpu.module @xevm_module{
+gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<2x16xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+ dense<0.0> : vector<2xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
+ }
+ [1] : vector<2x16xf32> to vector<2xf32>
+ gpu.yield %1 : vector<2xf32>
+ }
+ "some_user_op"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
+// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
+// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<32x16xf32>
+// CHECK: gpu.yield %9, %[[SRC]], %[[ACC]] : vector<32xf32>, vector<32x16xf32>, vector<32xf32>
+// CHECK: }
+// CHECK: %[[T1:.*]] = vector.extract %[[W]]#1[0] : vector<16xf32> from vector<2x16xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[W]]#2[0] : f32 from vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T1]], %[[T2]] : vector<16xf32> into f32
+// CHECK: %[[T4:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32
+// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
+gpu.module @xevm_module{
+gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : () -> (vector<32x16xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>}
+ dense<0.0> : vector<32xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
+ }
+ [1] : vector<32x16xf32> to vector<32xf32>
+ gpu.yield %1 : vector<32xf32>
+ }
+ "some_user_op"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
+// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
+// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
+// CHECK: %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] {{.*}} : vector<16x1xf32> to vector<16xf32>
+// CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
+// CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
+// CHECK: gpu.yield %[[T7]] : vector<2xf32>
+// CHECK: }
+gpu.module @xevm_module{
+gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : () -> (vector<16x2xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+ dense<0.0> : vector<2xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+ }
+ [0] : vector<16x2xf32> to vector<2xf32>
+ gpu.yield %1 : vector<2xf32>
+ }
+ "some_user_op"(%r) : (vector<2xf32>) -> ()
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
+// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME: -> (vector<1x8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
+// CHECK: gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]] :
+// CHECK-SAME: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = xegpu.load %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
+// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @xevm_module{
+ gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %1 = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<1>: vector<16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+ {
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ }
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ }
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
+// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME: -> (vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
+// CHECK: gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]]
+// CHECK-SAME: : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = xegpu.load %[[W]]#1[%[[W]]#2], %[[W]]#3
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @xevm_module{
+ gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
+ gpu.warp_execute_on_lane_0(%laneid)[16] {
+ %1 = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<1> : vector<16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1
+ {
+ layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset], %1
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ }
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) {
+// CHECK: gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index
+// CHECK-NEXT: arith.index_cast %[[INTPTR]] : index to i64
+gpu.module @xevm_module{
+ gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) {
+ %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+ gpu.yield %ptr : index
+ }
+ %ptr_i64 = arith.index_cast %r : index to i64
+ "some_user_op"(%ptr_i64) : (i64) -> ()
+ gpu.return
+ }
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @vector_transpose(
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
+// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32>
+// CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
+gpu.module @xevm_module{
+ gpu.func @vector_transpose(%arg0: memref<2x16xf32>, %laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : () -> (vector<16x2xf32>)
+ %transpose = vector.transpose %cst, [1, 0]
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x2xf32> to vector<2x16xf32>
+ gpu.yield %transpose : vector<2x16xf32>
+ }
+ "some_user_op"(%r) : (vector<2x1xf32>) -> ()
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_bitcast(
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) {
+// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8>
+// CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8>
+// CHECK: }
+// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
+gpu.module @xevm_module{
+ gpu.func @vector_bitcast(%arg0: memref<4x16xi16>, %laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+ : () -> (vector<4x32xi8>)
+ %bitcast = vector.bitcast %cst
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<4x32xi8> to vector<4x16xi16>
+ gpu.yield %bitcast : vector<4x16xi16>
+ }
+ "some_user_op"(%r) : (vector<4x1xi16>) -> ()
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 59fac26d18cf4..0e1365aa64171 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -1,198 +1,76 @@
// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -xegpu-subgroup-distribute \
// RUN: -allow-unregistered-dialect -canonicalize -cse -split-input-file %s | FileCheck %s
-// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' \
-// RUN: -xegpu-subgroup-distribute="enable-sg-reductions=false" -allow-unregistered-dialect \
-// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION
-
-// CHECK-LABEL: gpu.func @store_nd_1d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
-// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK: xegpu.store_nd %[[CST]], %[[T0]][%{{.*}}] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-// CHECK: gpu.return
-gpu.module @xevm_module{
- gpu.func @store_nd_1d(%arg0: memref<16xf32>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.store_nd %cst, %0 [%c0] : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @store_nd_2d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf16>
-// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.store_nd %[[CST]], %[[T0]][%{{.*}}] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @store_nd_2d(%arg0: memref<16x16xf16>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf16>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %cst, %0 [%c0, %c0] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
-
-
-// -----
-// CHECK-LABEL: gpu.func @load_nd_1d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
-// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK: xegpu.store_nd %[[T1]], %[[T2]][%{{.*}}] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @xevm_module{
- gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
- %2 = xegpu.create_nd_tdesc %arg1 : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.store_nd %1, %2 [%c0] : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @load_nd_2d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.store_nd %[[T1]], %[[T2]][%{{.*}}] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1: memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %1, %2[%c0, %c0] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @load_nd_array_length
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
-// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<32xf16> to vector<2x16x1xf16>
-// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<16x1xf16> from vector<2x16x1xf16>
-// CHECK-DAG: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-DAG: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
-// CHECK: xegpu.store_nd %[[T5]], %[[T4]][%{{.*}}] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
- %2 = vector.extract %1[%c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> from vector<2x16x16xf16>
- %3 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %2, %3[%c0, %c0] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @load_dpas_store
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]][%{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]][%{{.*}}] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-gpu.module @xevm_module{
- gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1: memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
- %3 = xegpu.load_nd %2[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
- %4 = xegpu.dpas %1, %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- %5 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %4, %5[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
-
-// -----
// CHECK-LABEL: gpu.func @load_dpas_postop_store
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]][%{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
-// CHECK-DAG: %[[T8:.*]] = vector.shape_cast %[[T6]] : vector<8x1xf32> to vector<8xf32>
-// CHECK-DAG: %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[T8]], %[[T7]][{{.*}}] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]][%{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
+// CHECK-DAG: %[[T8:.*]] = vector.shape_cast %[[T6]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-DAG: %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[T8]], %[[T7]][{{.*}}] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.module @xevm_module{
gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1: memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
- %3 = xegpu.load_nd %2[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
- %4 = xegpu.dpas %1, %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- %5 = math.exp %4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>
- %6 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %5, %6[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16>
+ -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0]
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+ !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+
+ %2 = xegpu.create_nd_tdesc %arg1: memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ %3 = xegpu.load_nd %2[%c0, %c0]
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<16x16xf16>
+
+ %4 = xegpu.dpas %1, %3
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// -----
-// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
-// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
-// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]], shape : [%[[ARG2]], %[[ARG3]]], strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]][{{.*}}] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]], shape : [%[[ARG2]], %[[ARG3]]], strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.store_nd %[[T1]], %[[T2]][{{.*}}] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0, shape:[%arg2, %arg3], strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1, shape:[%arg2, %arg3], strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %1, %2[%c0, %c0] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %5 = math.exp %4
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xf32>
+
+ %6 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %5, %6[%c0, %c0] : vector<8x16xf32>,
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
}
// -----
-// TODO: gemm does not use update_nd_offset because of an issue in scf-for distribution.
// CHECK-LABEL: gpu.func @gemm
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
-// CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
-// CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
-// CHECK-DAG: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
-// CHECK-DAG: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T2]][%[[X_COORD]], %[[Y_COORD]]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-// CHECK-NEXT: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
-// CHECK-DAG: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-// CHECK-DAG: %[[T11:.*]] = xegpu.load_nd %[[T10]][%[[K]], %[[Y_COORD]]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
-// CHECK-DAG: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-// CHECK-DAG: %[[T13:.*]] = xegpu.load_nd %[[T12]][%[[X_COORD]], %[[K]]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
-// CHECK-DAG: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
-// CHECK-NEXT: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
-// CHECK-NEXT: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
-// CHECK-NEXT: scf.yield %[[T16]] : vector<8x1xf32>
-// CHECK-NEXT: }
-// CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
-// CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]][%[[X_COORD]], %[[Y_COORD]]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
+// CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK-DAG: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
+// CHECK-DAG: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T2]][%[[X_COORD]], %[[Y_COORD]]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]])
+// CHECK-SAME: -> (vector<8x1xf32>) {
+// CHECK-DAG: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK-DAG: %[[T11:.*]] = xegpu.load_nd %[[T10]][%[[K]], %[[Y_COORD]]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG: %[[T13:.*]] = xegpu.load_nd %[[T12]][%[[X_COORD]], %[[K]]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK-DAG: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]]
+// CHECK-SAME: : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-NEXT: scf.yield %[[T16]] : vector<8x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]][%[[X_COORD]], %[[Y_COORD]]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.module @xevm_module{
gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
%c0 = arith.constant 0 : index
@@ -203,213 +81,56 @@ gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %ar
%block_id_y = gpu.block_id y
%0 = arith.muli %block_id_x, %c8 : index
%1 = arith.muli %block_id_y, %c16 : index
- %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %3 = xegpu.load_nd %2[%0, %1] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
- %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
- %5 = xegpu.create_nd_tdesc %arg0: memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %6 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
- %7 = xegpu.load_nd %5[%0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
- %8 = xegpu.load_nd %6[%arg3, %1] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
- %9 = xegpu.dpas %7, %8, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
- scf.yield %9 : vector<8x16xf32>
- } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- xegpu.store_nd %4, %2[%0, %1] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
-}
-}
+ %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %3 = xegpu.load_nd %2[%0, %1]
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
-// -----
-// CHECK-LABEL: gpu.func @prefetch_2d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.prefetch_nd %[[T0]][%{{.*}}] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
-gpu.module @xevm_module{
- gpu.func @prefetch_2d(%arg0: memref<256x256xf16>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.prefetch_nd %0[%c0, %c0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @prefetch_1d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK: xegpu.prefetch_nd %[[T0]][%{{.*}}] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
-gpu.module @xevm_module{
- gpu.func @prefetch_1d(%arg0: memref<256xf16>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0: memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.prefetch_nd %0[%c0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- gpu.return
- }
-}
+ %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
-// -----
-// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: %[[T1:.*]] = xegpu.load_nd %[[T0]][{{.*}}] : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
-// CHECK-NEXT: gpu.barrier
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: xegpu.store_nd %[[T1]], %[[T2]][{{.*}}] : vector<1xf16>, !xegpu.tensor_desc<16xf16>
-gpu.module @xevm_module{
- gpu.func @gpu_barrier(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- %1 = xegpu.load_nd %0[%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
- gpu.barrier
- %2 = xegpu.create_nd_tdesc %arg1 : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.store_nd %1, %2[%c0] : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- gpu.return
- }
-}
+ %5 = xegpu.create_nd_tdesc %arg0: memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %6 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-// -----
-// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
-// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] ->
-// CHECK-SAME: (!xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, vector<16x2xf32>) {
-// CHECK: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x32xf32>
-// CHECK-NEXT: gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, vector<16x32xf32>
-// CHECK-NEXT: }
-// CHECK: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32>
-// CHECK-NEXT: %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %{{.*}} : vector<16xf32> into f32
-// CHECK: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK-NEXT: %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32>
-// CHECK-NEXT: %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %{{.*}} : vector<16xf32> into f32
-// CHECK-NEXT: vector.from_elements %[[RED0]], %[[RED1]] : vector<2xf32>
-gpu.module @xevm_module{
-gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction() {
- %c0 = arith.constant 0 : index
- %0 = "some_def"() : () -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> (vector<16x32xf32>)
- %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.0> : vector<32xf32>
- %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0]
- : vector<16x32xf32> to vector<32xf32>
- %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : vector<32xf32> to vector<1x32xf32>
- xegpu.store_nd %3, %0[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
-}
-}
+ %7 = xegpu.load_nd %5[%0, %arg3]
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+ %8 = xegpu.load_nd %6[%arg3, %1]
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
-// -----
-// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
-// CHECK-REDUCTION: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<2x16xf32,
-// CHECK-REDUCTION-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, f32, f32) {
-// CHECK-REDUCTION: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<2x16xf32>
-// CHECK-REDUCTION-NEXT: %[[ROW0:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
-// CHECK-REDUCTION-NEXT: %[[R0:.*]] = vector.reduction <add>, %[[ROW0]], %{{.*}} : vector<16xf32> into f32
-// CHECK-REDUCTION-NEXT: %[[ROW1:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-REDUCTION-NEXT: %[[R1:.*]] = vector.reduction <add>, %[[ROW1]], %{{.*}} : vector<16xf32> into f32
-// CHECK-REDUCTION-NEXT: gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, f32, f32
-// CHECK-REDUCTION-NEXT: }
-// CHECK-REDUCTION-NEXT: vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32>
-gpu.module @xevm_module{
-gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction() {
- %c0 = arith.constant 0 : index
- %0 = "some_def"() : () -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> (vector<2x16xf32>)
- %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} dense<0.0> : vector<2xf32>
- %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
- [1] : vector<2x16xf32> to vector<2xf32>
- %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : vector<2xf32> to vector<2x1xf32>
- %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<2x1xf32> to vector<2x16xf32>
- xegpu.store_nd %4, %0[%c0, %c0] : vector<2x16xf32>, !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
-}
-}
+ %9 = xegpu.dpas %7, %8, %arg4
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-// -----
-// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
-// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%0)[16] ->
-// CHECK-SAME: (!xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<2x16xf32>) {
-// CHECK: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<32x16xf32>
-// CHECK-NEXT: gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<32x16xf32>
-// CHECK-NEXT: }
-// CHECK: %[[ROW0:.*]] = vector.extract %[[W]]#1[0] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT: %[[R0:.*]] = vector.reduction <add>, %[[ROW0]], %{{.*}} : vector<16xf32> into f32
-// CHECK: %[[ROW1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT: %[[R1:.*]] = vector.reduction <add>, %[[ROW1]], %{{.*}} : vector<16xf32> into f32
-// CHECK-NEXT: vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
-gpu.module @xevm_module{
-gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction() {
- %c0 = arith.constant 0 : index
- %0 = "some_def"() : () -> !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
- %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> (vector<32x16xf32>)
- %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>} dense<0.0> : vector<32xf32>
- %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>} [1]
- : vector<32x16xf32> to vector<32xf32>
- %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
- : vector<32xf32> to vector<32x1xf32>
- xegpu.store_nd %3, %0[%c0, %c0] : vector<32x1xf32>, !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
- gpu.return
-}
-}
+ scf.yield %9 : vector<8x16xf32>
+ } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-// -----
-// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
-// CHECK-REDUCTION: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<16x2xf32,
-// CHECK-REDUCTION-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, f32, f32) {
-// CHECK-REDUCTION: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<16x2xf32>
-// CHECK-REDUCTION-NEXT: %[[COL0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK-REDUCTION-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32>
-// CHECK-REDUCTION-NEXT: %[[R0:.*]] = vector.reduction <add>, %[[CAST0]], %{{.*}} : vector<16xf32> into f32
-// CHECK-REDUCTION-NEXT: %[[COL1:.*]] = vector.extract_strided_slice %5 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK-REDUCTION-NEXT: %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32>
-// CHECK-REDUCTION-NEXT: %[[R1:.*]] = vector.reduction <add>, %[[CAST1]], %cst : vector<16xf32> into f32
-// CHECK-REDUCTION-NEXT: gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, f32, f32
-// CHECK-REDUCTION-NEXT: }
-// CHECK-REDUCTION-NEXT: vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32>
-gpu.module @xevm_module{
-gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction() {
- %c0 = arith.constant 0 : index
- %0 = "some_def"() : () -> !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
- %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> (vector<16x2xf32>)
- %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<0.0> : vector<2xf32>
- %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
- [0] : vector<16x2xf32> to vector<2xf32>
- %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
- : vector<2xf32> to vector<1x2xf32>
- %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : vector<1x2xf32> to vector<16x2xf32>
- xegpu.store_nd %4, %0[%c0, %c0] : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+ xegpu.store_nd %4, %2[%0, %1] : vector<8x16xf32>,
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
}
// -----
-// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
- gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
- %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
- %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
- } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}},
-// CHECK-SAME: %[[PREDICATE:.*]]: i1) {
-// CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16>
-// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) {
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16>
-// CHECK-NEXT: } else {
-// CHECK-NEXT: scf.yield %[[DEFAULT]] : vector<8xf16>
-// CHECK-NEXT: }
-// CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+// CHECK-LABEL: gpu.func @scatter_ops_scf_yield
+// CHECK: (%{{.*}}: memref<256xf16>, %[[PREDICATE:[a-zA-Z0-9]+]]: i1) {
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.200000e+01> : vector<1x8xf16>
+// CHECK-DAG: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[IF:.*]] = scf.if %[[PREDICATE]] -> (vector<1x8xf16>) {
+// CHECK-NEXT: %[[LD:.*]] = xegpu.load %{{.*}}[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: %[[LD_CAST:.*]] = vector.shape_cast %[[LD]] : vector<8xf16> to vector<1x8xf16>
+// CHECK-NEXT: scf.yield %[[LD_CAST]] : vector<1x8xf16>
+// CHECK-NEXT: } else {
+// CHECK-NEXT: scf.yield %[[CST]] : vector<1x8xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[IF_CAST:.*]] = vector.shape_cast %[[IF]] : vector<1x8xf16> to vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[IF_CAST]], %{{.*}}[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.module @xevm_module{
gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>, %pred : i1) {
%1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
@@ -432,13 +153,15 @@ gpu.module @xevm_module{
// -----
// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) {
-// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1
-// CHECK: scf.if %[[PREDICATE]] {
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-// CHECK-NEXT: }
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1
+// CHECK: scf.if %[[PREDICATE]] {
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+// CHECK-NEXT: }
gpu.module @xevm_module{
gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) {
%pred = llvm.mlir.poison : i1
@@ -454,89 +177,14 @@ gpu.module @xevm_module{
}
}
-// -----
-// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @xevm_module{
- gpu.func @scatter_ops(%src: memref<256xf16>) {
- %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
- %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 {
- layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
- xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
-// CHECK: %{{.*}} = memref.extract_aligned_pointer_as_index %{{.*}} : memref<256x256xf16> -> index
-gpu.module @xevm_module{
- gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf16>
- %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
- %ptr_i64 = arith.index_cast %ptr : index to i64
- %tdesc = xegpu.create_nd_tdesc %ptr_i64, shape: [16], strides: [16] : i64
- -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- xegpu.store_nd %cst, %tdesc[%c0] : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
- gpu.return
- }
-}
-
-
-// -----
-// CHECK-LABEL: gpu.func @vector_transpose(
-// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2xf32>
-// CHECK: %[[DEST:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<2x16xf32> -> !xegpu.tensor_desc<2x16xf32>
-// CHECK: xegpu.store_nd %[[CST]], %[[DEST]][{{.*}}] : vector<2xf32>, !xegpu.tensor_desc<2x16xf32>
-gpu.module @xevm_module{
- gpu.func @vector_transpose(%arg0: memref<2x16xf32>) {
- %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} dense<1.000000e+00>
- : vector<16x2xf32>
- %c0 = arith.constant 0 : index
- %transpose = vector.transpose %cst, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : vector<16x2xf32> to vector<2x16xf32>
- %0 = xegpu.create_nd_tdesc %arg0 : memref<2x16xf32>
- -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %transpose, %0[%c0, %c0] : vector<2x16xf32>,
- !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @vector_bitcast(
-// CHECK: %[[CAST:.*]] = vector.bitcast %{{.*}} : vector<4x2xi8> to vector<4x1xi16>
-// CHECK-NEXT: %[[DEST:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x16xi16> -> !xegpu.tensor_desc<4x16xi16>
-// CHECK-NEXT: %[[T0:.*]] = vector.shape_cast %[[CAST]] : vector<4x1xi16> to vector<4xi16>
-// CHECK-NEXT: xegpu.store_nd %[[T0]], %[[DEST]][{{.*}}] : vector<4xi16>, !xegpu.tensor_desc<4x16xi16>
-gpu.module @xevm_module{
- gpu.func @vector_bitcast(%arg0: memref<4x16xi16>) {
- %cst = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
- : () -> (vector<4x32xi8>)
- %bitcast = vector.bitcast %cst {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- : vector<4x32xi8> to vector<4x16xi16>
- %c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4x16xi16>
- -> !xegpu.tensor_desc<4x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- xegpu.store_nd %bitcast, %0[%c0, %c0] : vector<4x16xi16>,
- !xegpu.tensor_desc<4x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
- }
-}
-
// -----
// CHECK-LABEL: gpu.func @mma_transpose_b(
// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
// CHECK-DAG: %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-DAG: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
// CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[ADESC]][%{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
-// CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
+// CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}] <{transpose = array<i64: 1, 0>}>
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
// CHECK-NEXT: %[[BCAST0:.*]] = vector.shape_cast %[[B]] : vector<8xi32> to vector<1x8xi32>
// CHECK-NEXT: %[[BCAST1:.*]] = vector.bitcast %[[BCAST0]] : vector<1x8xi32> to vector<1x16xf16>
// CHECK-NEXT: %[[BCAST2:.*]] = vector.shape_cast %[[BCAST1]] : vector<1x16xf16> to vector<16xf16>
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e51cac4286f0c..6ba7a004b7d31 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -218,6 +218,35 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
}
};
+struct TestXeGPUSGDistribute
+ : public PassWrapper<TestXeGPUSGDistribute,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUSGDistribute)
+
+ StringRef getArgument() const final { return "test-xegpu-sg-distribute"; }
+
+ StringRef getDescription() const final {
+ return "Test the implementation of XeGPU Subgroup Distribution";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<index::IndexDialect>();
+ }
+
+ TestXeGPUSGDistribute() = default;
+ TestXeGPUSGDistribute(const TestXeGPUSGDistribute &pass) = default;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestXeGPULayoutInterface
: public PassWrapper<TestXeGPULayoutInterface,
OperationPass<gpu::GPUModuleOp>> {
@@ -282,6 +311,7 @@ namespace test {
void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
PassRegistration<TestXeGPULayoutInterface>();
+ PassRegistration<TestXeGPUSGDistribute>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list