[Mlir-commits] [mlir] [mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. (PR #127920)

Charitha Saumya llvmlistbot at llvm.org
Thu Feb 20 11:28:50 PST 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/127920

>From f92c151bf530c3b28574f63a389d1eea74c0cf4c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 17 Feb 2025 22:46:26 +0000
Subject: [PATCH 1/8] save work

---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |   7 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  53 ++++++++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 118 +++++++++---------
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir         |  84 ++++++++++---
 4 files changed, 185 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index cc2e93fb19a70..ccd91a928e1dd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -103,7 +103,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
       CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
   ];
-  
+
   let extraClassDeclaration = [{
     using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -176,6 +176,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
         return scatter_attr.getChunkSize().getInt();
       return 1;
     }
+
+    // This returns a vector type that represents the fragment of data owned by
+    // a work item in SIMT mode if this tensor descriptor is used in a XeGPU
+    // load/store operation.
+    FailureOr<VectorType> getDistributedVectorType();
   }];
 
   let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 06fd03f3af3ad..768a3b4ef33c5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,8 +8,11 @@
 
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 namespace xegpu {
@@ -307,6 +310,56 @@ LogicalResult TensorDescType::verify(
   return success();
 }
 
+// If tensor descriptor has a sg_map attribute it is used in SIMT mode. In this
+// mode, the distributed vector shape is given by the following criteria:
+//        wi_data_size = wi_data[0] × wi_data[1]
+//        subgroup_size = wi_layout[0] × wi_layout[1]
+//        distribution_unit_size = subgroup_size × wi_data_size
+//        tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
+//        n_distribution_units = tensor_size / distribution_unit_size
+// Given above definitions, the following conditions must be met:
+//        * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
+//        * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
+// Distributed vector shape must be: n_distribution_units × wi_data_size
+FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
+  auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+  // If no sg_map is provided, tensor desc is not used in SIMT mode.
+  if (!sgMap)
+    return failure();
+  // FIXME: Add support for scatter tensor descriptor.
+  auto scatterAttr =
+      llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+  if (scatterAttr)
+    return failure();
+
+  SmallVector<int64_t> wiData(sgMap.getWiData());
+  SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
+  auto tdescShape = getShape();
+  // Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
+  // and wiLayout must be 1.
+  if (tdescShape.size() == 1) {
+    if (wiData[0] != 1 || wiLayout[0] != 1)
+      return failure();
+    wiData = {wiData[1]};
+    wiLayout = {wiLayout[1]};
+  }
+  // Check if the tensor descriptor shape is distributable.
+  int64_t tensorSize = 1, sgSize = 1, wiDataSize = 1;
+  for (auto [tdescDim, wiDim, wiDataDim] :
+       llvm::zip_equal(tdescShape, wiLayout, wiData)) {
+    if (tdescDim % (wiDim * wiDataDim) != 0)
+      return failure();
+    tensorSize *= tdescDim;
+    sgSize *= wiDim;
+    wiDataSize *= wiDataDim;
+  }
+  // tensorSize must be adjusted for array_length.
+  tensorSize *= getArrayLength();
+
+  return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
+                         getElementType());
+}
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 25dc1f22f0432..a47332ce9eb0c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -10,9 +10,12 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LLVM.h"
 
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 
 #define DEBUG_TYPE "xegpu"
 
@@ -73,43 +76,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
-// Validations for nd instruction arguments is successful if any of these are
-// true:
-// - tensor descriptor and the output vector shapes exactly match.
-// - tensor descriptor has a sg_map attribute and the distributed vector shape
-//   matches the tensor descriptor shape when scaled using sg_map factors on
-//   each dimension.
-static bool isArgShapesValid(ArrayRef<int64_t> descShape,
-                             ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
-  // Equal shapes with no distribution - no further verification needed.
-  if (descShape == valShape && !sgMap)
-    return true;
-
-  // Unknown distribution - cannot perform operation on partial shape.
-  if (!sgMap)
-    return false;
-
-  // Invalid rank or mixed rank usage.
-  size_t descRank = descShape.size();
-  if (descRank > 2 || valShape.size() != descRank)
-    return false;
-
-  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
-  // Only take the distribution over the innermost dimension for validation.
-  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
-  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
-  if (descRank == 1)
-    mapLayout = {wiLayout.back()};
-
-  for (const auto &[factor, dim, expected] :
-       llvm::zip_equal(mapLayout, valShape, descShape)) {
-    if (factor * dim != expected)
-      return false;
-  }
-
-  return true;
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -280,7 +246,8 @@ LogicalResult LoadNdOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto array_len = tdescTy.getArrayLength();
-  auto tdescShape = getShapeOf(tdescTy);
+  // adjusted tensor descriptor shape tracks the expected shape of the result.
+  auto adjustedTdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 
   if (getTranspose()) {
@@ -292,7 +259,7 @@ LogicalResult LoadNdOp::verify() {
     });
 
     if (valid)
-      transpose(trans, tdescShape);
+      transpose(trans, adjustedTdescShape);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -301,8 +268,8 @@ LogicalResult LoadNdOp::verify() {
     if (tdescTy.getRank() == 2) {
       const int axis = 0;
       auto vnni_factor = valueShape.back();
-      tdescShape[axis] /= vnni_factor;
-      tdescShape.push_back(vnni_factor);
+      adjustedTdescShape[axis] /= vnni_factor;
+      adjustedTdescShape.push_back(vnni_factor);
     } else {
       mlir::emitWarning(getLoc())
           << "Invalid Packed Attr. It is ignored (available for 2D "
@@ -311,17 +278,35 @@ LogicalResult LoadNdOp::verify() {
   }
 
   if (array_len > 1) {
-    auto it = tdescShape.begin();
-    tdescShape.insert(it, array_len);
+    auto it = adjustedTdescShape.begin();
+    adjustedTdescShape.insert(it, array_len);
   }
-  auto sgMap = tdescTy.getSGMapAttr();
 
-  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
-    return emitOpError() << "Result shape doesn't match TensorDesc shape."
-                         << "The expected shape is " << makeString(tdescShape)
-                         << ". But the given shape is "
-                         << makeString(valueShape) << ".\n";
-  return success();
+  auto sgMap = tdescTy.getSGMapAttr();
+  // sg_map not present means IR is in VC mode. In this case value shape must
+  // match adjusted tensor descriptor shape.
+  if (!sgMap)
+    return valueShape == adjustedTdescShape
+               ? success()
+               : emitOpError()
+                     << "Result shape " << makeString(valueShape)
+                     << " is not consistent with tensor descripter " << tdescTy;
+
+  // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+  // value shape.
+  auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
+  if (failed(expectedValueShapeOrFailure))
+    return emitOpError() << "Failed to compute distributed vector shape for "
+                            "tensor descriptor "
+                         << tdescTy;
+
+  return valueTy == expectedValueShapeOrFailure.value()
+             ? success()
+             : emitOpError()
+                   << "Result shape " << makeString(valueShape)
+                   << " is not consistent with distributed vector shape "
+                   << makeString(expectedValueShapeOrFailure.value().getShape())
+                   << " for tensor descriptor " << tdescTy;
 }
 
 //===----------------------------------------------------------------------===//
@@ -351,14 +336,33 @@ LogicalResult StoreNdOp::verify() {
 
   auto tdescShape = getShapeOf(dstTy);
   auto valueShape = getShapeOf(valTy);
-  auto sgMap = dstTy.getSGMapAttr();
 
-  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
-    return emitOpError() << "Result shape doesn't match TensorDesc shape."
-                         << "The expected shape is " << makeString(tdescShape)
-                         << ". But the given shape is "
-                         << makeString(valueShape) << ".\n";
-  return success();
+  auto sgMap = dstTy.getSGMapAttr();
+  // sg_map not present means IR is in VC mode. In this case value shape must
+  // match adjusted tensor descriptor shape.
+  if (!sgMap)
+    return valueShape == tdescShape
+               ? success()
+               : emitOpError()
+                     << "Result shape " << makeString(valueShape)
+                     << " is not consistent with tensor descripter shape "
+                     << makeString(tdescShape);
+
+  // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+  // value shape.
+  auto expectedValueShapeOrFailure = dstTy.getDistributedVectorType();
+  if (failed(expectedValueShapeOrFailure))
+    return emitOpError() << "Failed to compute distributed vector shape for "
+                            "tensor descriptor "
+                         << dstTy;
+
+  return valTy == expectedValueShapeOrFailure.value()
+             ? success()
+             : emitOpError()
+                   << "Result shape " << makeString(valueShape)
+                   << " is not consistent with distributed vector shape "
+                   << makeString(expectedValueShapeOrFailure.value().getShape())
+                   << " for tensor descriptor " << dstTy;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 472176af72b19..131af2fbd75d5 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -13,8 +13,8 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
+// CHECK: gpu.func @test_create_nd_tdesc_simt(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -30,6 +30,15 @@ gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : ind
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_simt_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
+gpu.func @test_create_nd_tdesc_simt_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+  //CHECK: %[[C:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_nd_tdesc_vc_3(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>
@@ -37,6 +46,13 @@ gpu.func @test_create_nd_tdesc_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_simt_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_3(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_4(%[[arg0:.*]]: memref<2x24x32xf32>) {
 gpu.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -44,6 +60,13 @@ gpu.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_simt_4(%[[arg0:.*]]: memref<2x24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_4(%src: memref<2x24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_5(%[[arg0:.*]]: memref<2x24x32xf32, 3>) {
 gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
@@ -51,6 +74,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_simt_5(%[[arg0:.*]]: memref<2x24x32xf32, 3>) {
+gpu.func @test_create_nd_tdesc_simt_5(%src: memref<2x24x32xf32, 3>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr<memory_space = slm>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr<memory_space = slm>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
@@ -58,6 +88,13 @@ gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_simt_6(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_6(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -86,9 +123,8 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
   gpu.return
 }
 
-// load_nd args may have different shapes, validated against sg_map
-// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+// CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -97,13 +133,23 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_2(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<8x2xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<8x2xf16>
+  gpu.return
+}
+
 // CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2x1xf32>
   gpu.return
 }
 
@@ -132,25 +178,25 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
 // store_nd args may have different shapes, validated against sg_map
 // CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
-   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
-  %1 = arith.constant dense<1.0>: vector<24x2xf16>
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
+  %1 = arith.constant dense<1.0>: vector<48x1xf16>
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
     !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   gpu.return
 }
 
 // CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
-   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
-  %1 = arith.constant dense<1.0>: vector<2xf16>
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2x1xf16>
+  %1 = arith.constant dense<1.0>: vector<2x1xf16>
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
     !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   gpu.return
 }
 
@@ -207,7 +253,7 @@ gpu.func @test_load_with_sg_map(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32> 
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
   %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
   gpu.return
 }
@@ -220,7 +266,7 @@ gpu.func @test_load_with_sg_map_2(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32> 
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
   %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
   gpu.return
 }
@@ -233,7 +279,7 @@ gpu.func @test_store_with_sg_map(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
   %2 = arith.constant dense<2.9>: vector<2x1xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>> 
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
   %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
   //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
   xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
@@ -248,7 +294,7 @@ gpu.func @test_store_with_sg_map_2(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
   %2 = arith.constant dense<2.9>: vector<1xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>> 
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
   %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
   //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
   xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>

>From 061ace2db7f782c47cc365222de800dbe8e6fbcb Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Feb 2025 19:40:14 +0000
Subject: [PATCH 2/8] fix invalid tests

---
 mlir/test/Dialect/XeGPU/invalid.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 86356e09de57c..1ea452119fec9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
 func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1]}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
       l2_hint = #xegpu.cache_hint<uncached>}>
     : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -93,7 +93,7 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
 func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  // expected-error at +1 {{Result shape [8] is not consistent with distributed vector shape [1, 1]}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
       l2_hint = #xegpu.cache_hint<uncached>}>
     : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -105,7 +105,7 @@ func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
 func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32>
-  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descripter}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
       l2_hint = #xegpu.cache_hint<uncached>}>
     : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
@@ -137,7 +137,7 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
 func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
   %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1] for tensor descriptor}}
   xegpu.store_nd %data, %1
     : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   return
@@ -147,7 +147,7 @@ func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
 func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
   %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  // expected-error at +1 {{Result shape [2] is not consistent with distributed vector shape [1, 1] for tensor descriptor}}
   xegpu.store_nd %data, %1
     : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   return
@@ -157,7 +157,7 @@ func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
 func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
   %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32>
-  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descripter shape [8, 16]}}
   xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }

>From bd2c8be5a732464f5da1b436278352ab34bbf4de Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 18 Feb 2025 22:56:16 +0000
Subject: [PATCH 3/8] save work

---
 .../Dialect/XeGPU/{XeGPUOps.mlir => ops.mlir} | 149 ++++++++++++++++--
 1 file changed, 139 insertions(+), 10 deletions(-)
 rename mlir/test/Dialect/XeGPU/{XeGPUOps.mlir => ops.mlir} (70%)

diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
similarity index 70%
rename from mlir/test/Dialect/XeGPU/XeGPUOps.mlir
rename to mlir/test/Dialect/XeGPU/ops.mlir
index 131af2fbd75d5..aa4869afa2d86 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -13,8 +13,8 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_create_nd_tdesc_simt(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_create_nd_tdesc_simt(%src: memref<24x32xf32>) {
+// CHECK: gpu.func @test_create_nd_tdesc_simt_1(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_simt_1(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -83,8 +83,8 @@ gpu.func @test_create_nd_tdesc_simt_5(%src: memref<2x24x32xf32, 3>) {
 
 // CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>>
   gpu.return
 }
 
@@ -104,6 +104,15 @@ gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_prefetch_nd_simt(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_prefetch_nd_simt(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: func @test_load_nd_vc(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -114,6 +123,16 @@ gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<8x16xf16>) {
+gpu.func @test_load_nd_simt(%src: memref<8x16xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<4x2xf16>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+       : !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<4x2xf16>
+  gpu.return
+}
+
 // CHECK: func @test_load_nd_vc_2(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
@@ -123,8 +142,26 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
   gpu.return
 }
 
-// CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_simt(%src: memref<24x32xf32>) {
+// CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<8x16xf16>) {
+gpu.func @test_load_nd_simt_2(%src: memref<8x16xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<1x1xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<1x1xf16>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_3(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -133,8 +170,17 @@ gpu.func @test_load_nd_simt(%src: memref<24x32xf32>) {
   gpu.return
 }
 
-// CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @test_load_nd_simt_2(%src: memref<24x32xf16>) {
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_4(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
     !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
@@ -143,8 +189,17 @@ gpu.func @test_load_nd_simt_2(%src: memref<24x32xf16>) {
   gpu.return
 }
 
-// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+// CHECK: func @test_load_nd_vc_5(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32> -> vector<32xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32> -> vector<32xf32>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_5(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_5(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -153,6 +208,80 @@ gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_vc_6(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_vc_6(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>> -> vector<2x16x16xf16>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_6(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_6(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<32x1xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+    !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<32x1xf16>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_7(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_vc_7(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x8x16x2xf16>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>> -> vector<2x8x16x2xf16>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_7(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_load_nd_simt_7(%src: memref<24x32xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<16x2xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+    !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>> -> vector<16x2xf16>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_8(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_8(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_8(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_8(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_vc_9(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_9(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+  gpu.return
+}
+
+// CHECK: func @test_load_nd_simt_9(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_simt_9(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>

>From 981c7d3656991ce3dca803f302be4fe188ecbc98 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Feb 2025 21:48:24 +0000
Subject: [PATCH 4/8] save work

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp |  59 ++++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp     |  68 +++++----
 mlir/test/Dialect/XeGPU/ops.mlir           | 167 +++++++++++++++++----
 3 files changed, 216 insertions(+), 78 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 768a3b4ef33c5..af3faf141a66e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/LogicalResult.h"
@@ -279,14 +280,13 @@ LogicalResult TensorDescType::verify(
     if (scatterAttr) {
       // Validate subgroup mapping rules for scattered tensors.
       // A work-item's slice of the tensor with shape [sg_size] or
-      // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
-      // the mapping should reflect that.
+      // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
+      // respectively, the mapping should reflect that.
       if (wiData[0] != 1)
         return emitError()
                << "cannot map over non-contiguous scattered row elements";
 
-      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
-      if (wiData[1] != chunkSize)
+      if (wiData[1] != (32 / elementType.getIntOrFloatBitWidth()))
         return emitError() << "work item data mapping must match the number of "
                               "contiguous elements";
     }
@@ -310,31 +310,62 @@ LogicalResult TensorDescType::verify(
   return success();
 }
 
-// If tensor descriptor has a sg_map attribute it is used in SIMT mode. In this
-// mode, the distributed vector shape is given by the following criteria:
+// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
+// In this mode, the distributed vector shape is determined as follows:
+// Definitions:
 //        wi_data_size = wi_data[0] × wi_data[1]
 //        subgroup_size = wi_layout[0] × wi_layout[1]
 //        distribution_unit_size = subgroup_size × wi_data_size
+// ---------------------------------------------------------------------
+// Case 1: Regular loads/stores.
+// ---------------------------------------------------------------------
+// Distributed vector shape must be:
+//        [chunk_size / wi_data_size, wi_data_size]
+// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
+//        [wi_data_size]
+// ---------------------------------------------------------------------
+// Case 2: Block loads/stores
+// ---------------------------------------------------------------------
+// Additionalm definitions:
 //        tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
 //        n_distribution_units = tensor_size / distribution_unit_size
 // Given above definitions, the following conditions must be met:
 //        * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
 //        * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
-// Distributed vector shape must be: n_distribution_units × wi_data_size
+// Distributed vector shape must be:
+//        [n_distribution_units, wi_data_size]
 FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
   // If no sg_map is provided, tensor desc is not used in SIMT mode.
   if (!sgMap)
     return failure();
-  // FIXME: Add support for scatter tensor descriptor.
-  auto scatterAttr =
-      llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
-  if (scatterAttr)
-    return failure();
 
   SmallVector<int64_t> wiData(sgMap.getWiData());
   SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
   auto tdescShape = getShape();
+
+  auto wiDataSize = 1, sgSize = 1;
+  for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
+    wiDataSize *= wiDataDim;
+    sgSize *= wiDim;
+  }
+
+  // Case 1: regular loads/stores
+  auto scatterAttr =
+      llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+  if (scatterAttr) {
+    auto chunkSize = scatterAttr.getChunkSize().getInt();
+    // Check if the first dimension of the tensor descriptor shape is
+    // distributable.
+    if (tdescShape[0] % (wiLayout[0]) != 0)
+      return failure();
+    if (chunkSize > 1)
+      return VectorType::get({chunkSize / wiDataSize, wiDataSize},
+                             getElementType());
+    return VectorType::get({wiDataSize}, getElementType());
+  }
+
+  // Case 2: block loads/stores
   // Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
   // and wiLayout must be 1.
   if (tdescShape.size() == 1) {
@@ -344,14 +375,12 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
     wiLayout = {wiLayout[1]};
   }
   // Check if the tensor descriptor shape is distributable.
-  int64_t tensorSize = 1, sgSize = 1, wiDataSize = 1;
+  int64_t tensorSize = 1;
   for (auto [tdescDim, wiDim, wiDataDim] :
        llvm::zip_equal(tdescShape, wiLayout, wiData)) {
     if (tdescDim % (wiDim * wiDataDim) != 0)
       return failure();
     tensorSize *= tdescDim;
-    sgSize *= wiDim;
-    wiDataSize *= wiDataDim;
   }
   // tensorSize must be adjusted for array_length.
   tensorSize *= getArrayLength();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index a47332ce9eb0c..1453c3c80a838 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -513,22 +513,25 @@ LogicalResult LoadGatherOp::verify() {
     transpose({1, 0}, tdescShape);
   }
 
-  if (auto sgMap = tdescTy.getSGMapAttr()) {
-    auto valueVecTy = cast<VectorType>(valueTy);
-    const int32_t wiData =
-        sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
-    // All represent the same concept: a number of row elements to store.
-    if (valueVecTy.getNumElements() != wiData ||
-        valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
-      return emitOpError("Chunk size, vector size and wi_data must match.");
-    }
-    // Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size].
-    tdescShape[tdescTy.getRank() - 1] = 1;
-  }
-
-  if (valueShape != tdescShape)
+  auto sgMap = tdescTy.getSGMapAttr();
+  // In VC mode, sg_map is not present. In this case, value shape must match
+  // the tensor descriptor shape.
+  if (!sgMap)
+    return valueShape == tdescShape
+               ? success()
+               : emitOpError("Unexpected result shape")
+                     << "(Expected shape: " << makeString(tdescShape)
+                     << ", Given shape: " << makeString(valueShape) << ").\n";
+  // In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
+  auto distributedVectorShapeOrFailure = tdescTy.getDistributedVectorType();
+  if (failed(distributedVectorShapeOrFailure))
+    return emitOpError("Failed to compute distributed vector shape for "
+                       "tensor descriptor ")
+           << tdescTy;
+  if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
     return emitOpError("Unexpected result shape")
-           << "(Expected shape: " << makeString(tdescShape)
+           << "(Expected shape: "
+           << makeString(distributedVectorShapeOrFailure.value().getShape())
            << ", Given shape: " << makeString(valueShape) << ").\n";
 
   return success();
@@ -565,22 +568,25 @@ LogicalResult StoreScatterOp::verify() {
     transpose({1, 0}, tdescShape);
   }
 
-  if (auto sgMap = tdescTy.getSGMapAttr()) {
-    auto valueVecTy = cast<VectorType>(valueTy);
-    const int32_t wiData =
-        sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
-    // All represent the same concept: a number of row elements to store.
-    if (valueVecTy.getNumElements() != wiData ||
-        valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
-      return emitOpError("Chunk size, vector size and wi_data must match.");
-    }
-    // Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size].
-    tdescShape[tdescTy.getRank() - 1] = 1;
-  }
-
-  if (valueShape != tdescShape)
-    return emitOpError("Unexpected value shape")
-           << "(Expected shape: " << makeString(tdescShape)
+  auto sgMap = tdescTy.getSGMapAttr();
+  // In VC mode, sg_map is not present. In this case, value shape must match
+  // the tensor descriptor shape.
+  if (!sgMap)
+    return valueShape == tdescShape
+               ? success()
+               : emitOpError("Unexpected result shape")
+                     << "(Expected shape: " << makeString(tdescShape)
+                     << ", Given shape: " << makeString(valueShape) << ").\n";
+  // In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
+  auto distributedVectorShapeOrFailure = tdescTy.getDistributedVectorType();
+  if (failed(distributedVectorShapeOrFailure))
+    return emitOpError("Failed to compute distributed vector shape for "
+                       "tensor descriptor ")
+           << tdescTy;
+  if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
+    return emitOpError("Unexpected result shape")
+           << "(Expected shape: "
+           << makeString(distributedVectorShapeOrFailure.value().getShape())
            << ", Given shape: " << makeString(valueShape) << ").\n";
 
   return success();
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index aa4869afa2d86..7b5232621b73b 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -293,6 +293,20 @@ gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_store_nd_simt(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_simt(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
+  %1 = arith.constant dense<1.0>: vector<48x1xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
+
+
 // CHECK: func @test_store_nd_vc_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -304,21 +318,9 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
-// store_nd args may have different shapes, validated against sg_map
-// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
-   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
-  %1 = arith.constant dense<1.0>: vector<48x1xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
-    !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-  gpu.return
-}
 
-// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+// CHECK: func @test_store_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_simt_2(%src: memref<24x32xf16>) {
    // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2x1xf16>
   %1 = arith.constant dense<1.0>: vector<2x1xf16>
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
@@ -329,8 +331,8 @@ gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
-gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
+// CHECK: gpu.func @test_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 16] : !xegpu.tensor_desc<8x16xf32>
@@ -338,6 +340,15 @@ gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_update_nd_tdesc_simt(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_update_nd_tdesc_simt(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 16] : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_tdesc_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_create_tdesc_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -347,6 +358,16 @@ gpu.func @test_create_tdesc_vc(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_tdesc_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_simt(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  gpu.return
+}
+
+
 // CHECK: gpu.func @test_create_tdesc_vc_1(%[[arg0:.*]]: memref<?xf32, 3>) {
 gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -356,6 +377,16 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_tdesc_simt_1(%[[arg0:.*]]: memref<?xf32, 3>) {
+gpu.func @test_create_tdesc_simt_1(%src: memref<?xf32, 3>) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32, 3>, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space =  slm, chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32, 3>, vector<4xindex>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  gpu.return
+}
+
+
 // CHECK: gpu.func @test_create_tdesc_vc_2(%[[arg0:.*]]: memref<?xf32>) {
 gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -365,30 +396,75 @@ gpu.func @test_create_tdesc_vc_2(%src: memref<?xf32>) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
-gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
+// CHECK: gpu.func @test_create_tdesc_simt_2(%[[arg0:.*]]: memref<?xf32>) {
+gpu.func @test_create_tdesc_simt_2(%src: memref<?xf32>) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>  -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_create_tdesc_vc_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_vc_3(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @test_create_tdesc_simt_3(%arg0: ui64) {
+gpu.func @test_create_tdesc_simt_3(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_vc(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_simt(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2x1xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2x1xf32>
   gpu.return
 }
 
-// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
-gpu.func @test_load_with_sg_map(%src: ui64) {
+// CHECK: gpu.func @test_load_vc_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_vc_2(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<4xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<4xf32>
   gpu.return
 }
 
-// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
-gpu.func @test_load_with_sg_map_2(%src: ui64) {
+// CHECK: gpu.func @test_load_simt_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_simt_2(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
@@ -400,6 +476,33 @@ gpu.func @test_load_with_sg_map_2(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_load_vc_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_vc_3(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_simt_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_simt_3(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<4x2xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<4x2xf16>
+  gpu.return
+}
+
+
 // CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
 gpu.func @test_store_with_sg_map(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -408,10 +511,10 @@ gpu.func @test_store_with_sg_map(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
   %2 = arith.constant dense<2.9>: vector<2x1xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
   gpu.return
 }
 

>From ddc8cba496eef9d790ab2b3d6d9fb710954ee98c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Feb 2025 22:16:20 +0000
Subject: [PATCH 5/8] save work

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp |  4 +-
 mlir/test/Dialect/XeGPU/ops.mlir       | 72 ++++++++++++++++++++++++--
 2 files changed, 70 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1453c3c80a838..2e50f4677298b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -574,7 +574,7 @@ LogicalResult StoreScatterOp::verify() {
   if (!sgMap)
     return valueShape == tdescShape
                ? success()
-               : emitOpError("Unexpected result shape")
+               : emitOpError("Unexpected value shape")
                      << "(Expected shape: " << makeString(tdescShape)
                      << ", Given shape: " << makeString(valueShape) << ").\n";
   // In SIMT mode, sg_map, wi_data, and chunk_size determine the value shape.
@@ -584,7 +584,7 @@ LogicalResult StoreScatterOp::verify() {
                        "tensor descriptor ")
            << tdescTy;
   if (cast<VectorType>(valueTy) != distributedVectorShapeOrFailure.value())
-    return emitOpError("Unexpected result shape")
+    return emitOpError("Unexpected value shape")
            << "(Expected shape: "
            << makeString(distributedVectorShapeOrFailure.value().getShape())
            << ", Given shape: " << makeString(valueShape) << ").\n";
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 7b5232621b73b..54210277237ef 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -502,9 +502,25 @@ gpu.func @test_load_simt_3(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_store_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_vc(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
+  %2 = arith.constant dense<2.9>: vector<2x4xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+  gpu.return
+}
+
+
 
-// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
-gpu.func @test_store_with_sg_map(%src: ui64) {
+// CHECK: gpu.func @test_store_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_simt(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
@@ -518,8 +534,56 @@ gpu.func @test_store_with_sg_map(%src: ui64) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
-gpu.func @test_store_with_sg_map_2(%src: ui64) {
+// CHECK: gpu.func @test_store_vc_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_vc_2(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf16>
+  %2 = arith.constant dense<2.9>: vector<2x4xf16>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+  gpu.return
+}
+
+
+
+// CHECK: gpu.func @test_store_simt_2(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_simt_2(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<1x2xf16>
+  %2 = arith.constant dense<2.9>: vector<1x2xf16>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_store_vc_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_vc_3(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
+  %1 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<4xf32>
+  %2 = arith.constant dense<2.9>: vector<4xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @test_store_simt_3(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_simt_3(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>

>From 2d79647ec06c482bb7de4701fa7fbcbea6df885f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 19 Feb 2025 23:27:13 +0000
Subject: [PATCH 6/8] save work

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 57 ++++++++++++++++++++----
 mlir/test/Dialect/XeGPU/ops.mlir       | 60 +++++++++++++-------------
 2 files changed, 80 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2e50f4677298b..ad8b4bf3427cc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -620,20 +620,61 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
 LogicalResult DpasOp::verify() {
   int64_t lhsRank = getLhsType().getRank();
   int64_t rhsRank = getRhsType().getRank();
-
-  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
-    return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
-                       "2D or 3D (packed) vector.");
-
+  int64_t resultRank = getResultType().getRank();
   auto lhsShape = getLhsType().getShape();
   auto rhsShape = getRhsType().getShape();
-  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
-  if (bK != lhsShape[1])
+  auto resultShape = getResultType().getShape();
+
+  auto sgMapA = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_a");
+  auto sgMapB = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_b");
+  auto sgMapC = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_c");
+
+  // If sg_maps are not present, then the operation is in VC mode.
+  if (!sgMapA && !sgMapB && !sgMapC) {
+    if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resultRank != 2)
+      return emitOpError(
+          "expecting lhs and result to be a 2D vector, and rhs to be either "
+          "2D or 3D (packed) vector.");
+    auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+    if (bK != lhsShape[1])
+      return emitOpError("K-dimension mismatch.");
+    if (lhsShape[0] != resultShape[0])
+      return emitOpError("M-dimension mismatch.");
+    if (rhsShape[1] != resultShape[1])
+      return emitOpError("N-dimension mismatch.");
+    return success();
+  }
+  // Otherwise, in SIMT mode we expect sg_map attributes for all operands and
+  // result of DPAS operation.
+  if (!sgMapA || !sgMapB || !sgMapC)
+    return emitOpError("sg_map attributes for all operands and outputs are "
+                       "expected in SIMT xegpu::Dpas operation");
+
+  // In SIMT mode, All data fragments must be 2D
+  if (lhsRank != 2 || rhsRank != 2 || resultRank != 2)
+    return emitOpError("expecting lhs, rhs, and result to be a 2D vector.");
+
+  auto wiLayoutA = sgMapA.getWiLayout();
+  auto wiLayoutB = sgMapB.getWiLayout();
+  auto wiLayoutC = sgMapC.getWiLayout();
+  // Obtain the expanded shapes of the operands and result using wi_layout.
+  // NOTE: For B, get rid of the packed dimension for the expanded shape.
+  SmallVector<int64_t> expandedShapeA = {lhsShape[0] * wiLayoutA[0],
+                                         lhsShape[1] * wiLayoutA[1]};
+  SmallVector<int64_t> expandedShapeB = {
+      rhsShape[0] * rhsShape[1] * wiLayoutB[0], 1 * wiLayoutB[1]};
+  SmallVector<int64_t> expandedShapeC = {resultShape[0] * wiLayoutC[0],
+                                         resultShape[1] * wiLayoutC[1]};
+  auto bK = expandedShapeB[0];
+  if (bK != expandedShapeA[1])
     return emitOpError("K-dimension mismatch.");
+  if (expandedShapeA[0] != expandedShapeC[0])
+    return emitOpError("M-dimension mismatch.");
+  if (expandedShapeB[1] != expandedShapeC[1])
+    return emitOpError("N-dimension mismatch.");
 
   return success();
 }
-
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 54210277237ef..075830c4b6b7d 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -597,6 +597,16 @@ gpu.func @test_store_simt_3(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_prefetch_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_prefetch_simt(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  gpu.return
+}
 
 
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
@@ -610,35 +620,16 @@ gpu.func @test_prefetch_vc(%src: ui64) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_load_gather_vc(%[[arg0:.*]]: ui64) {
-gpu.func @test_load_gather_vc(%src: ui64) {
-  //CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<4xi1>
-  %0 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[c2:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %c = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c2]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %1 = xegpu.create_tdesc %src, %c : ui64, vector<4xindex>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
-  //CHECK-SAME: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
-        : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
-  gpu.return
-}
-
-// CHECK: gpu.func @test_store_scatter_vc(%[[arg0:.*]]: ui64) {
-gpu.func @test_store_scatter_vc(%src: ui64) {
-  //CHECK: %[[c0:.*]] = arith.constant dense<true> : vector<4xi1>
-  %0 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[c1:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
-  %1 = arith.constant dense<2.9>: vector<2x4xf32>
-  //CHECK: %[[c2:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %c = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c2]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %2 = xegpu.create_tdesc %src, %c : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: xegpu.store %[[c1]], %[[R0]], %[[c0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
-  //CHECK-SAME: vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
-  xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
-        : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+// CHECK: gpu.func @test_create_update_tdesc_simt(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_update_tdesc_simt(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  //CHECK: %[[st:.*]] = arith.constant dense<32> : vector<4xindex>
+  //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], %[[st]] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %s = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
+  %2 = xegpu.update_offset %1, %s : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xindex>
   gpu.return
 }
 
@@ -662,6 +653,17 @@ gpu.func @test_dpas_vc(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_dpas_simt(%[[arg0:.*]]: vector<8x1xf16>, %[[arg1:.*]]: vector<8x2xf16>)
+gpu.func @test_dpas_simt(%a : vector<8x1xf16>, %b: vector<8x2xf16>) {
+  // CHECK: xegpu.dpas %[[arg0]], %[[arg1]] {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>,
+  // CHECK: sg_map_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>,
+  // CHECK: sg_map_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+  %1 = xegpu.dpas %a, %b {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>,
+                          sg_map_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>,
+                          sg_map_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>}
+                          : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+  gpu.return
+}
 
 // CHECK: gpu.func @test_dpas_vc_with_packed_b(%[[arg0:.*]]: vector<8x16xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
 gpu.func @test_dpas_vc_with_packed_b(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {

>From 13edb3368518ed700a5fe8ca04f85d21865bf342 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Feb 2025 16:26:09 +0000
Subject: [PATCH 7/8] save work

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 5 ++++-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp         | 6 +++---
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 7560ede058faa..78e236b4ca170 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -757,7 +757,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
   let arguments = (ins
     XeGPU_DpasOpType : $lhs,
     XeGPU_DpasOpType : $rhs,
-    Optional<XeGPU_Vector2DType>: $acc);
+    Optional<XeGPU_Vector2DType>: $acc,
+    OptionalAttr<XeGPU_SGMapAttr>:$sg_map_a,
+    OptionalAttr<XeGPU_SGMapAttr>:$sg_map_b,
+    OptionalAttr<XeGPU_SGMapAttr>:$sg_map_c);
   let results = (outs XeGPU_Vector2DType: $result);
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ad8b4bf3427cc..fd8e53ad16ad2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -625,9 +625,9 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resultShape = getResultType().getShape();
 
-  auto sgMapA = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_a");
-  auto sgMapB = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_b");
-  auto sgMapC = (*this)->getAttrOfType<xegpu::SGMapAttr>("sg_map_c");
+  auto sgMapA = getSgMapAAttr();
+  auto sgMapB = getSgMapBAttr();
+  auto sgMapC = getSgMapCAttr();
 
   // If sg_maps are not present, then the operation is in VC mode.
   if (!sgMapA && !sgMapB && !sgMapC) {

>From 0d4214830d4651d5533d11617fc1ce7a427a77cc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 20 Feb 2025 19:22:19 +0000
Subject: [PATCH 8/8] save work

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 10 ++--
 mlir/test/Dialect/XeGPU/invalid.mlir   | 66 ++++++++++++++++++--------
 2 files changed, 52 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index fd8e53ad16ad2..19eb665d599ee 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -283,7 +283,7 @@ LogicalResult LoadNdOp::verify() {
   }
 
   auto sgMap = tdescTy.getSGMapAttr();
-  // sg_map not present means IR is in VC mode. In this case value shape must
+  // sg_map not present means IR is in SIMD mode. In this case value shape must
   // match adjusted tensor descriptor shape.
   if (!sgMap)
     return valueShape == adjustedTdescShape
@@ -338,7 +338,7 @@ LogicalResult StoreNdOp::verify() {
   auto valueShape = getShapeOf(valTy);
 
   auto sgMap = dstTy.getSGMapAttr();
-  // sg_map not present means IR is in VC mode. In this case value shape must
+  // sg_map not present means IR is in SIMD mode. In this case value shape must
   // match adjusted tensor descriptor shape.
   if (!sgMap)
     return valueShape == tdescShape
@@ -514,7 +514,7 @@ LogicalResult LoadGatherOp::verify() {
   }
 
   auto sgMap = tdescTy.getSGMapAttr();
-  // In VC mode, sg_map is not present. In this case, value shape must match
+  // In SIMD mode, sg_map is not present. In this case, value shape must match
   // the tensor descriptor shape.
   if (!sgMap)
     return valueShape == tdescShape
@@ -569,7 +569,7 @@ LogicalResult StoreScatterOp::verify() {
   }
 
   auto sgMap = tdescTy.getSGMapAttr();
-  // In VC mode, sg_map is not present. In this case, value shape must match
+  // In SIMD mode, sg_map is not present. In this case, value shape must match
   // the tensor descriptor shape.
   if (!sgMap)
     return valueShape == tdescShape
@@ -629,7 +629,7 @@ LogicalResult DpasOp::verify() {
   auto sgMapB = getSgMapBAttr();
   auto sgMapC = getSgMapCAttr();
 
-  // If sg_maps are not present, then the operation is in VC mode.
+  // If sg_maps are not present, then the operation is in SIMD mode.
   if (!sgMapA && !sgMapB && !sgMapC) {
     if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resultRank != 2)
       return emitOpError(
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 1ea452119fec9..584dc23d61e2e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -253,9 +253,9 @@ func.func @test_create_tdesc_sg_map_3(%src: ui64) {
 func.func @test_load_gather_sg_map_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
   // expected-error at +1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}}
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1x2xf32>
+  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<1x2xf32>
   return
 }
 
@@ -263,19 +263,9 @@ func.func @test_load_gather_sg_map_1(%src: ui64) {
 func.func @test_load_gather_sg_map_2(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
   // expected-error at +1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}}
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
-  return
-}
-
-// -----
-func.func @test_load_gather_sg_map_3(%src: ui64) {
-  %0 = arith.constant dense<1>: vector<4xi1>
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
-  // expected-error at +1 {{Chunk size, vector size and wi_data must match}}
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1xf32>
+  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1> -> vector<2xf32>
   return
 }
 
@@ -285,9 +275,9 @@ func.func @test_store_scatter_sg_map_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %val = arith.constant dense<2.9>: vector<1x2xf32>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
   // expected-error at +1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}}
-  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
   return
 }
 
@@ -296,9 +286,9 @@ func.func @test_store_scatter_sg_map_2(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %val = arith.constant dense<2.9>: vector<2xf32>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
   // expected-error at +1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}}
-  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
+  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xi1>
   return
 }
 
@@ -358,11 +348,49 @@ func.func @test_dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
 
 // -----
 func.func @test_dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector}}
+  // expected-error at +1 {{expecting lhs and result to be a 2D vector, and rhs to be either 2D or 3D (packed) vector}}
   %1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
 
+// -----
+func.func @test_dpas_3(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{K-dimension mismatch}}
+  %1 = xegpu.dpas %a, %b : vector<8x8xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_4(%a : vector<16x16xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{M-dimension mismatch}}
+  %1 = xegpu.dpas %a, %b : vector<16x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
+  // expected-error at +1 {{N-dimension mismatch}}
+  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x8x2xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_sg_map_1(%a : vector<8x1xf16>, %b: vector<8x2xf16>) {
+  // expected-error at +1 {{sg_map attributes for all operands and outputs are expected in SIMT xegpu::Dpas operation}}
+  %1 = xegpu.dpas %a, %b {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_sg_map_2(%a : vector<8x1xf16>, %b: vector<4x2xf16>) {
+  // expected-error at +1 {{K-dimension mismatch}}
+  %1 = xegpu.dpas %a, %b {sg_map_a = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>,
+                          sg_map_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>,
+                          sg_map_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>}
+                          : vector<8x1xf16>, vector<4x2xf16> -> vector<8x1xf32>
+  return
+}
+
 // -----
 func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
   %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>



More information about the Mlir-commits mailing list