[Mlir-commits] [mlir] [mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. (PR #127920)
Charitha Saumya
llvmlistbot at llvm.org
Thu Feb 20 19:32:46 PST 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/127920
>From f92c151bf530c3b28574f63a389d1eea74c0cf4c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 17 Feb 2025 22:46:26 +0000
Subject: [PATCH 01/11] save work
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 7 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 53 ++++++++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 118 +++++++++---------
mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 84 ++++++++++---
4 files changed, 185 insertions(+), 77 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index cc2e93fb19a70..ccd91a928e1dd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -103,7 +103,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
];
-
+
let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -176,6 +176,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return scatter_attr.getChunkSize().getInt();
return 1;
}
+
+ // This returns a vector type that represents the fragment of data owned by
+ // a work item in SIMT mode if this tensor descriptor is used in a XeGPU
+ // load/store operation.
+ FailureOr<VectorType> getDistributedVectorType();
}];
let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 06fd03f3af3ad..768a3b4ef33c5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,8 +8,11 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
namespace xegpu {
@@ -307,6 +310,56 @@ LogicalResult TensorDescType::verify(
return success();
}
+// If tensor descriptor has a sg_map attribute it is used in SIMT mode. In this
+// mode, the distributed vector shape is given by the following criteria:
+// wi_data_size = wi_data[0] × wi_data[1]
+// subgroup_size = wi_layout[0] × wi_layout[1]
+// distribution_unit_size = subgroup_size × wi_data_size
+// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
+// n_distribution_units = tensor_size / distribution_unit_size
+// Given above definitions, the following conditions must be met:
+// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
+// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
+// Distributed vector shape must be: n_distribution_units × wi_data_size
+FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
+ auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+ // If no sg_map is provided, tensor desc is not used in SIMT mode.
+ if (!sgMap)
+ return failure();
+ // FIXME: Add support for scatter tensor descriptor.
+ auto scatterAttr =
+ llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+ if (scatterAttr)
+ return failure();
+
+ SmallVector<int64_t> wiData(sgMap.getWiData());
+ SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
+ auto tdescShape = getShape();
+ // Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
+ // and wiLayout must be 1.
+ if (tdescShape.size() == 1) {
+ if (wiData[0] != 1 || wiLayout[0] != 1)
+ return failure();
+ wiData = {wiData[1]};
+ wiLayout = {wiLayout[1]};
+ }
+ // Check if the tensor descriptor shape is distributable.
+ int64_t tensorSize = 1, sgSize = 1, wiDataSize = 1;
+ for (auto [tdescDim, wiDim, wiDataDim] :
+ llvm::zip_equal(tdescShape, wiLayout, wiData)) {
+ if (tdescDim % (wiDim * wiDataDim) != 0)
+ return failure();
+ tensorSize *= tdescDim;
+ sgSize *= wiDim;
+ wiDataSize *= wiDataDim;
+ }
+ // tensorSize must be adjusted for array_length.
+ tensorSize *= getArrayLength();
+
+ return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
+ getElementType());
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 25dc1f22f0432..a47332ce9eb0c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -10,9 +10,12 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
#define DEBUG_TYPE "xegpu"
@@ -73,43 +76,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
-// Validations for nd instruction arguments is successful if any of these are
-// true:
-// - tensor descriptor and the output vector shapes exactly match.
-// - tensor descriptor has a sg_map attribute and the distributed vector shape
-// matches the tensor descriptor shape when scaled using sg_map factors on
-// each dimension.
-static bool isArgShapesValid(ArrayRef<int64_t> descShape,
- ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- // Equal shapes with no distribution - no further verification needed.
- if (descShape == valShape && !sgMap)
- return true;
-
- // Unknown distribution - cannot perform operation on partial shape.
- if (!sgMap)
- return false;
-
- // Invalid rank or mixed rank usage.
- size_t descRank = descShape.size();
- if (descRank > 2 || valShape.size() != descRank)
- return false;
-
- // For 1D, SG map is guaranteed to be unit size in the outer dimension.
- // Only take the distribution over the innermost dimension for validation.
- ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
- SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
- if (descRank == 1)
- mapLayout = {wiLayout.back()};
-
- for (const auto &[factor, dim, expected] :
- llvm::zip_equal(mapLayout, valShape, descShape)) {
- if (factor * dim != expected)
- return false;
- }
-
- return true;
-}
-
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -280,7 +246,8 @@ LogicalResult LoadNdOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto array_len = tdescTy.getArrayLength();
- auto tdescShape = getShapeOf(tdescTy);
+ // adjusted tensor descriptor shape tracks the expected shape of the result.
+ auto adjustedTdescShape = getShapeOf(tdescTy);
auto valueShape = getShapeOf(valueTy);
if (getTranspose()) {
@@ -292,7 +259,7 @@ LogicalResult LoadNdOp::verify() {
});
if (valid)
- transpose(trans, tdescShape);
+ transpose(trans, adjustedTdescShape);
else
mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
}
@@ -301,8 +268,8 @@ LogicalResult LoadNdOp::verify() {
if (tdescTy.getRank() == 2) {
const int axis = 0;
auto vnni_factor = valueShape.back();
- tdescShape[axis] /= vnni_factor;
- tdescShape.push_back(vnni_factor);
+ adjustedTdescShape[axis] /= vnni_factor;
+ adjustedTdescShape.push_back(vnni_factor);
} else {
mlir::emitWarning(getLoc())
<< "Invalid Packed Attr. It is ignored (available for 2D "
@@ -311,17 +278,35 @@ LogicalResult LoadNdOp::verify() {
}
if (array_len > 1) {
- auto it = tdescShape.begin();
- tdescShape.insert(it, array_len);
+ auto it = adjustedTdescShape.begin();
+ adjustedTdescShape.insert(it, array_len);
}
- auto sgMap = tdescTy.getSGMapAttr();
- if (!isArgShapesValid(tdescShape, valueShape, sgMap))
- return emitOpError() << "Result shape doesn't match TensorDesc shape."
- << "The expected shape is " << makeString(tdescShape)
- << ". But the given shape is "
- << makeString(valueShape) << ".\n";
- return success();
+ auto sgMap = tdescTy.getSGMapAttr();
+ // sg_map not present means IR is in VC mode. In this case value shape must
+ // match adjusted tensor descriptor shape.
+ if (!sgMap)
+ return valueShape == adjustedTdescShape
+ ? success()
+ : emitOpError()
+ << "Result shape " << makeString(valueShape)
+ << " is not consistent with tensor descripter " << tdescTy;
+
+ // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+ // value shape.
+ auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
+ if (failed(expectedValueShapeOrFailure))
+ return emitOpError() << "Failed to compute distributed vector shape for "
+ "tensor descriptor "
+ << tdescTy;
+
+ return valueTy == expectedValueShapeOrFailure.value()
+ ? success()
+ : emitOpError()
+ << "Result shape " << makeString(valueShape)
+ << " is not consistent with distributed vector shape "
+ << makeString(expectedValueShapeOrFailure.value().getShape())
+ << " for tensor descriptor " << tdescTy;
}
//===----------------------------------------------------------------------===//
@@ -351,14 +336,33 @@ LogicalResult StoreNdOp::verify() {
auto tdescShape = getShapeOf(dstTy);
auto valueShape = getShapeOf(valTy);
- auto sgMap = dstTy.getSGMapAttr();
- if (!isArgShapesValid(tdescShape, valueShape, sgMap))
- return emitOpError() << "Result shape doesn't match TensorDesc shape."
- << "The expected shape is " << makeString(tdescShape)
- << ". But the given shape is "
- << makeString(valueShape) << ".\n";
- return success();
+ auto sgMap = dstTy.getSGMapAttr();
+ // sg_map not present means IR is in VC mode. In this case value shape must
+ // match adjusted tensor descriptor shape.
+ if (!sgMap)
+ return valueShape == tdescShape
+ ? success()
+ : emitOpError()
+ << "Result shape " << makeString(valueShape)
+ << " is not consistent with tensor descripter shape "
+ << makeString(tdescShape);
+
+ // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+ // value shape.
+ auto expectedValueShapeOrFailure = dstTy.getDistributedVectorType();
+ if (failed(expectedValueShapeOrFailure))
+ return emitOpError() << "Failed to compute distributed vector shape for "
+ "tensor descriptor "
+ << dstTy;
+
+ return valTy == expectedValueShapeOrFailure.value()
+ ? success()
+ : emitOpError()
+ << "Result shape " << makeString(valueShape)
+ << " is not consistent with distributed vector shape "
+ << makeString(expectedValueShapeOrFailure.value().getShape())
+ << " for tensor descriptor " << dstTy;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 472176af72b19..131af2fbd75d5 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -13,8 +13,8 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
gpu.return
}
-// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
+// CHECK: gpu.func @test_create_nd_tdesc_simt(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -30,6 +30,15 @@ gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : ind
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_simt_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
+gpu.func @test_create_nd_tdesc_simt_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+ //CHECK: %[[C:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_nd_tdesc_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_nd_tdesc_vc_3(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>
@@ -37,6 +46,13 @@ gpu.func @test_create_nd_tdesc_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_simt_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_3(%src: memref<24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_nd_tdesc_vc_4(%[[arg0:.*]]: memref<2x24x32xf32>) {
gpu.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -44,6 +60,13 @@ gpu.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_simt_4(%[[arg0:.*]]: memref<2x24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_4(%src: memref<2x24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_nd_tdesc_vc_5(%[[arg0:.*]]: memref<2x24x32xf32, 3>) {
gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
@@ -51,6 +74,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_simt_5(%[[arg0:.*]]: memref<2x24x32xf32, 3>) {
+gpu.func @test_create_nd_tdesc_simt_5(%src: memref<2x24x32xf32, 3>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr<memory_space = slm>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr<memory_space = slm>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
@@ -58,6 +88,13 @@ gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_simt_6(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_6(%src: memref<24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -86,9 +123,8 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
gpu.return
}
-// load_nd args may have different shapes, validated against sg_map
-// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+// CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -97,13 +133,23 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_2(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<8x2xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<8x2xf16>
+ gpu.return
+}
+
// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2x1xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2x1xf32>
gpu.return
}
@@ -132,25 +178,25 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
// store_nd args may have different shapes, validated against sg_map
// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
- // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
- %1 = arith.constant dense<1.0>: vector<24x2xf16>
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
+ %1 = arith.constant dense<1.0>: vector<48x1xf16>
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
!xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
gpu.return
}
// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
- // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
- %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2x1xf16>
+ %1 = arith.constant dense<1.0>: vector<2x1xf16>
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
!xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
gpu.return
}
@@ -207,7 +253,7 @@ gpu.func @test_load_with_sg_map(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
gpu.return
}
@@ -220,7 +266,7 @@ gpu.func @test_load_with_sg_map_2(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
- //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
gpu.return
}
@@ -233,7 +279,7 @@ gpu.func @test_store_with_sg_map(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
%2 = arith.constant dense<2.9>: vector<2x1xf32>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
@@ -248,7 +294,7 @@ gpu.func @test_store_with_sg_map_2(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
%2 = arith.constant dense<2.9>: vector<1xf32>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
>From 061ace2db7f782c47cc365222de800dbe8e6fbcb Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Feb 2025 19:40:14 +0000
Subject: [PATCH 02/11] fix invalid tests
---
mlir/test/Dialect/XeGPU/invalid.mlir | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 86356e09de57c..1ea452119fec9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1]}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -93,7 +93,7 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ // expected-error at +1 {{Result shape [8] is not consistent with distributed vector shape [1, 1]}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -105,7 +105,7 @@ func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descripter}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
@@ -137,7 +137,7 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1] for tensor descriptor}}
xegpu.store_nd %data, %1
: vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
return
@@ -147,7 +147,7 @@ func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ // expected-error at +1 {{Result shape [2] is not consistent with distributed vector shape [1, 1] for tensor descriptor}}
xegpu.store_nd %data, %1
: vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
return
@@ -157,7 +157,7 @@ func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descripter shape [8, 16]}}
xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
>From bd2c8be5a732464f5da1b436278352ab34bbf4de Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Feb 2025 22:56:16 +0000
Subject: [PATCH 03/11] save work
---
.../Dialect/XeGPU/{XeGPUOps.mlir => ops.mlir} | 149 ++++++++++++++++--
1 file changed, 139 insertions(+), 10 deletions(-)
rename mlir/test/Dialect/XeGPU/{XeGPUOps.mlir => ops.mlir} (70%)
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
similarity index 70%
rename from mlir/test/Dialect/XeGPU/XeGPUOps.mlir
rename to mlir/test/Dialect/XeGPU/ops.mlir
index 131af2fbd75d5..aa4869afa2d86 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -13,8 +13,8 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
gpu.return
}
-// CHECK: gpu.func @test_create_nd_tdesc_simt(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_create_nd_tdesc_simt(%src: memref<24x32xf32>) {
+// CHECK: gpu.func @test_create_nd_tdesc_simt_1(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_1(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -83,8 +83,8 @@ gpu.func @test_create_nd_tdesc_simt_5(%src: memref<2x24x32xf32, 3>) {
// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
- %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>>
gpu.return
}
@@ -104,6 +104,15 @@ gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: gpu.func @test_prefetch_nd_simt(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_prefetch_nd_simt(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: func @test_load_nd_vc(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -114,6 +123,16 @@ gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
gpu.return
}
+// CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<8x16xf16>) {
+gpu.func @test_load_nd_simt(%src: memref<8x16xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<4x2xf16>
+ %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<4x2xf16>
+ gpu.return
+}
+
// CHECK: func @test_load_nd_vc_2(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
@@ -123,8 +142,26 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
gpu.return
}
-// CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_simt(%src: memref<24x32xf32>) {
+// CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<8x16xf16>) {
+gpu.func @test_load_nd_simt_2(%src: memref<8x16xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<1x1xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<1x1xf16>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_3(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -133,8 +170,17 @@ gpu.func @test_load_nd_simt(%src: memref<24x32xf32>) {
gpu.return
}
-// CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @test_load_nd_simt_2(%src: memref<24x32xf16>) {
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
+ %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_4(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
!xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
@@ -143,8 +189,17 @@ gpu.func @test_load_nd_simt_2(%src: memref<24x32xf16>) {
gpu.return
}
-// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+// CHECK: func @test_load_nd_vc_5(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32> -> vector<32xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32> -> vector<32xf32>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_5(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_5(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -153,6 +208,80 @@ gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_vc_6(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_vc_6(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>> -> vector<2x16x16xf16>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_6(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_6(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<32x1xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+ !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<32x1xf16>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_7(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_vc_7(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x8x16x2xf16>
+ %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>> -> vector<2x8x16x2xf16>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_7(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_7(%src: memref<24x32xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<16x2xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+ !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<16x2xf16>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_8(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_8(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_8(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_8(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_9(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_9(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+ gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_9(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_9(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
>From 981c7d3656991ce3dca803f302be4fe188ecbc98 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Feb 2025 21:48:24 +0000
Subject: [PATCH 04/11] save work
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 59 ++++++--
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 68 +++++----
mlir/test/Dialect/XeGPU/ops.mlir | 167 +++++++++++++++++----
3 files changed, 216 insertions(+), 78 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 768a3b4ef33c5..af3faf141a66e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
@@ -279,14 +280,13 @@ LogicalResult TensorDescType::verify(
if (scatterAttr) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
- // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
- // the mapping should reflect that.
+ // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
+ // respectively, the mapping should reflect that.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";
- unsigned chunkSize = scatterAttr.getChunkSize().getInt();
- if (wiData[1] != chunkSize)
+ if (wiData[1] != (32 / elementType.getIntOrFloatBitWidth()))
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
}
@@ -310,31 +310,62 @@ LogicalResult TensorDescType::verify(
return success();
}
-// If tensor descriptor has a sg_map attribute it is used in SIMT mode. In this
-// mode, the distributed vector shape is given by the following criteria:
+// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
+// In this mode, the distributed vector shape is determined as follows:
+// Definitions:
// wi_data_size = wi_data[0] × wi_data[1]
// subgroup_size = wi_layout[0] × wi_layout[1]
// distribution_unit_size = subgroup_size × wi_data_size
+// ---------------------------------------------------------------------
+// Case 1: Regular loads/stores.
+// ---------------------------------------------------------------------
+// Distributed vector shape must be:
+// [chunk_size / wi_data_size, wi_data_size]
+// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
+// [wi_data_size]
+// ---------------------------------------------------------------------
+// Case 2: Block loads/stores
+// ---------------------------------------------------------------------
+// Additionalm definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// Given above definitions, the following conditions must be met:
// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
-// Distributed vector shape must be: n_distribution_units × wi_data_size
+// Distributed vector shape must be:
+// [n_distribution_units, wi_data_size]
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
// If no sg_map is provided, tensor desc is not used in SIMT mode.
if (!sgMap)
return failure();
- // FIXME: Add support for scatter tensor descriptor.
- auto scatterAttr =
- llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
- if (scatterAttr)
- return failure();
SmallVector<int64_t> wiData(sgMap.getWiData());
SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
auto tdescShape = getShape();
+
+ auto wiDataSize = 1, sgSize = 1;
+ for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
+ wiDataSize *= wiDataDim;
+ sgSize *= wiDim;
+ }
+
+ // Case 1: regular loads/stores
+ auto scatterAttr =
+ llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+ if (scatterAttr) {
+ auto chunkSize = scatterAttr.getChunkSize().getInt();
+ // Check if the first dimension of the tensor descriptor shape is
+ // distributable.
+ if (tdescShape[0] % (wiLayout[0]) != 0)
+ return failure();
+ if (chunkSize > 1)
+ return VectorType::get({chunkSize / wiDataSize, wiDataSize},
+ getElementType());
+ return VectorType::get({wiDataSize}, getElementType());
+ }
+
+ // Case 2: block loads/stores
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
// and wiLayout must be 1.
if (tdescShape.size() == 1) {
@@ -344,14 +375,12 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
wiLayout = {wiLayout[1]};
}
// Check if the tensor descriptor shape is distributable.
- int64_t tensorSize = 1, sgSize = 1, wiDataSize = 1;
+ int64_t tensorSize = 1;
for (auto [tdescDim, wiDim, wiDataDim] :
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
if (tdescDim % (wiDim * wiDataDim) != 0)
return failure();
tensorSize *= tdescDim;
- sgSize *= wiDim;
- wiDataSize *= wiDataDim;
}
// tensorSize must be adjusted for array_length.
tensorSize *= getArrayLength();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index a47332ce9eb0c..1453c3c80a838 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -513,22 +513,25 @@ LogicalResult LoadGatherOp::verify() {
transpose({1, 0}, tdescShape);
}
- if (auto sgMap = tdescTy.getSGMapAttr()) {
- auto valueVecTy = cast<VectorType>(valueTy);
- const int32_t wiData =
- sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
- // All represent the same concept: a number of row elements to store.
- if (valueVecTy.getNumElements() != wiData ||
- valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
- return emitOpError("Chunk size, vector size and wi_data must match.");
- }
- // Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size].
- tdescShape[tdescTy.getRank() - 1] = 1;
- }
-
- if (valueShape != tdescShape)
+ auto sgMap = tdescTy.getSGMapAttr();
+ // In VC mode, sg_map is not present. In this case, value shape must match
+ // the tensor descriptor shape.
+ if (!sgMap)
+ return valueShape == tdescShape
+ ? success()
+ : emitOpError("Unexpected result shape")
+ << "(Expected shape: " << makeString(tdescShape)
+ << ", Given shape: " << makeString(valueShape) << ").\n";
+ // In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
+ auto distributedVectorShapeOrFailure = tdescTy.getDistributedVectorType();
+ if (failed(distributedVectorShapeOrFailure))
+ return emitOpError("Failed to compute distributed vector shape for "
+ "tensor descriptor ")
+ << tdescTy;
+ if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
return emitOpError("Unexpected result shape")
- << "(Expected shape: " << makeString(tdescShape)
+ << "(Expected shape: "
+ << makeString(distributedVectorShapeOrFailure.value().getShape())
<< ", Given shape: " << makeString(valueShape) << ").\n";
return success();
@@ -565,22 +568,25 @@ LogicalResult StoreScatterOp::verify() {
transpose({1, 0}, tdescShape);
}
- if (auto sgMap = tdescTy.getSGMapAttr()) {
- auto valueVecTy = cast<VectorType>(valueTy);
- const int32_t wiData =
- sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
- // All represent the same concept: a number of row elements to store.
- if (valueVecTy.getNumElements() != wiData ||
- valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
- return emitOpError("Chunk size, vector size and wi_data must match.");
- }
- // Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size].
- tdescShape[tdescTy.getRank() - 1] = 1;
- }
-
- if (valueShape != tdescShape)
- return emitOpError("Unexpected value shape")
- << "(Expected shape: " << makeString(tdescShape)
+ auto sgMap = tdescTy.getSGMapAttr();
+ // In VC mode, sg_map is not present. In this case, value shape must match
+ // the tensor descriptor shape.
+ if (!sgMap)
+ return valueShape == tdescShape
+ ? success()
+ : emitOpError("Unexpected result shape")
+ << "(Expected shape: " << makeString(tdescShape)
+ << ", Given shape: " << makeString(valueShape) << ").\n";
+ // In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
+ auto distributedVectorShapeOrFailure = tdescTy.getDistributedVectorType();
+ if (failed(distributedVectorShapeOrFailure))
+ return emitOpError("Failed to compute distributed vector shape for "
+ "tensor descriptor ")
+ << tdescTy;
+ if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
+ return emitOpError("Unexpected result shape")
+ << "(Expected shape: "
+ << makeString(distributedVectorShapeOrFailure.value().getShape())
<< ", Given shape: " << makeString(valueShape) << ").\n";
return success();
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index aa4869afa2d86..7b5232621b73b 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -293,6 +293,20 @@ gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @test_store_nd_simt(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_simt(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
+ %1 = arith.constant dense<1.0>: vector<48x1xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
+
+
// CHECK: func @test_store_nd_vc_2(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -304,21 +318,9 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
gpu.return
}
-// store_nd args may have different shapes, validated against sg_map
-// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
- // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
- %1 = arith.constant dense<1.0>: vector<48x1xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
- !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- gpu.return
-}
-// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+// CHECK: func @test_store_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_simt_2(%src: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2x1xf16>
%1 = arith.constant dense<1.0>: vector<2x1xf16>
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -329,8 +331,8 @@ gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
gpu.return
}
-// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
+// CHECK: gpu.func @test_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 16] : !xegpu.tensor_desc<8x16xf32>
@@ -338,6 +340,15 @@ gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_update_nd_tdesc_simt(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_update_nd_tdesc_simt(%src: memref<24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 16] : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_tdesc_vc(%[[arg0:.*]]: ui64) {
gpu.func @test_create_tdesc_vc(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -347,6 +358,16 @@ gpu.func @test_create_tdesc_vc(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @test_create_tdesc_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_simt(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ gpu.return
+}
+
+
// CHECK: gpu.func @test_create_tdesc_vc_1(%[[arg0:.*]]: memref<?xf32, 3>) {
gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -356,6 +377,16 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
gpu.return
}
+// CHECK: gpu.func @test_create_tdesc_simt_1(%[[arg0:.*]]: memref<?xf32, 3>) {
+gpu.func @test_create_tdesc_simt_1(%src: memref<?xf32, 3>) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32, 3>, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32, 3>, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ gpu.return
+}
+
+
// CHECK: gpu.func @test_create_tdesc_vc_2(%[[arg0:.*]]: memref<?xf32>) {
gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -365,30 +396,75 @@ gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
gpu.return
}
-// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
-gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
+// CHECK: gpu.func @test_create_tdesc_simt_2(%[[arg0:.*]]: memref<?xf32>) {
+gpu.func @test_create_tdesc_simt_2(%src: memref<?xf32>) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_create_tdesc_vc_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_vc_3(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+
+
+// CHECK: gpu.func @test_create_tdesc_simt_3(%arg0: ui64) {
+gpu.func @test_create_tdesc_simt_3(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_load_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_vc(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_load_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_simt(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2x1xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2x1xf32>
gpu.return
}
-// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
-gpu.func @test_load_with_sg_map(%src: ui64) {
+// CHECK: gpu.func @test_load_vc_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_vc_2(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
%1 = arith.constant dense<1>: vector<4xi1>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<4xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<4xf32>
gpu.return
}
-// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
-gpu.func @test_load_with_sg_map_2(%src: ui64) {
+// CHECK: gpu.func @test_load_simt_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_simt_2(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
@@ -400,6 +476,33 @@ gpu.func @test_load_with_sg_map_2(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @test_load_vc_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_vc_3(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_load_simt_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_simt_3(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<4x2xf16>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<4x2xf16>
+ gpu.return
+}
+
+
// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
gpu.func @test_store_with_sg_map(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -408,10 +511,10 @@ gpu.func @test_store_with_sg_map(%src: ui64) {
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
%2 = arith.constant dense<2.9>: vector<2x1xf32>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
- xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
gpu.return
}
>From ddc8cba496eef9d790ab2b3d6d9fb710954ee98c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Feb 2025 22:16:20 +0000
Subject: [PATCH 05/11] save work
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 +-
mlir/test/Dialect/XeGPU/ops.mlir | 72 ++++++++++++++++++++++++--
2 files changed, 70 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1453c3c80a838..2e50f4677298b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -574,7 +574,7 @@ LogicalResult StoreScatterOp::verify() {
if (!sgMap)
return valueShape == tdescShape
? success()
- : emitOpError("Unexpected result shape")
+ : emitOpError("Unexpected value shape")
<< "(Expected shape: " << makeString(tdescShape)
<< ", Given shape: " << makeString(valueShape) << ").\n";
// In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
@@ -584,7 +584,7 @@ LogicalResult StoreScatterOp::verify() {
"tensor descriptor ")
<< tdescTy;
if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
- return emitOpError("Unexpected result shape")
+ return emitOpError("Unexpected value shape")
<< "(Expected shape: "
<< makeString(distributedVectorShapeOrFailure.value().getShape())
<< ", Given shape: " << makeString(valueShape) << ").\n";
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 7b5232621b73b..54210277237ef 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -502,9 +502,25 @@ gpu.func @test_load_simt_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @test_store_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_vc(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
+ %2 = arith.constant dense<2.9>: vector<2x4xf32>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+ gpu.return
+}
+
+
-// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
-gpu.func @test_store_with_sg_map(%src: ui64) {
+// CHECK: gpu.func @test_store_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_simt(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
@@ -518,8 +534,56 @@ gpu.func @test_store_with_sg_map(%src: ui64) {
gpu.return
}
-// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
-gpu.func @test_store_with_sg_map_2(%src: ui64) {
+// CHECK: gpu.func @test_store_vc_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_vc_2(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf16>
+ %2 = arith.constant dense<2.9>: vector<2x4xf16>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+ gpu.return
+}
+
+
+
+// CHECK: gpu.func @test_store_simt_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_simt_2(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<1x2xf16>
+ %2 = arith.constant dense<2.9>: vector<1x2xf16>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_store_vc_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_vc_3(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+ %1 = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<4xf32>
+ %2 = arith.constant dense<2.9>: vector<4xf32>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1>
+ gpu.return
+}
+
+
+// CHECK: gpu.func @test_store_simt_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_simt_3(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
>From 2d79647ec06c482bb7de4701fa7fbcbea6df885f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Feb 2025 23:27:13 +0000
Subject: [PATCH 06/11] save work
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 57 ++++++++++++++++++++----
mlir/test/Dialect/XeGPU/ops.mlir | 60 +++++++++++++-------------
2 files changed, 80 insertions(+), 37 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2e50f4677298b..ad8b4bf3427cc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -620,20 +620,61 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
LogicalResult DpasOp::verify() {
int64_t lhsRank = getLhsType().getRank();
int64_t rhsRank = getRhsType().getRank();
-
- if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
- return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
- "2D or 3D (packed) vector.");
-
+ int64_t resultRank = getResultType().getRank();
auto lhsShape = getLhsType().getShape();
auto rhsShape = getRhsType().getShape();
- auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
- if (bK != lhsShape[1])
+ auto resultShape = getResultType().getShape();
+
+ auto sgMapA = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_a");
+ auto sgMapB = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_b");
+ auto sgMapC = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_c");
+
+ // If sg_maps are not present, then the operation is in VC mode.
+ if (!sgMapA && !sgMapB && !sgMapC) {
+ if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resultRank != 2)
+ return emitOpError(
+ "expecting lhs and result to be a 2D vector, and rhs to be either "
+ "2D or 3D (packed) vector.");
+ auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+ if (bK != lhsShape[1])
+ return emitOpError("K-dimension mismatch.");
+ if (lhsShape[0] != resultShape[0])
+ return emitOpError("M-dimension mismatch.");
+ if (rhsShape[1] != resultShape[1])
+ return emitOpError("N-dimension mismatch.");
+ return success();
+ }
+ // Otherwise, in SIMT mode we expect sg_map attributes for all operands and
+ // result of DPAS operation.
+ if (!sgMapA || !sgMapB || !sgMapC)
+ return emitOpError("sg_map attributes for all operands and outputs are "
+ "expected in SIMT xegpu::Dpas operation");
+
+ // In SIMT mode, All data fragments must be 2D
+ if (lhsRank != 2 || rhsRank != 2 || resultRank != 2)
+ return emitOpError("expecting lhs, rhs, and result to be a 2D vector.");
+
+ auto wiLayoutA = sgMapA.getWiLayout();
+ auto wiLayoutB = sgMapB.getWiLayout();
+ auto wiLayoutC = sgMapC.getWiLayout();
+ // Obtain the expanded shapes of the operands and result using wi_layout.
+ // NOTE: For B, get rid of the packed dimension for the expanded shape.
+ SmallVector<int64_t> expandedShapeA = {lhsShape[0] * wiLayoutA[0],
+ lhsShape[1] * wiLayoutA[1]};
+ SmallVector<int64_t> expandedShapeB = {
+ rhsShape[0] * rhsShape[1] * wiLayoutB[0], 1 * wiLayoutB[1]};
+ SmallVector<int64_t> expandedShapeC = {resultShape[0] * wiLayoutC[0],
+ resultShape[1] * wiLayoutC[1]};
+ auto bK = expandedShapeB[0];
+ if (bK != expandedShapeA[1])
return emitOpError("K-dimension mismatch.");
+ if (expandedShapeA[0] != expandedShapeC[0])
+ return emitOpError("M-dimension mismatch.");
+ if (expandedShapeB[1] != expandedShapeC[1])
+ return emitOpError("N-dimension mismatch.");
return success();
}
-
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 54210277237ef..075830c4b6b7d 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -597,6 +597,16 @@ gpu.func @test_store_simt_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @test_prefetch_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_prefetch_simt(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ gpu.return
+}
// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
@@ -610,35 +620,16 @@ gpu.func @test_prefetch_vc(%src: ui64) {
gpu.return
}
-// CHECK: gpu.func @test_load_gather_vc(%[[arg0:.*]]: ui64) {
-gpu.func @test_load_gather_vc(%src: ui64) {
- //CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<4xi1>
- %0 = arith.constant dense<1>: vector<4xi1>
- //CHECK: %[[c2:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %c = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c2]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
- %1 = xegpu.create_tdesc %src, %c : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
- //CHECK-SAME: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
- %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
- : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
- gpu.return
-}
-
-// CHECK: gpu.func @test_store_scatter_vc(%[[arg0:.*]]: ui64) {
-gpu.func @test_store_scatter_vc(%src: ui64) {
- //CHECK: %[[c0:.*]] = arith.constant dense<true> : vector<4xi1>
- %0 = arith.constant dense<1>: vector<4xi1>
- //CHECK: %[[c1:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
- %1 = arith.constant dense<2.9>: vector<2x4xf32>
- //CHECK: %[[c2:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %c = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c2]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
- %2 = xegpu.create_tdesc %src, %c : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- //CHECK: xegpu.store %[[c1]], %[[R0]], %[[c0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
- //CHECK-SAME: vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
- xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
- : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+// CHECK: gpu.func @test_create_update_tdesc_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_update_tdesc_simt(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ //CHECK: %[[st:.*]] = arith.constant dense<32> : vector<4xindex>
+ //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], %[[st]] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+ %s = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
+ %2 = xegpu.update_offset %1, %s : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xindex>
gpu.return
}
@@ -662,6 +653,17 @@ gpu.func @test_dpas_vc(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
gpu.return
}
+// CHECK: gpu.func @test_dpas_simt(%[[arg0:.*]]: vector<8x1xf16>, %[[arg1:.*]]: vector<8x2xf16>)
+gpu.func @test_dpas_simt(%a : vector<8x1xf16>, %b: vector<8x2xf16>) {
+ // CHECK: xegpu.dpas %[[arg0]], %[[arg1]] {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>,
+ // CHECK: sg_map_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>,
+ // CHECK: sg_map_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+ %1 = xegpu.dpas %a, %b {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>,
+ sg_map_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>,
+ sg_map_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>}
+ : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+ gpu.return
+}
// CHECK: gpu.func @test_dpas_vc_with_packed_b(%[[arg0:.*]]: vector<8x16xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
gpu.func @test_dpas_vc_with_packed_b(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {
>From 13edb3368518ed700a5fe8ca04f85d21865bf342 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Feb 2025 16:26:09 +0000
Subject: [PATCH 07/11] save work
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 5 ++++-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 6 +++---
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 7560ede058faa..78e236b4ca170 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -757,7 +757,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
let arguments = (ins
XeGPU_DpasOpType : $lhs,
XeGPU_DpasOpType : $rhs,
- Optional<XeGPU_Vector2DType>: $acc);
+ Optional<XeGPU_Vector2DType>: $acc,
+ OptionalAttr<XeGPU_SGMapAttr>:$sg_map_a,
+ OptionalAttr<XeGPU_SGMapAttr>:$sg_map_b,
+ OptionalAttr<XeGPU_SGMapAttr>:$sg_map_c);
let results = (outs XeGPU_Vector2DType: $result);
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ad8b4bf3427cc..fd8e53ad16ad2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -625,9 +625,9 @@ LogicalResult DpasOp::verify() {
auto rhsShape = getRhsType().getShape();
auto resultShape = getResultType().getShape();
- auto sgMapA = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_a");
- auto sgMapB = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_b");
- auto sgMapC = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_c");
+ auto sgMapA = getSgMapAAttr();
+ auto sgMapB = getSgMapBAttr();
+ auto sgMapC = getSgMapCAttr();
// If sg_maps are not present, then the operation is in VC mode.
if (!sgMapA && !sgMapB && !sgMapC) {
>From 0d4214830d4651d5533d11617fc1ce7a427a77cc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Feb 2025 19:22:19 +0000
Subject: [PATCH 08/11] save work
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 10 ++--
mlir/test/Dialect/XeGPU/invalid.mlir | 66 ++++++++++++++++++--------
2 files changed, 52 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index fd8e53ad16ad2..19eb665d599ee 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -283,7 +283,7 @@ LogicalResult LoadNdOp::verify() {
}
auto sgMap = tdescTy.getSGMapAttr();
- // sg_map not present means IR is in VC mode. In this case value shape must
+ // sg_map not present means IR is in SIMD mode. In this case value shape must
// match adjusted tensor descriptor shape.
if (!sgMap)
return valueShape == adjustedTdescShape
@@ -338,7 +338,7 @@ LogicalResult StoreNdOp::verify() {
auto valueShape = getShapeOf(valTy);
auto sgMap = dstTy.getSGMapAttr();
- // sg_map not present means IR is in VC mode. In this case value shape must
+ // sg_map not present means IR is in SIMD mode. In this case value shape must
// match adjusted tensor descriptor shape.
if (!sgMap)
return valueShape == tdescShape
@@ -514,7 +514,7 @@ LogicalResult LoadGatherOp::verify() {
}
auto sgMap = tdescTy.getSGMapAttr();
- // In VC mode, sg_map is not present. In this case, value shape must match
+ // In SIMD mode, sg_map is not present. In this case, value shape must match
// the tensor descriptor shape.
if (!sgMap)
return valueShape == tdescShape
@@ -569,7 +569,7 @@ LogicalResult StoreScatterOp::verify() {
}
auto sgMap = tdescTy.getSGMapAttr();
- // In VC mode, sg_map is not present. In this case, value shape must match
+ // In SIMD mode, sg_map is not present. In this case, value shape must match
// the tensor descriptor shape.
if (!sgMap)
return valueShape == tdescShape
@@ -629,7 +629,7 @@ LogicalResult DpasOp::verify() {
auto sgMapB = getSgMapBAttr();
auto sgMapC = getSgMapCAttr();
- // If sg_maps are not present, then the operation is in VC mode.
+ // If sg_maps are not present, then the operation is in SIMD mode.
if (!sgMapA && !sgMapB && !sgMapC) {
if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resultRank != 2)
return emitOpError(
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 1ea452119fec9..584dc23d61e2e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -253,9 +253,9 @@ func.func @test_create_tdesc_sg_map_3(%src: ui64) {
func.func @test_load_gather_sg_map_1(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
// expected-error at +1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}}
- %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1x2xf32>
+ %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1x2xf32>
return
}
@@ -263,19 +263,9 @@ func.func @test_load_gather_sg_map_1(%src: ui64) {
func.func @test_load_gather_sg_map_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
// expected-error at +1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}}
- %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
- return
-}
-
-// -----
-func.func @test_load_gather_sg_map_3(%src: ui64) {
- %0 = arith.constant dense<1>: vector<4xi1>
- %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
- // expected-error at +1 {{Chunk size, vector size and wi_data must match}}
- %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1xf32>
+ %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2xf32>
return
}
@@ -285,9 +275,9 @@ func.func @test_store_scatter_sg_map_1(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%val = arith.constant dense<2.9>: vector<1x2xf32>
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
// expected-error at +1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}}
- xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
return
}
@@ -296,9 +286,9 @@ func.func @test_store_scatter_sg_map_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%val = arith.constant dense<2.9>: vector<2xf32>
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+ %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
// expected-error at +1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}}
- xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+ xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
return
}
@@ -358,11 +348,49 @@ func.func @test_dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
// -----
func.func @test_dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
- // expected-error at +1 {{expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector}}
+ // expected-error at +1 {{expecting lhs and result to be a 2D vector, and rhs to be either 2D or 3D (packed) vector}}
%1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
return
}
+// -----
+func.func @test_dpas_3(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
+ // expected-error at +1 {{K-dimension mismatch}}
+ %1 = xegpu.dpas %a, %b : vector<8x8xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+ return
+}
+
+// -----
+func.func @test_dpas_4(%a : vector<16x16xf16>, %b: vector<8x16x2xf16>) {
+ // expected-error at +1 {{M-dimension mismatch}}
+ %1 = xegpu.dpas %a, %b : vector<16x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+ return
+}
+
+// -----
+func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
+ // expected-error at +1 {{N-dimension mismatch}}
+ %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x8x2xf16> -> vector<8x16xf32>
+ return
+}
+
+// -----
+func.func @test_dpas_sg_map_1(%a : vector<8x1xf16>, %b: vector<8x2xf16>) {
+ // expected-error at +1 {{sg_map attributes for all operands and outputs are expected in SIMT xegpu::Dpas operation}}
+ %1 = xegpu.dpas %a, %b {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+ return
+}
+
+// -----
+func.func @test_dpas_sg_map_2(%a : vector<8x1xf16>, %b: vector<4x2xf16>) {
+ // expected-error at +1 {{K-dimension mismatch}}
+ %1 = xegpu.dpas %a, %b {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>,
+ sg_map_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>,
+ sg_map_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>}
+ : vector<8x1xf16>, vector<4x2xf16> -> vector<8x1xf32>
+ return
+}
+
// -----
func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
>From 4b5cffb1b3443c1e51d053ef795e2b03c61b1977 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Feb 2025 19:35:31 +0000
Subject: [PATCH 09/11] save work
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ++--
mlir/test/Dialect/XeGPU/invalid.mlir | 12 ++++++------
3 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index af3faf141a66e..53fb89efd9226 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -326,7 +326,7 @@ LogicalResult TensorDescType::verify(
// ---------------------------------------------------------------------
// Case 2: Block loads/stores
// ---------------------------------------------------------------------
-// Additionalm definitions:
+// Additional definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// Given above definitions, the following conditions must be met:
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 19eb665d599ee..2af9f7d846ff2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -290,7 +290,7 @@ LogicalResult LoadNdOp::verify() {
? success()
: emitOpError()
<< "Result shape " << makeString(valueShape)
- << " is not consistent with tensor descripter " << tdescTy;
+ << " is not consistent with tensor descriptor " << tdescTy;
// sg_map present means IR is in SIMT mode. In this case sg_map determines the
// value shape.
@@ -345,7 +345,7 @@ LogicalResult StoreNdOp::verify() {
? success()
: emitOpError()
<< "Result shape " << makeString(valueShape)
- << " is not consistent with tensor descripter shape "
+ << " is not consistent with tensor descriptor shape "
<< makeString(tdescShape);
// sg_map present means IR is in SIMT mode. In this case sg_map determines the
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 584dc23d61e2e..8f5a42968a2d6 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -78,7 +78,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
}
// -----
-func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+func.func @test_load_nd_sg_map(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1]}}
@@ -90,7 +90,7 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
}
// -----
-func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
+func.func @test_load_nd_sg_map(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape [8] is not consistent with distributed vector shape [1, 1]}}
@@ -105,7 +105,7 @@ func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descripter}}
+ // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
@@ -134,7 +134,7 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
}
// -----
-func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+func.func @test_store_nd_sg_map(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1] for tensor descriptor}}
@@ -144,7 +144,7 @@ func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
}
// -----
-func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+func.func @test_store_nd_sg_map(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape [2] is not consistent with distributed vector shape [1, 1] for tensor descriptor}}
@@ -157,7 +157,7 @@ func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descripter shape [8, 16]}}
+ // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor shape [8, 16]}}
xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
>From ab59c4659af240293f6fe0b2dc6a2eff97b3c530 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Feb 2025 22:27:40 +0000
Subject: [PATCH 10/11] save work
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 22 +++---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 89 +++++++++-------------
mlir/test/Dialect/XeGPU/invalid.mlir | 4 +-
3 files changed, 51 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 53fb89efd9226..93bc9a7961d4d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -14,6 +14,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
+#include <cassert>
namespace mlir {
namespace xegpu {
@@ -281,11 +282,11 @@ LogicalResult TensorDescType::verify(
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
- // respectively, the mapping should reflect that.
+ // respectively, the mapping should reflect that. This is because each
+ // work item access data in 32 bit granularity.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";
-
if (wiData[1] != (32 / elementType.getIntOrFloatBitWidth()))
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
@@ -351,14 +352,13 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
}
// Case 1: regular loads/stores
- auto scatterAttr =
- llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+ auto scatterAttr = getEncodingAsScatterTensorDescAttr();
if (scatterAttr) {
auto chunkSize = scatterAttr.getChunkSize().getInt();
- // Check if the first dimension of the tensor descriptor shape is
+ // Verify if the first dimension of the tensor descriptor shape is
// distributable.
- if (tdescShape[0] % (wiLayout[0]) != 0)
- return failure();
+ assert(tdescShape[0] % (wiLayout[0]) == 0 &&
+ "tensor descriptor shape is not distributable");
if (chunkSize > 1)
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
getElementType());
@@ -369,8 +369,8 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
// and wiLayout must be 1.
if (tdescShape.size() == 1) {
- if (wiData[0] != 1 || wiLayout[0] != 1)
- return failure();
+ assert((wiData[0] == 1 && wiLayout[0] == 1) &&
+ "wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
wiData = {wiData[1]};
wiLayout = {wiLayout[1]};
}
@@ -378,8 +378,8 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
int64_t tensorSize = 1;
for (auto [tdescDim, wiDim, wiDataDim] :
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
- if (tdescDim % (wiDim * wiDataDim) != 0)
- return failure();
+ assert((tdescDim % (wiDim * wiDataDim) == 0) &&
+ "tensor descriptor shape is not distributable");
tensorSize *= tdescDim;
}
// tensorSize must be adjusted for array_length.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2af9f7d846ff2..af34fa03e24d6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
@@ -76,6 +77,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
+// Helper to validate value shape of LoadNd and StoreNd ops.
+static LogicalResult
+isArgShapesValid(TensorDescType tdescTy, VectorType valueTy,
+ ArrayRef<int64_t> adjustedTdescShape,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto sgMap = tdescTy.getSGMapAttr();
+ auto valueShape = valueTy.getShape();
+ // sg_map not present means IR is in SIMD mode. In this case value shape must
+ // match adjusted tensor descriptor shape.
+ if (!sgMap)
+ return valueShape == adjustedTdescShape
+ ? success()
+ : emitError()
+ << "Value shape " << makeString(valueShape)
+ << " is not consistent with tensor descriptor " << tdescTy;
+
+ // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+ // value shape.
+ auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
+ if (failed(expectedValueShapeOrFailure))
+ return emitError() << "Failed to compute distributed vector shape for "
+ "tensor descriptor "
+ << tdescTy;
+
+ return valueTy == expectedValueShapeOrFailure.value()
+ ? success()
+ : emitError()
+ << "Result shape " << makeString(valueShape)
+ << " is not consistent with distributed vector shape "
+ << makeString(expectedValueShapeOrFailure.value().getShape())
+ << " for tensor descriptor " << tdescTy;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -282,31 +316,8 @@ LogicalResult LoadNdOp::verify() {
adjustedTdescShape.insert(it, array_len);
}
- auto sgMap = tdescTy.getSGMapAttr();
- // sg_map not present means IR is in SIMD mode. In this case value shape must
- // match adjusted tensor descriptor shape.
- if (!sgMap)
- return valueShape == adjustedTdescShape
- ? success()
- : emitOpError()
- << "Result shape " << makeString(valueShape)
- << " is not consistent with tensor descriptor " << tdescTy;
-
- // sg_map present means IR is in SIMT mode. In this case sg_map determines the
- // value shape.
- auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
- if (failed(expectedValueShapeOrFailure))
- return emitOpError() << "Failed to compute distributed vector shape for "
- "tensor descriptor "
- << tdescTy;
-
- return valueTy == expectedValueShapeOrFailure.value()
- ? success()
- : emitOpError()
- << "Result shape " << makeString(valueShape)
- << " is not consistent with distributed vector shape "
- << makeString(expectedValueShapeOrFailure.value().getShape())
- << " for tensor descriptor " << tdescTy;
+ return isArgShapesValid(tdescTy, valueTy, adjustedTdescShape,
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
@@ -337,32 +348,8 @@ LogicalResult StoreNdOp::verify() {
auto tdescShape = getShapeOf(dstTy);
auto valueShape = getShapeOf(valTy);
- auto sgMap = dstTy.getSGMapAttr();
- // sg_map not present means IR is in SIMD mode. In this case value shape must
- // match adjusted tensor descriptor shape.
- if (!sgMap)
- return valueShape == tdescShape
- ? success()
- : emitOpError()
- << "Result shape " << makeString(valueShape)
- << " is not consistent with tensor descriptor shape "
- << makeString(tdescShape);
-
- // sg_map present means IR is in SIMT mode. In this case sg_map determines the
- // value shape.
- auto expectedValueShapeOrFailure = dstTy.getDistributedVectorType();
- if (failed(expectedValueShapeOrFailure))
- return emitOpError() << "Failed to compute distributed vector shape for "
- "tensor descriptor "
- << dstTy;
-
- return valTy == expectedValueShapeOrFailure.value()
- ? success()
- : emitOpError()
- << "Result shape " << makeString(valueShape)
- << " is not consistent with distributed vector shape "
- << makeString(expectedValueShapeOrFailure.value().getShape())
- << " for tensor descriptor " << dstTy;
+ return isArgShapesValid(dstTy, valTy, tdescShape,
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 8f5a42968a2d6..01b4e99d2245e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -105,7 +105,7 @@ func.func @test_load_nd_sg_map(%src: memref<24x32xf32>) {
func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor}}
+ // expected-error at +1 {{Value shape [8, 1] is not consistent with tensor descriptor}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
@@ -157,7 +157,7 @@ func.func @test_store_nd_sg_map(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32>
- // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor shape [8, 16]}}
+ // expected-error at +1 {{Value shape [8, 1] is not consistent with tensor descriptor}}
xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
>From be1e7287c03bb74846e0e228eeec846bd43d04cf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 21 Feb 2025 03:32:28 +0000
Subject: [PATCH 11/11] save work
---
mlir/test/Dialect/XeGPU/ops.mlir | 18 ------------------
1 file changed, 18 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 075830c4b6b7d..c32f1905454b6 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -264,24 +264,6 @@ gpu.func @test_load_nd_simt_8(%src: memref<24x32xf32>) {
gpu.return
}
-// CHECK: func @test_load_nd_vc_9(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_vc_9(%src: memref<24x32xf32>) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
- gpu.return
-}
-
-// CHECK: func @test_load_nd_simt_9(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_simt_9(%src: memref<24x32xf32>) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
- // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
- gpu.return
-}
-
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
More information about the Mlir-commits
mailing list