[Mlir-commits] [mlir] [MLIR][XeGPU] Remove the transpose attribute from LoadGatherOp and StoreScatterOp (PR #145389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 12:00:42 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Chao Chen (chencha3)
<details>
<summary>Changes</summary>
As the title suggests. This PR removes the transpose attribute from the definition of `LoadGatherOp` and `StoreScatterOp`. It is meaningful in the context of SIMD lowering pipeline, but not for SIMT lowering pipeline.
---
Patch is 45.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145389.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+12-22)
- (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+2)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+3-24)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+39-29)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+5-10)
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+14-22)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+18-27)
- (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+23-23)
- (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+28-28)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e6c7efc47593f..ffc08e9b90b56 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -609,12 +609,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
let description = [{ It (aka. load) load data per each work-item. The output
describes the data being loaded at the subgroup level, so its size is
consistent with the number of work-items in a subgroup. When the chunk size
- is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
- to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
- Specially, there is a transpose effect on the result (as compared to the TensorDesc)
- due to the hardware implementation. Therefore, a transpose attribute is introduced
- on purpose, making sure users are aware of this implicit transformation.
-
+ is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
+ to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
The mask operand masks out memory access so that it is safe to pass out-of-boundary
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
@@ -634,8 +630,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
Example 2:
```mlir
- %2 = xegpu.load %1, %0 {transpose,
- l1_hint = #xegpu.cache_hint<cached>,
+ %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
@@ -643,20 +638,18 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
```
Example 3 (SIMT mode):
```mlir
- %2 = xegpu.load %1, %0 {transpose,
- l1_hint = #xegpu.cache_hint<cached>,
+ %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
- vector<16xi1> -> vector<8x1xf32>
+ vector<16xi1> -> vector<8xf32>
```
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
XeGPU_MaskType: $mask,
- OptionalAttr<UnitAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -714,19 +707,17 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
Example 2:
```mlir
- xegpu.store %0, %1, %2 {transpose,
- l1_hint = #xegpu.cache_hint<uncached>,
- l2_hint = #xegpu.cache_hint<write_back>,
- l3_hint = #xegpu.cache_hint<write_through>}
+ xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+ l2_hint = #xegpu.cache_hint<write_back>,
+ l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
```
Example 3 (SIMT mode):
```mlir
- xegpu.store %0, %1, %2 {transpose,
- l1_hint = #xegpu.cache_hint<uncached>,
- l2_hint = #xegpu.cache_hint<write_back>,
- l3_hint = #xegpu.cache_hint<write_through>}
- : vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
+ xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+ l2_hint = #xegpu.cache_hint<write_back>,
+ l3_hint = #xegpu.cache_hint<write_through>}
+ : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
```
@@ -736,7 +727,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
XeGPU_MaskType: $mask,
- OptionalAttr<UnitAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 772cf73649646..09311e6017d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault =
16; // Minimum packing size per register for DPAS A.
constexpr unsigned packedSizeInBitsForDpasB =
32; // Minimum packing size per register for DPAS B.
+constexpr unsigned packedSizeInBitsForGatherScatter =
+ 32; // Minimum packing size per register for Gather and Scatter ops.
} // namespace targetinfo
} // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0afc502c026f7..f0fb03d4f1139 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,13 +20,6 @@
namespace mlir {
namespace xegpu {
-static void transpose(llvm::ArrayRef<int64_t> trans,
- SmallVector<int64_t> &shape) {
- SmallVector<int64_t> old = shape;
- for (size_t i = 0; i < trans.size(); i++)
- shape[i] = old[trans[i]];
-}
-
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
@@ -76,7 +69,7 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
static LogicalResult
isValidGatherScatterParams(Type maskTy, VectorType valueTy,
- TensorDescType tdescTy, UnitAttr transposeAttr,
+ TensorDescType tdescTy,
function_ref<InFlightDiagnostic()> emitError) {
if (!tdescTy.isScattered())
@@ -102,17 +95,9 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
if (tdescTy.getLayoutAttr())
return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
- if (transposeAttr)
- return emitError() << "doesn't need TransposeAttr for SIMT code";
return success();
}
- if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
- if (!transposeAttr)
- return emitError() << "rank-2 tensor has to be transposed.";
- transpose({1, 0}, tdescShape);
- }
-
if (tdescShape != valueShape)
return emitError() << "Value shape " << makeString(valueShape)
<< " is neither a valid distribution for SIMT nor "
@@ -310,13 +295,9 @@ LogicalResult LoadNdOp::verify() {
if (getTranspose()) {
auto trans = getTranspose().value();
-
// Make sure the transpose value is valid.
- bool valid = llvm::all_of(
- trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); });
-
- if (valid)
- transpose(trans, tdescShape);
+ if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
+ tdescShape = applyPermutation(tdescShape, trans);
else
mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
}
@@ -536,7 +517,6 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- getTransposeAttr(),
[&]() { return emitOpError(); });
}
@@ -558,7 +538,6 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- getTransposeAttr(),
[&]() { return emitOpError(); });
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index cc22d2bbd8c39..60ccd823775a5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -213,6 +213,35 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
LaneData({1, packingFactor}));
}
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+ // Expecting a 1D or 2D vector.
+ assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
+ "Expected 1D or 2D TensorDesc.");
+ // Expecting int or float element type.
+ assert(tdescTy.getElementType().isIntOrFloat() &&
+ "Expected int or float element type.");
+ // If the rank is 1, then return default layout for 1D vector.
+ if (tdescTy.getRank() == 1)
+ return getDefaultSIMTLayoutInfo(1);
+ // Packing factor is determined by the element type bitwidth.
+ unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+
+ if (tdescTy.isScattered()) {
+ int packingFactor =
+ xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
+ return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+ LaneData({1, packingFactor}));
+ }
+
+ int packingFactor =
+ (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+ ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
+ : 1;
+ return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+ LaneData({1, packingFactor}));
+}
+
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
/// is set according to the following criteria:
/// * For A operand, the data must be packed in minimum
@@ -379,8 +408,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
// Here we assign the default layout to the tensor descriptor operand of
// prefetch.
auto tdescTy = prefetch.getTensorDescType();
- auto prefetchLayout = getDefaultSIMTLayoutInfo(
- VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+ auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
@@ -516,24 +544,14 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo valueLayout = results[0]->getValue();
- // Need the layout of the value to propagate to the tensor descriptor.
- if (!valueLayout.isAssigned())
- return;
+ // The layout is strictly determined by the tensor descriptor type.
+ LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
- LayoutInfo tensorDescLayout = valueLayout;
- if (load.getTranspose()) {
- // LoadGatherOp has the transpose effect. However, at the stage of this
- // analyis this effect is not expected and should be abstracted away. Emit
- // a warning.
- load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
- "LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
- }
// Mask operand should have 1D default layout.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+
// Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the new layout to the mask operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
}
@@ -567,21 +585,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
"Expected the first dimension of 2D tensor descriptor to be equal to "
"subgroup size.");
- LayoutInfo valueLayout =
- getDefaultSIMTLayoutInfo(storeScatter.getValueType());
- LayoutInfo storeScatterLayout = valueLayout;
- if (storeScatter.getTranspose()) {
- // StoreScatteOp allows transpose effect. However, at the stage of this
- // analyis this effect is not expected and should be abstracted away. Emit
- // a warning.
- storeScatter.emitWarning("Transpose effect is not expected for "
- "StoreScatterOp at LayoutInfoPropagation stage.");
- storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
- }
+ LayoutInfo layout =
+ getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+
// Propagate the value layout.
- propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the tensor descriptor layout.
- propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+ propagateIfChanged(operands[1], operands[1]->meet(layout));
// Use default 1D layout for mask operand.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 0457f8128b908..be39ee1f0b53f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -507,8 +507,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
for (int64_t i = 0; i < numNewChunks; ++i)
convertedMasks.push_back(mask);
}
- // This is to handle the transpose effect when chunkSize > 1.
- std::swap((*targetShape)[0], (*targetShape)[1]);
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
@@ -519,8 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
auto newOp = rewriter.create<xegpu::LoadGatherOp>(
- loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
newOps.push_back(newOp);
}
@@ -598,9 +596,6 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
convertedMasks.push_back(mask);
}
}
- // This is to handle the transpose effect when chunkSize > 1.
- std::swap((*targetShape)[0], (*targetShape)[1]);
-
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
convertedMasks =
@@ -616,9 +611,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
Value v = convertedValues[i];
Value t = convertedTdescs[i];
Value m = op.getMask() ? convertedMasks[i] : nullptr;
- rewriter.create<xegpu::StoreScatterOp>(
- loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ rewriter.create<xegpu::StoreScatterOp>(loc, v, t, m, op.getL1HintAttr(),
+ op.getL2HintAttr(),
+ op.getL3HintAttr());
}
rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 054c4d12fdb28..5ceb548221758 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -199,8 +199,8 @@ gpu.func @simt_load_nd_7(%src: memref<24x32xf16>) {
gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
- // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
gpu.return
}
@@ -235,8 +235,6 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
gpu.return
}
-
-
// CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -248,7 +246,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
gpu.return
}
-
// CHECK: func @simt_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
@@ -318,8 +315,8 @@ gpu.func @subgroup_load(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<4x2xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<4x2xf32>
gpu.return
}
@@ -370,8 +367,8 @@ gpu.func @subgroup_load_3(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
- //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<4x8xf16>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.te...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/145389
More information about the Mlir-commits
mailing list