[Mlir-commits] [mlir] 2e713af - [MLIR][XeGPU] refine verifier for TensorDescType (#137226)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 29 06:57:36 PDT 2025
Author: Chao Chen
Date: 2025-04-29T08:57:33-05:00
New Revision: 2e713af20e58a1ca005fcae9165fda3007c0e400
URL: https://github.com/llvm/llvm-project/commit/2e713af20e58a1ca005fcae9165fda3007c0e400
DIFF: https://github.com/llvm/llvm-project/commit/2e713af20e58a1ca005fcae9165fda3007c0e400.diff
LOG: [MLIR][XeGPU] refine verifier for TensorDescType (#137226)
This PR updates the verifier of TensorDescType after the extension of
LayoutAttr in #132425.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/test/Dialect/XeGPU/invalid.mlir
mlir/test/Dialect/XeGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index d6c51d20571fd..8e2784f40ad39 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -25,12 +25,14 @@ class TensorDescType;
} // namespace xegpu
} // namespace mlir
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
+
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
+
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index fb5a1e6f1db0c..549018b61d6fb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -36,6 +36,12 @@ def XeGPU_Dialect : Dialect {
let useDefaultTypePrinterParser = true;
let useDefaultAttributePrinterParser = true;
+
+ let extraClassDeclaration = [{
+ /// Checks if the given shape can be evenly distributed based on the layout
+ /// and data factors provided by the LayoutAttr.
+ static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+ }];
}
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUDIALECT_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index b865b80f0075e..b2d217d192934 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,12 +6,15 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
+using std::optional;
+
namespace mlir {
namespace xegpu {
@@ -30,6 +33,71 @@ void XeGPUDialect::initialize() {
>();
}
+// Checks if the given shape can be evenly distributed based on the layout
+// and data factors provided by the LayoutAttr.
+bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
+ xegpu::LayoutAttr attr) {
+ assert(attr && "Layout attribute is missing.");
+
+ // Checks whether the given shape can be evenly distributed using the
+ // specified layout and data attributes. If successful, it returns the work
+ // size for each compute unit; otherwise, it returns `std::nullopt`. The work
+ // size per compute unit is calculated as follows:
+ // - If `data` is null: newShape[i] = shape[i] / layout[i]
+ // - If `data` is not null: newShape[i] = data[i]
+ // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
+ // smaller than `layout[i] * data[i]`, allowing multiple compute units to
+ // share the data.
+ auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
+ DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
+ bool rr = true) -> optional<SmallVector<int64_t>> {
+ llvm::SmallVector<int64_t> newShape(shape);
+ if (layout) {
+ auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
+ if (vec.size() != shape.size())
+ return std::nullopt;
+ auto ratio = computeShapeRatio(shape, vec);
+ if (!ratio.has_value())
+ return std::nullopt;
+ newShape = ratio.value();
+ }
+
+ if (data) {
+ auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
+ if (vec.size() != shape.size())
+ return std::nullopt;
+ auto ratio = computeShapeRatio(newShape, vec);
+ if (!ratio.has_value() && rr)
+ ratio = computeShapeRatio(vec, newShape);
+ if (!ratio.has_value())
+ return std::nullopt;
+
+ // if data is not null, we always return it for next phase.
+ newShape = vec;
+ }
+ return newShape;
+ };
+
+ // check the sgLayout and sgData
+ auto maybeSgShape =
+ tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
+ if (!maybeSgShape)
+ return false;
+ auto sgShape = maybeSgShape.value();
+
+ // check InstData, it neither have layout nor need round-robin
+ auto maybeInstShape =
+ tryDistribute(sgShape, nullptr, attr.getInstData(), false);
+ if (!maybeInstShape)
+ return false;
+ auto instShape = maybeInstShape.value();
+
+ // check LaneLayout and LaneData
+ auto maybeLaneShape =
+ tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
+ return maybeLaneShape.has_value();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_BlockTensorDescAttr
//===----------------------------------------------------------------------===//
@@ -241,7 +309,7 @@ LogicalResult TensorDescType::verify(
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
- // Low-pressure types are packed in 32-bit units.
+ // Low-precision types are packed in 32-bit units.
int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth();
if (rank != 1 && rank != 2)
return emitError() << "expected 1D or 2D tensor";
@@ -268,23 +336,21 @@ LogicalResult TensorDescType::verify(
}
}
- if (auto blockAttr =
- mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
+ auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
+ if (blockAttr) {
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
if (rank == 2 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
return emitError() << "SLM is not supported for 2D block tensor";
}
- if (auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout)) {
-
+ auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
+ if (layoutAttr) {
if (rank != (size_t)layoutAttr.getRank())
return emitError() << "expected layout rank to match tensor rank";
- ArrayRef<int32_t> laneLayout = layoutAttr.getLaneLayout().asArrayRef();
- ArrayRef<int32_t> laneData = layoutAttr.getLaneData().asArrayRef();
-
- if (scatterAttr) {
+ auto laneData = layoutAttr.getLaneData();
+ if (scatterAttr && laneData) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
@@ -294,20 +360,19 @@ LogicalResult TensorDescType::verify(
if (rank > 1 && laneData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";
- if (laneData.back() != packingFactor)
+ if (laneData[rank - 1] != packingFactor)
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
}
- for (size_t i = 0; i < shape.size(); ++i) {
- uint32_t numElemPerWi = laneLayout[i] * laneData[i];
- if (shape[i] < numElemPerWi || shape[i] % numElemPerWi != 0)
- return emitError() << "cannot distribute " << shape[i] << " over "
- << laneLayout[i] << " work items with "
- << laneData[i] << " elements each";
+ if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
+ std::string shapeStr;
+ llvm::raw_string_ostream stream(shapeStr);
+ llvm::interleaveComma(shape, stream);
+ return emitError() << "cannot distribute [" << shapeStr << "] using "
+ << layoutAttr;
}
}
-
return success();
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e0e25365220b5..f9d7e013826ed 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,34 +73,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
-// Checks if the given shape is evenly distributed based on the layout
-// and data factors provided by the LayoutAttr. The function ensures that
-// each dimension of the shape can be evenly divided by the corresponding
-// data factor, and the resulting quotient can be evenly divided by the
-// layout factor. Returns `true` if the shape is evenly distributed,
-// otherwise `false`.
-static bool isEvenDistributed(llvm::ArrayRef<int64_t> shape,
- xegpu::LayoutAttr attr) {
- assert(attr && "Layout attribute is missing.");
- llvm::SmallVector<int32_t> defaults(shape.size(), 1);
- llvm::ArrayRef<int32_t> layout, data;
- if (auto sg_layout = attr.getSgLayout()) {
- layout = sg_layout.asArrayRef();
- auto sg_data = attr.getSgData();
- data = sg_data ? sg_data.asArrayRef() : defaults;
- } else {
- layout = attr.getLaneLayout().asArrayRef();
- auto lane_data = attr.getLaneData();
- data = lane_data ? lane_data.asArrayRef() : defaults;
- }
- for (auto [dimSize, dataFactor, layoutFactor] :
- llvm::zip_equal(shape, data, layout)) {
- if (dimSize % dataFactor != 0 || (dimSize / dataFactor) % layoutFactor != 0)
- return false;
- }
- return true;
-}
-
static LogicalResult
isValidGatherScatterParams(Type maskTy, VectorType valueTy,
TensorDescType tdescTy, UnitAttr transposeAttr,
@@ -685,10 +657,10 @@ LogicalResult ConvertLayoutOp::verify() {
"expected srcMap and resMap be WgLayout or SgLayout at the same time.");
auto shape = getSource().getType().getShape();
- if (!isEvenDistributed(shape, srcMap))
+ if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
return emitOpError("invalid srcMap, data cannot be evenly distributed.");
- if (!isEvenDistributed(shape, resMap))
+ if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
return emitOpError("invalid resMap, data cannot be evenly distributed.");
return mlir::success();
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 67ed89e11b4c9..b05c317231ad9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -29,6 +29,27 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
return
}
+// -----
+func.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
+ // expected-error at +1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>}}
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>>
+ return
+}
+
+// -----
+func.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
+ // expected-error at +1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>}}
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>>
+ return
+}
+
+// -----
+func.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
+ // expected-error at +1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>}}
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>>
+ return
+}
+
// -----
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -77,6 +98,17 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
return
}
+// -----
+func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
+ return
+}
+
// -----
func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
@@ -87,13 +119,10 @@ func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
}
// -----
-func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
- %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- !xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor}}
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
- l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
+func.func @test_load_nd_simt(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // expected-error at +1 {{TensorDesc doesn't need LayoutAttr for SIMT code}}
+ %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8xf32>
return
}
@@ -135,6 +164,14 @@ func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) {
return
}
+// -----
+func.func @test_store_nd_simt(%src: memref<24x32xf32>, %data: vector<8xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // expected-error at +1 {{TensorDesc doesn't need LayoutAttr for SIMT code}}
+ xegpu.store_nd %data, %1 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ return
+}
+
// -----
func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
@@ -404,7 +441,7 @@ func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot distribute 8 over 16 work items with 1 elements each}}
+ // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}}
!xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
return
}
@@ -412,7 +449,7 @@ func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
+ // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}}
!xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
return
}
@@ -420,7 +457,7 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot distribute 4 over 2 work items with 4 elements each}}
+ // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>}}
!xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
return
}
@@ -428,7 +465,7 @@ func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
+ // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 2]>}}
!xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 2]>>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 71e7e9bdda07d..76af59d6aedc7 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -95,6 +95,27 @@ gpu.func @test_create_nd_tdesc_simt_6(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_subgroup_1(%[[arg0:.*]]: memref<128x128xf32>) {
+gpu.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_subgroup_2(%[[arg0:.*]]: memref<128x128xf32>) {
+gpu.func @test_create_nd_tdesc_subgroup_2(%src: memref<128x128xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_subgroup_3(%[[arg0:.*]]: memref<128x128xf32>) {
+gpu.func @test_create_nd_tdesc_subgroup_3(%src: memref<128x128xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -127,8 +148,8 @@ gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
gpu.func @test_load_nd_simt(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
- %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
gpu.return
}
More information about the Mlir-commits
mailing list