[Mlir-commits] [mlir] [MLIR][XeGPU] Switch to 1D representation for SIMT code (PR #135116)

Chao Chen llvmlistbot at llvm.org
Thu Apr 17 08:45:20 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/135116

>From 2a1d373a61ca10bca9064a2afa7ac1fb88a87fc8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 10 Apr 2025 18:45:30 +0000
Subject: [PATCH 1/6] Switch to 1D representation for SIMT

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  17 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |   3 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  26 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 227 +++++++++++-------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 100 ++------
 mlir/test/Dialect/XeGPU/ops.mlir              | 162 ++++++-------
 6 files changed, 250 insertions(+), 285 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 16a7f63d60c82..9af6eaf69aec3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -833,16 +833,14 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
     data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
     and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
     also requires A and B to be loaded with the required data layout. Specially,
-
     VNNI layout is required for B operand. It is achieved via adding `packed`
     attribute to the `load_nd` operator.  Due to the VNNI transformation, B operands
     can be represented as a 3D vector, with the last dimension representing the VNNI
     factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
     can be represented as `B: vector<8x16x2xf16>`.
 
-    In SIMT mode, DpasOp expects layout attributes `a`, `b`, and `c` (only if acc is used)
-    which describe the data fragment owned by each work-item w.r.t. the tensor descriptor
-    these data are loaded from.
+    In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
+    which are represented as 1D vectors.
 
     Note: on PVC, the hardware can perform load with VNNI transformation when data
           element type is 16-bit or lower precision, taking 2 or 4 elements from
@@ -850,13 +848,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
   }];
 
   let arguments = (ins
-    XeGPU_DpasOpType : $lhs,
-    XeGPU_DpasOpType : $rhs,
-    Optional<XeGPU_Vector2DType>: $acc,
-    OptionalAttr<XeGPU_LayoutAttr>:$a_layout,
-    OptionalAttr<XeGPU_LayoutAttr>:$b_layout,
-    OptionalAttr<XeGPU_LayoutAttr>:$c_layout);
-  let results = (outs XeGPU_Vector2DType: $result);
+    XeGPU_DpasOprType : $lhs,
+    XeGPU_DpasOprType : $rhs,
+    Optional<XeGPU_DpasResType>: $acc);
+  let results = (outs XeGPU_DpasResType: $result);
 
   let extraClassDeclaration = [{
     VectorType getLhsType() {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 173f1462fdd73..3cb71788a15ef 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -17,7 +17,8 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
 def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
-def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
 def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
 def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 171a15ce27b59..269e445c3790c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
 
 namespace mlir {
 namespace xegpu {
@@ -336,19 +337,20 @@ LogicalResult TensorDescType::verify(
 //        [n_distribution_units, lane_data_size]
 FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
-  // If no layout is provided, tensor desc is not used in SIMT mode.
-  if (!layout)
+  // It only works for subgroup level layout, which only has lane_layout
+  // and lane_data, and is to distribute a SIMD code into SIMT code.
+  if (!layout || !layout.isSgLayout())
     return failure();
 
   SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
   SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
   auto tdescShape = getShape();
 
-  auto laneDataSize = 1, sgSize = 1;
-  for (auto [laneDim, laneDataDim] : llvm::zip_equal(laneLayout, laneData)) {
-    laneDataSize *= laneDataDim;
-    sgSize *= laneDim;
-  }
+  // compute sgSize by multiply elements of laneLayout
+  // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
+  // e.g. for 1D layout, sgSize = laneLayout[0]
+  auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
+                                std::multiplies<int64_t>());
 
   // Case 1: regular loads/stores
   auto scatterAttr = getEncodingAsScatterTensorDescAttr();
@@ -356,12 +358,9 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
     auto chunkSize = scatterAttr.getChunkSize().getInt();
     // Verify if the first dimension of the tensor descriptor shape is
     // distributable.
-    assert(tdescShape[0] % (laneLayout[0]) == 0 &&
+    assert(tdescShape[0] == laneLayout[0] &&
            "tensor descriptor shape is not distributable");
-    if (chunkSize > 1)
-      return VectorType::get({chunkSize / laneDataSize, laneDataSize},
-                             getElementType());
-    return VectorType::get({laneDataSize}, getElementType());
+    return VectorType::get({chunkSize}, getElementType());
   }
 
   // Case 2: block loads/stores
@@ -376,8 +375,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   // tensorSize must be adjusted for array_length.
   tensorSize *= getArrayLength();
 
-  return VectorType::get({tensorSize / (sgSize * laneDataSize), laneDataSize},
-                         getElementType());
+  return VectorType::get({tensorSize / sgSize}, getElementType());
 }
 
 } // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0d67e3d70f945..fef39508c3bfe 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,38 +73,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
-// Helper to validate value shape of LoadNd and StoreNd ops.
-static LogicalResult
-isArgShapesValid(TensorDescType tdescTy, VectorType valueTy,
-                 ArrayRef<int64_t> adjustedTdescShape,
-                 function_ref<InFlightDiagnostic()> emitError) {
-  auto layout = tdescTy.getLayoutAttr();
-  auto valueShape = valueTy.getShape();
-  // layout not present means IR is in SIMD mode. In this case value shape must
-  // match adjusted tensor descriptor shape.
-  if (!layout)
-    return valueShape == adjustedTdescShape
-               ? success()
-               : emitError()
-                     << "Value shape " << makeString(valueShape)
-                     << " is not consistent with tensor descriptor " << tdescTy;
-
-  // layout present means IR is in SIMT mode. In this case layout determines the
-  // value shape.
-  auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
-  assert(succeeded(expectedValueShapeOrFailure) &&
-         "Failed to compute distributed vector shape for "
-         "tensor descriptor ");
-
-  return valueTy == expectedValueShapeOrFailure.value()
-             ? success()
-             : emitError()
-                   << "Result shape " << makeString(valueShape)
-                   << " is not consistent with distributed vector shape "
-                   << makeString(expectedValueShapeOrFailure.value().getShape())
-                   << " for tensor descriptor " << tdescTy;
-}
-
 // Checks if the given shape is evenly distributed based on the layout
 // and data factors provided by the LayoutAttr. The function ensures that
 // each dimension of the shape can be evenly divided by the corresponding
@@ -302,9 +270,35 @@ LogicalResult LoadNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
+  // Handling a 1D vector as the result can be complex. It may represent the
+  // outcome of a 1D block load in SIMD mode or a fragment of a block load
+  // result in SIMT mode. In the latter case, the tensor descriptor must be
+  // evenly distributed, with each lane holding an equally sized fragment of
+  // the result. Only subgroup size 8 or 16 is supported.
+  if (valueTy.getRank() == 1 &&
+      valueTy.getNumElements() < tdescTy.getNumElements()) {
+    // SIMT mode doesn't need LayoutAttr.
+    if (tdescTy.getLayoutAttr())
+      return emitOpError()
+             << "TensorDesc doesn't need LayoutAttr for SIMT code";
+
+    int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
+    int valueElems = valueTy.getNumElements();
+
+    int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1;
+    if (lanes != 16 && lanes != 8) {
+      return emitOpError()
+             << "Result shape " << makeString(getShapeOf(valueTy))
+             << " is not a valid distribution for tensor descriptor "
+             << tdescTy;
+    }
+    return success();
+  }
+
+  // Check SIMD mode.
   auto array_len = tdescTy.getArrayLength();
   // adjusted tensor descriptor shape tracks the expected shape of the result.
-  auto adjustedTdescShape = getShapeOf(tdescTy);
+  auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 
   if (getTranspose()) {
@@ -316,7 +310,7 @@ LogicalResult LoadNdOp::verify() {
     });
 
     if (valid)
-      transpose(trans, adjustedTdescShape);
+      transpose(trans, tdescShape);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -325,8 +319,8 @@ LogicalResult LoadNdOp::verify() {
     if (tdescTy.getRank() == 2) {
       const int axis = 0;
       auto vnni_factor = valueShape.back();
-      adjustedTdescShape[axis] /= vnni_factor;
-      adjustedTdescShape.push_back(vnni_factor);
+      tdescShape[axis] /= vnni_factor;
+      tdescShape.push_back(vnni_factor);
     } else {
       mlir::emitWarning(getLoc())
           << "Invalid Packed Attr. It is ignored (available for 2D "
@@ -335,12 +329,16 @@ LogicalResult LoadNdOp::verify() {
   }
 
   if (array_len > 1) {
-    auto it = adjustedTdescShape.begin();
-    adjustedTdescShape.insert(it, array_len);
+    tdescShape.insert(tdescShape.begin(), array_len);
+  }
+
+  if (tdescShape != valueShape) {
+    return emitOpError() << "Result shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << tdescTy;
   }
 
-  return isArgShapesValid(tdescTy, valueTy, adjustedTdescShape,
-                          [&]() { return emitOpError(); });
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -371,8 +369,37 @@ LogicalResult StoreNdOp::verify() {
   auto tdescShape = getShapeOf(dstTy);
   auto valueShape = getShapeOf(valTy);
 
-  return isArgShapesValid(dstTy, valTy, tdescShape,
-                          [&]() { return emitOpError(); });
+  // Similar to LoadNdOp, handling a 1D vector as the value can be complex. It
+  // may represent the input of a 1D block store in SIMD mode or a fragment of
+  // a block store input in SIMT mode. In the latter case, the tensor descriptor
+  // must be evenly distributed, with each lane holding an equally sized
+  // fragment of the input. Only subgroup size 8 or 16 is supported.
+  if (valTy.getRank() == 1 && valTy.getNumElements() < dstTy.getNumElements()) {
+    // SIMT mode doesn't need LayoutAttr.
+    if (dstTy.getLayoutAttr())
+      return emitOpError()
+             << "TensorDesc doesn't need LayoutAttr for SIMT code";
+
+    int tdescElems = dstTy.getNumElements() * dstTy.getArrayLength();
+    int valueElems = valueShape[0];
+
+    int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1;
+    if (lanes != 16 && lanes != 8) {
+      return emitOpError()
+             << "Value shape " << makeString(getShapeOf(valTy))
+             << " is not a valid distribution for tensor descriptor " << dstTy;
+    }
+    return success();
+  }
+
+  // SIMD code should have the same shape as the tensor descriptor.
+  if (tdescShape != valueShape) {
+    return emitOpError() << "Value shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << dstTy;
+  }
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -520,14 +547,41 @@ LogicalResult LoadGatherOp::verify() {
   if (tdescShape[0] != maskShape[0])
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
+  auto chunkSize = tdescTy.getChunkSize();
+  // for SIMT code, the value should be 1D vector with size of chunkSize.
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
+    if (valueTy.getNumElements() != chunkSize) {
+      return emitOpError()
+             << "Result shape " << makeString(valueShape)
+             << " is not a valid distribution for tensor descriptor "
+             << tdescTy;
+    } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
+      if (tdescTy.getLayoutAttr())
+        return emitOpError()
+               << "TensorDesc doesn't need LayoutAttr for SIMT code";
+      if (getTransposeAttr())
+        return emitOpError() << "doesn't need TransposeAttr for SIMT code";
+    }
+    return success();
+  } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
+    // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
+    // it is a valid SIMT code if chunkSize happens to be the same as
+    // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
+    return success();
+  }
+
+  // For SIMD code verification.
   if (tdescTy.getRank() == 2) {
     if (!getTransposeAttr())
       return emitOpError("load of rank-2 tensor has to be transposed.");
     transpose({1, 0}, tdescShape);
   }
 
-  return isArgShapesValid(tdescTy, valueTy, tdescShape,
-                          [&]() { return emitOpError(); });
+  if (tdescShape != valueShape)
+    return emitOpError() << "Result shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << tdescTy;
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -559,14 +613,42 @@ LogicalResult StoreScatterOp::verify() {
   if (tdescShape[0] != maskShape[0])
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
+  auto chunkSize = tdescTy.getChunkSize();
+  // for SIMT code, the value should be 1D vector with size of chunkSize.
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
+    if (valueTy.getNumElements() != chunkSize) {
+      return emitOpError()
+             << "Value shape " << makeString(valueShape)
+             << " is not a valid distribution for tensor descriptor "
+             << tdescTy;
+    } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
+      if (tdescTy.getLayoutAttr())
+        return emitOpError()
+               << "TensorDesc doesn't need LayoutAttr for SIMT code";
+      if (getTransposeAttr())
+        return emitOpError() << "doesn't need TransposeAttr for SIMT code";
+    }
+    return success();
+  } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
+    // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
+    // it is a valid SIMT code if chunkSize happens to be the same as
+    // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
+    return success();
+  }
+
+  // for SIMD code verification.
   if (tdescTy.getRank() == 2) {
     if (!getTransposeAttr())
       return emitOpError("Store of a rank-2 tensor has to be transposed.");
     transpose({1, 0}, tdescShape);
   }
 
-  return isArgShapesValid(tdescTy, valueTy, tdescShape,
-                          [&]() { return emitOpError(); });
+  if (tdescShape != valueShape)
+    return emitOpError() << "Value shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << tdescTy;
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -602,51 +684,16 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
-  auto aLayout = getALayoutAttr();
-  auto bLayout = getBLayoutAttr();
-  auto cLayout = getCLayoutAttr();
-
-  // make sure the layout attribute is either set for every available
-  // operand or simply not set at all. C is special, since ACC is optional.
-  auto hasValidLayoutAttrs = [&]() {
-    bool result = (aLayout != nullptr) ^ (bLayout != nullptr);
-    if (hasAcc()) {
-      result |= (aLayout != nullptr) ^ (cLayout != nullptr);
-    }
-    return !result;
-  };
+  if (getAcc()) {
+    if (getAcc().getType() != getResultType())
+      return emitOpError("Expecting the acc type to be the same as result.");
+  }
 
-  if (!hasValidLayoutAttrs())
-    return emitOpError(
-        "layout attributes should be either set for all operands (for SIMT "
-        "code) or not set at all (for SIMD code).");
-
-  // query the scope from aLayout (a valid setting).
-  if (aLayout) {
-    // In SIMT mode, All data fragments must be 2D
-    if (lhsRank != 2 || rhsRank != 2 || resRank != 2)
-      return emitOpError("expecting lhs, rhs, and result to be a 2D vector.");
-
-    auto laneLayoutA = aLayout.getLaneLayout();
-    auto laneLayoutB = bLayout.getLaneLayout();
-    auto laneLayoutC = cLayout.getLaneLayout();
-    // Obtain the expanded shapes of the operands and result using lane_layout.
-    // NOTE: For B, get rid of the packed dimension for the expanded shape.
-    SmallVector<int64_t> expandedShapeA = {lhsShape[0] * laneLayoutA[0],
-                                           lhsShape[1] * laneLayoutA[1]};
-    SmallVector<int64_t> expandedShapeB = {
-        rhsShape[0] * rhsShape[1] * laneLayoutB[0], 1 * laneLayoutB[1]};
-    SmallVector<int64_t> expandedShapeC = {resShape[0] * laneLayoutC[0],
-                                           resShape[1] * laneLayoutC[1]};
-    auto bK = expandedShapeB[0];
-    if (bK != expandedShapeA[1])
-      return emitOpError("K-dimension mismatch.");
-    if (expandedShapeA[0] != expandedShapeC[0])
-      return emitOpError("M-dimension mismatch.");
-    if (expandedShapeB[1] != expandedShapeC[1])
-      return emitOpError("N-dimension mismatch.");
-  } else { // For other scopes, operands' shape should match the mxkxn
-           // semantics.
+  // SIMT code: skip the check since lack of semantic info at this level.
+  // Users need to ensure the correctness.
+  if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
+    return success();
+  } else { // SIMD code
     if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
       return emitOpError(
           "expecting lhs and result to be a 2D vector, and rhs to be either "
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 48df33a591908..c0739d735dfec 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -79,25 +79,10 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
 
 // -----
 func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<8x16xf32,   #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1]}}
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
-      l2_hint = #xegpu.cache_hint<uncached>}>
-    : !xegpu.tensor_desc<8x16xf32,   #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    -> vector<8x2xf32>
-  return
-}
-
-// -----
-func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  // expected-error at +1 {{Result shape [8] is not consistent with distributed vector shape [1, 1]}}
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
+  // expected-error at +1 {{Result shape [8] is not a valid distribution for tensor descriptor}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
-      l2_hint = #xegpu.cache_hint<uncached>}>
-    : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    -> vector<8xf32>
+      l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf32> -> vector<8xf32>
   return
 }
 
@@ -105,7 +90,7 @@ func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
 func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32>
-  // expected-error at +1 {{Value shape [8, 1] is not consistent with tensor descriptor}}
+  // expected-error at +1 {{Result shape [8, 1] is not consistent with tensor descriptor}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
       l2_hint = #xegpu.cache_hint<uncached>}>
     : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
@@ -134,22 +119,10 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
 }
 
 // -----
-func.func @test_store_nd_layout(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
-  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<8x16xf32,   #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1] for tensor descriptor}}
-  xegpu.store_nd %data, %1
-    : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32,   #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  return
-}
-
-// -----
-func.func @test_store_nd_layout(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
-  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  // expected-error at +1 {{Result shape [2] is not consistent with distributed vector shape [1, 1] for tensor descriptor}}
-  xegpu.store_nd %data, %1
-    : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<4xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
+  // expected-error at +1 {{Value shape [4] is not a valid distribution for tensor descriptor}}
+  xegpu.store_nd %data, %1 : vector<4xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 
@@ -269,45 +242,23 @@ func.func @test_create_tdesc_layout_3(%src: ui64) {
 }
 
 // -----
-func.func @test_load_gather_layout_1(%src: ui64) {
+func.func @test_load_gather_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,   #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  // expected-error at +1 {{Result shape [1, 2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1> -> vector<1x2xf32>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // expected-error at +1 {{Result shape [6] is not a valid distribution for tensor descriptor}}
+  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<6xf32>
   return
 }
 
 // -----
-func.func @test_load_gather_layout_2(%src: ui64) {
+func.func @test_store_scatter_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  // expected-error at +1 {{esult shape [2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1> -> vector<2xf32>
-  return
-}
-
-
-// -----
-func.func @test_store_scatter_layout_1(%src: ui64) {
-  %0 = arith.constant dense<1>: vector<4xi1>
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %val = arith.constant dense<2.9>: vector<1x2xf32>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  // expected-error at +1 {{Result shape [1, 2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
-  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1>
-  return
-}
-
-// -----
-func.func @test_store_scatter_layout_2(%src: ui64) {
-  %0 = arith.constant dense<1>: vector<4xi1>
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %val = arith.constant dense<2.9>: vector<2xf32>
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  // expected-error at +1 {{esult shape [2] is not consistent with distributed vector shape [2, 1] for tensor descriptor}}
-  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,  #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1>
+  %val = arith.constant dense<2.9>: vector<6xf32>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // expected-error at +1 {{Value shape [6] is not a valid distribution for tensor descriptor}}
+  xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : vector<6xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   return
 }
 
@@ -393,23 +344,6 @@ func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
   return
 }
 
-// -----
-func.func @test_dpas_layout_1(%a : vector<8x1xf16>, %b: vector<8x2xf16>) {
-  // expected-error at +1 {{layout attributes should be either set for all operands (for SIMT code) or not set at all (for SIMD code)}}
-  %1 = xegpu.dpas %a, %b {a_layout =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
-  return
-}
-
-// -----
-func.func @test_dpas_layout_2(%a : vector<8x1xf16>, %b: vector<4x2xf16>) {
-  // expected-error at +1 {{K-dimension mismatch}}
-  %1 = xegpu.dpas %a, %b {a_layout =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-                          b_layout =  #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-                          c_layout =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-                          : vector<8x1xf16>, vector<4x2xf16> -> vector<8x1xf32>
-  return
-}
-
 // -----
 func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
   %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index e9895e0d0a71d..71e7e9bdda07d 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -125,11 +125,11 @@ gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
 
 // CHECK: func @test_load_nd_simt(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @test_load_nd_simt(%src: memref<8x16xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<4x2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
   %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-       : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<4x2xf16>
+       : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
   gpu.return
 }
 
@@ -144,10 +144,10 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
 
 // CHECK: func @test_load_nd_simt_2(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @test_load_nd_simt_2(%src: memref<8x16xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<1x1xf16>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<1x1xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
   gpu.return
 }
 
@@ -162,11 +162,10 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
 
 // CHECK: func @test_load_nd_simt_3(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_load_nd_simt_3(%src: memref<24x32xf32>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x1xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x1xf32>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
   gpu.return
 }
 
@@ -181,11 +180,10 @@ gpu.func @test_load_nd_vc_4(%src: memref<24x32xf16>) {
 
 // CHECK: func @test_load_nd_simt_4(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_load_nd_simt_4(%src: memref<24x32xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
-    !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<8x2xf16>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<8x2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
   gpu.return
 }
 
@@ -200,11 +198,10 @@ gpu.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
 
 // CHECK: func @test_load_nd_simt_5(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_load_nd_simt_5(%src: memref<24x32xf32>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<2x1xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<2x1xf32>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
   gpu.return
 }
 
@@ -219,11 +216,11 @@ gpu.func @test_load_nd_vc_6(%src: memref<24x32xf16>) {
 
 // CHECK: func @test_load_nd_simt_6(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_load_nd_simt_6(%src: memref<24x32xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x1xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
-    !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x1xf16>
+    !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>> -> vector<32xf16>
   gpu.return
 }
 
@@ -238,11 +235,11 @@ gpu.func @test_load_nd_vc_7(%src: memref<24x32xf16>) {
 
 // CHECK: func @test_load_nd_simt_7(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_load_nd_simt_7(%src: memref<24x32xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
-    !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x2xf16>
+    !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>> -> vector<32xf16>
   gpu.return
 }
 
@@ -257,10 +254,10 @@ gpu.func @test_load_nd_vc_8(%src: memref<24x32xf32>) {
 
 // CHECK: func @test_load_nd_simt_8(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_load_nd_simt_8(%src: memref<24x32xf32>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<8x1xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<8x1xf32>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
   gpu.return
 }
 
@@ -277,13 +274,12 @@ gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
 
 // CHECK: func @test_store_nd_simt(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_simt(%src: memref<24x32xf16>) {
-   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48x1xf16>
-  %1 = arith.constant dense<1.0>: vector<48x1xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
-    !xegpu.tensor_desc<24x32xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48x1xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<48xf16>
+  %1 = arith.constant dense<1.0>: vector<48xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<48xf16>, !xegpu.tensor_desc<24x32xf16>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<48xf16>, !xegpu.tensor_desc<24x32xf16>
   gpu.return
 }
 
@@ -303,13 +299,12 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
 
 // CHECK: func @test_store_nd_simt_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_simt_2(%src: memref<24x32xf16>) {
-   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2x1xf16>
-  %1 = arith.constant dense<1.0>: vector<2x1xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
-    !xegpu.tensor_desc<32xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2x1xf16>, !xegpu.tensor_desc<32xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+  %1 = arith.constant dense<1.0>: vector<2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>
   gpu.return
 }
 
@@ -425,10 +420,10 @@ gpu.func @test_load_simt(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1> -> vector<2x1xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1> -> vector<2x1xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2xf32>
   gpu.return
 }
 
@@ -451,10 +446,10 @@ gpu.func @test_load_simt_2(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>
-  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>, vector<4xi1> -> vector<1xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>, vector<4xi1> -> vector<1xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<1xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1> -> vector<1xf32>
   gpu.return
 }
 
@@ -477,10 +472,10 @@ gpu.func @test_load_simt_3(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>
-  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>, vector<4xi1> -> vector<4x2xf16>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>, vector<4xi1> -> vector<4x2xf16>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8xf16>
   gpu.return
 }
 
@@ -507,12 +502,12 @@ gpu.func @test_store_simt(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
-  %2 = arith.constant dense<2.9>: vector<2x1xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
+  %2 = arith.constant dense<2.9>: vector<2xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   gpu.return
 }
 
@@ -539,12 +534,12 @@ gpu.func @test_store_simt_2(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<1x2xf16>
-  %2 = arith.constant dense<2.9>: vector<1x2xf16>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>
-  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>
-  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>, vector<4xi1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<1x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>, vector<4xi1>
+  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2xf16>
+  %2 = arith.constant dense<2.9>: vector<2xf16>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   gpu.return
 }
 
@@ -572,10 +567,10 @@ gpu.func @test_store_simt_3(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
   %2 = arith.constant dense<2.9>: vector<1xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>
-  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>
-  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>, vector<4xi1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4], lane_data = [1]>>, vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>, vector<4xi1>
   gpu.return
 }
 
@@ -635,15 +630,10 @@ gpu.func @test_dpas_vc(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_dpas_simt(%[[arg0:.*]]: vector<8x1xf16>, %[[arg1:.*]]: vector<8x2xf16>)
-gpu.func @test_dpas_simt(%a : vector<8x1xf16>, %b: vector<8x2xf16>) {
-  // CHECK: xegpu.dpas %[[arg0]], %[[arg1]] {a_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-  // CHECK: b_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-  // CHECK: c_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
-  %1 = xegpu.dpas %a, %b {a_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-                          b_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-                          c_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-                          : vector<8x1xf16>, vector<8x2xf16> -> vector<8x1xf32>
+// CHECK: gpu.func @test_dpas_simt(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: vector<16xf16>)
+gpu.func @test_dpas_simt(%a : vector<8xf16>, %b: vector<16xf16>) {
+  // CHECK: xegpu.dpas %[[arg0]], %[[arg1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+  %1 = xegpu.dpas %a, %b : vector<8xf16>, vector<16xf16> -> vector<8xf32>
   gpu.return
 }
 

>From 2159119977dfb62c11d808777529dd34ed0abd43 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 10 Apr 2025 20:25:00 +0000
Subject: [PATCH 2/6] refine verfier for load_nd and store_nd

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  4 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 53 +++++++++----------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 19 +++++--
 3 files changed, 43 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 9af6eaf69aec3..5fa18754305ca 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -840,7 +840,9 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
     can be represented as `B: vector<8x16x2xf16>`.
 
     In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
-    which are represented as 1D vectors.
+    which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
+    (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
+    for more details about the fragment distribution.
 
     Note: on PVC, the hardware can perform load with VNNI transformation when data
           element type is 16-bit or lower precision, taking 2 or 4 elements from
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index fef39508c3bfe..1dafc9936107e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -270,33 +270,31 @@ LogicalResult LoadNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  // Handling a 1D vector as the result can be complex. It may represent the
-  // outcome of a 1D block load in SIMD mode or a fragment of a block load
-  // result in SIMT mode. In the latter case, the tensor descriptor must be
-  // evenly distributed, with each lane holding an equally sized fragment of
-  // the result. Only subgroup size 8 or 16 is supported.
-  if (valueTy.getRank() == 1 &&
-      valueTy.getNumElements() < tdescTy.getNumElements()) {
+  int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
+  int valueElems = valueTy.getNumElements();
+
+  // If the result vector is 1D and has less elements than the tensor
+  // descriptor, it is supposed to be a SIMT op. The layout attribute in
+  // tensor_desc is not needed.
+  if (valueElems < tdescElems && valueTy.getRank() == 1) {
     // SIMT mode doesn't need LayoutAttr.
     if (tdescTy.getLayoutAttr())
       return emitOpError()
              << "TensorDesc doesn't need LayoutAttr for SIMT code";
 
-    int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
-    int valueElems = valueTy.getNumElements();
-
-    int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1;
-    if (lanes != 16 && lanes != 8) {
+    // For SIMT code, the load is evenly distributed across all lanes in a
+    // subgroup. Since subgroup size is arch dependent, we only check even
+    // distribution here.
+    if (tdescElems % valueElems)
       return emitOpError()
              << "Result shape " << makeString(getShapeOf(valueTy))
              << " is not a valid distribution for tensor descriptor "
              << tdescTy;
-    }
+
     return success();
   }
 
   // Check SIMD mode.
-  auto array_len = tdescTy.getArrayLength();
   // adjusted tensor descriptor shape tracks the expected shape of the result.
   auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
@@ -328,6 +326,7 @@ LogicalResult LoadNdOp::verify() {
     }
   }
 
+  auto array_len = tdescTy.getArrayLength();
   if (array_len > 1) {
     tdescShape.insert(tdescShape.begin(), array_len);
   }
@@ -366,25 +365,23 @@ LogicalResult StoreNdOp::verify() {
   if (!isWriteHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  auto tdescShape = getShapeOf(dstTy);
-  auto valueShape = getShapeOf(valTy);
+  auto array_len = dstTy.getArrayLength();
+  if (array_len > 1)
+    return emitOpError("array length is not supported by store_nd.\n");
+
+  auto tdescElems = dstTy.getNumElements();
+  auto valueElems = valTy.getNumElements();
 
-  // Similar to LoadNdOp, handling a 1D vector as the value can be complex. It
-  // may represent the input of a 1D block store in SIMD mode or a fragment of
-  // a block store input in SIMT mode. In the latter case, the tensor descriptor
-  // must be evenly distributed, with each lane holding an equally sized
-  // fragment of the input. Only subgroup size 8 or 16 is supported.
-  if (valTy.getRank() == 1 && valTy.getNumElements() < dstTy.getNumElements()) {
+  // Similar to LoadNdOp, if the value vector is 1D and has less elements than
+  // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
+  // in tensor_desc is not needed.
+  if (valTy.getRank() == 1 && valueElems < tdescElems) {
     // SIMT mode doesn't need LayoutAttr.
     if (dstTy.getLayoutAttr())
       return emitOpError()
              << "TensorDesc doesn't need LayoutAttr for SIMT code";
 
-    int tdescElems = dstTy.getNumElements() * dstTy.getArrayLength();
-    int valueElems = valueShape[0];
-
-    int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1;
-    if (lanes != 16 && lanes != 8) {
+    if (tdescElems % valueElems) {
       return emitOpError()
              << "Value shape " << makeString(getShapeOf(valTy))
              << " is not a valid distribution for tensor descriptor " << dstTy;
@@ -393,6 +390,8 @@ LogicalResult StoreNdOp::verify() {
   }
 
   // SIMD code should have the same shape as the tensor descriptor.
+  auto tdescShape = getShapeOf(dstTy);
+  auto valueShape = getShapeOf(valTy);
   if (tdescShape != valueShape) {
     return emitOpError() << "Value shape " << makeString(valueShape)
                          << " is not consistent with tensor descriptor "
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index c0739d735dfec..a02427b6e317b 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -80,9 +80,9 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
 // -----
 func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
-  // expected-error at +1 {{Result shape [8] is not a valid distribution for tensor descriptor}}
+  // expected-error at +1 {{Result shape [3] is not a valid distribution for tensor descriptor}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
-      l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf32> -> vector<8xf32>
+      l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf32> -> vector<3xf32>
   return
 }
 
@@ -119,10 +119,19 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
 }
 
 // -----
-func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<4xf32>) {
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf16>) {
+  %1 = arith.constant dense<1.0>: vector<2x24x32xf16>
+  %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+  // expected-error at +1 {{array length is not supported by store_nd}}
+  xegpu.store_nd %1, %2: vector<2x24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<array_length = 2>>
+  return
+}
+
+// -----
+func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) {
   %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
-  // expected-error at +1 {{Value shape [4] is not a valid distribution for tensor descriptor}}
-  xegpu.store_nd %data, %1 : vector<4xf32>, !xegpu.tensor_desc<16xf32>
+  // expected-error at +1 {{Value shape [3] is not a valid distribution for tensor descriptor}}
+  xegpu.store_nd %data, %1 : vector<3xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 

>From 775d039bb7a5ba9fd91939411e2d69312879f1e0 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 15 Apr 2025 18:46:55 +0000
Subject: [PATCH 3/6] refine verifier for gather/scatter

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 62 +++++++++-----------------
 mlir/test/Dialect/XeGPU/invalid.mlir   |  4 +-
 2 files changed, 22 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1dafc9936107e..f5205c5e7e5bc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -547,30 +547,18 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
   auto chunkSize = tdescTy.getChunkSize();
-  // for SIMT code, the value should be 1D vector with size of chunkSize.
-  if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
-    if (valueTy.getNumElements() != chunkSize) {
+
+  // a valid shape for SIMT case
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
+    if (tdescTy.getLayoutAttr())
       return emitOpError()
-             << "Result shape " << makeString(valueShape)
-             << " is not a valid distribution for tensor descriptor "
-             << tdescTy;
-    } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
-      if (tdescTy.getLayoutAttr())
-        return emitOpError()
-               << "TensorDesc doesn't need LayoutAttr for SIMT code";
-      if (getTransposeAttr())
-        return emitOpError() << "doesn't need TransposeAttr for SIMT code";
-    }
-    return success();
-  } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
-    // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
-    // it is a valid SIMT code if chunkSize happens to be the same as
-    // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
+             << "TensorDesc doesn't need LayoutAttr for SIMT code";
+    if (getTransposeAttr())
+      return emitOpError() << "doesn't need TransposeAttr for SIMT code";
     return success();
   }
 
-  // For SIMD code verification.
-  if (tdescTy.getRank() == 2) {
+  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
     if (!getTransposeAttr())
       return emitOpError("load of rank-2 tensor has to be transposed.");
     transpose({1, 0}, tdescShape);
@@ -578,7 +566,8 @@ LogicalResult LoadGatherOp::verify() {
 
   if (tdescShape != valueShape)
     return emitOpError() << "Result shape " << makeString(valueShape)
-                         << " is not consistent with tensor descriptor "
+                         << " is neither a valid distribution for SIMT nor "
+                            "consistent with the tensor descriptor for SIMD "
                          << tdescTy;
   return success();
 }
@@ -613,30 +602,18 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
   auto chunkSize = tdescTy.getChunkSize();
-  // for SIMT code, the value should be 1D vector with size of chunkSize.
-  if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
-    if (valueTy.getNumElements() != chunkSize) {
+
+  // a valid shape for SIMT case
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
+    if (tdescTy.getLayoutAttr())
       return emitOpError()
-             << "Value shape " << makeString(valueShape)
-             << " is not a valid distribution for tensor descriptor "
-             << tdescTy;
-    } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
-      if (tdescTy.getLayoutAttr())
-        return emitOpError()
-               << "TensorDesc doesn't need LayoutAttr for SIMT code";
-      if (getTransposeAttr())
-        return emitOpError() << "doesn't need TransposeAttr for SIMT code";
-    }
-    return success();
-  } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
-    // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
-    // it is a valid SIMT code if chunkSize happens to be the same as
-    // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
+             << "TensorDesc doesn't need LayoutAttr for SIMT code";
+    if (getTransposeAttr())
+      return emitOpError() << "doesn't need TransposeAttr for SIMT code";
     return success();
   }
 
-  // for SIMD code verification.
-  if (tdescTy.getRank() == 2) {
+  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
     if (!getTransposeAttr())
       return emitOpError("Store of a rank-2 tensor has to be transposed.");
     transpose({1, 0}, tdescShape);
@@ -644,7 +621,8 @@ LogicalResult StoreScatterOp::verify() {
 
   if (tdescShape != valueShape)
     return emitOpError() << "Value shape " << makeString(valueShape)
-                         << " is not consistent with tensor descriptor "
+                         << " is neither a valid distribution for SIMT nor "
+                            "consistent with the tensor descriptor for SIMD "
                          << tdescTy;
 
   return success();
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a02427b6e317b..2a7436807f5f4 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{Result shape [6] is not a valid distribution for tensor descriptor}}
+  // expected-error at +1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<6xf32>
   return
 }
@@ -266,7 +266,7 @@ func.func @test_store_scatter_simt_1(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %val = arith.constant dense<2.9>: vector<6xf32>
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{Value shape [6] is not a valid distribution for tensor descriptor}}
+  // expected-error at +1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
   xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : vector<6xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   return
 }

>From 5520ce18138b5153d7ecb874fe10be78127d719e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 15 Apr 2025 18:59:39 +0000
Subject: [PATCH 4/6] update comments

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 13 +++++++------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp     |  1 -
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 269e445c3790c..b865b80f0075e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -320,21 +320,22 @@ LogicalResult TensorDescType::verify(
 // ---------------------------------------------------------------------
 // Case 1: Regular loads/stores.
 // ---------------------------------------------------------------------
-// Distributed vector shape must be:
-//        [chunk_size / lane_data_size, lane_data_size]
-// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
-//        [lane_data_size]
+// The following conditions must be met:
+//        * tensor_desc[0] == lane_layout[0]
+// Distributed vector is a 1D vector with shape:
+//        [chunk_size]
 // ---------------------------------------------------------------------
 // Case 2: Block loads/stores
 // ---------------------------------------------------------------------
 // Additional definitions:
 //        tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
 //        n_distribution_units = tensor_size / distribution_unit_size
+//        fragment_size = n_distribution_units * lane_data_size
 // Given above definitions, the following conditions must be met:
 //        * tensor_desc[0] % (lane_layout[0] × lane_data[0]) == 0
 //        * tensor_desc[1] % (lane_layout[1] × lane_data[1]) == 0
-// Distributed vector shape must be:
-//        [n_distribution_units, lane_data_size]
+// Distributed vector is a 1D vector with shape:
+//        [fragment_size]
 FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
   // It only works for subgroup level layout, which only has lane_layout
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f5205c5e7e5bc..4305c0431cc7e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -295,7 +295,6 @@ LogicalResult LoadNdOp::verify() {
   }
 
   // Check SIMD mode.
-  // adjusted tensor descriptor shape tracks the expected shape of the result.
   auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 

>From 7072bc1bf5a36613adf5f0cdb201c4dbeb1b81f5 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 17 Apr 2025 15:41:57 +0000
Subject: [PATCH 5/6] refator verifiers for load_gather, store_scatter and dpas

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 144 ++++++++++---------------
 mlir/test/Dialect/XeGPU/invalid.mlir   |  11 +-
 2 files changed, 66 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4305c0431cc7e..b02490909e067 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -101,6 +101,48 @@ static bool isEvenDistributed(llvm::ArrayRef<int64_t> shape,
   return true;
 }
 
+static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref<InFlightDiagnostic()> emitError) {
+
+  if (!tdescTy.isScattered())
+    return emitError() << "Expects a scattered TensorDesc.";
+
+  if (!valueTy)
+    return emitError() << "Expecting a vector type result.";
+
+  auto maskShape = getShapeOf(maskTy);
+  auto valueShape = getShapeOf(valueTy);
+  auto tdescShape = getShapeOf(tdescTy);
+  auto chunkSize = tdescTy.getChunkSize();
+
+  if (valueTy.getElementType() != tdescTy.getElementType())
+    return emitError() << "Value should have the same element type as TensorDesc.";
+
+  if (tdescShape[0] != maskShape[0])
+    return emitError() << "dim-0 of the Mask and TensorDesc should be the same.";
+
+  // a valid shape for SIMT case
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
+    if (tdescTy.getLayoutAttr())
+      return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
+    if (transposeAttr)
+      return emitError() << "doesn't need TransposeAttr for SIMT code";
+    return success();
+  }
+
+  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
+    if (!transposeAttr)
+      return emitError() << "rank-2 tensor has to be transposed.";
+    transpose({1, 0}, tdescShape);
+  }
+
+  if (tdescShape != valueShape)
+    return emitError() << "Value shape " << makeString(valueShape)
+                       << " is neither a valid distribution for SIMT nor "
+                          "consistent with the tensor descriptor for SIMD "
+                       << tdescTy;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -517,12 +559,6 @@ LogicalResult LoadGatherOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
-  if (!valueTy)
-    return emitOpError("Expecting a vector type result.\n");
-
-  if (!tdescTy.isScattered())
-    return emitOpError("Expects a scattered TensorDesc.\n");
-
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
@@ -532,43 +568,8 @@ LogicalResult LoadGatherOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  auto tdescElemTy = tdescTy.getElementType();
-  auto valueElemTy = getElementType();
-  if (tdescElemTy != valueElemTy)
-    return emitOpError(
-        "Value should have the same element type as TensorDesc.");
-
-  auto maskShape = getShapeOf(maskTy);
-  auto valueShape = getShapeOf(valueTy);
-  auto tdescShape = getShapeOf(tdescTy);
-
-  if (tdescShape[0] != maskShape[0])
-    return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
-
-  auto chunkSize = tdescTy.getChunkSize();
-
-  // a valid shape for SIMT case
-  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
-    if (tdescTy.getLayoutAttr())
-      return emitOpError()
-             << "TensorDesc doesn't need LayoutAttr for SIMT code";
-    if (getTransposeAttr())
-      return emitOpError() << "doesn't need TransposeAttr for SIMT code";
-    return success();
-  }
-
-  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
-    if (!getTransposeAttr())
-      return emitOpError("load of rank-2 tensor has to be transposed.");
-    transpose({1, 0}, tdescShape);
-  }
-
-  if (tdescShape != valueShape)
-    return emitOpError() << "Result shape " << makeString(valueShape)
-                         << " is neither a valid distribution for SIMT nor "
-                            "consistent with the tensor descriptor for SIMD "
-                         << tdescTy;
-  return success();
+  return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,8 +577,8 @@ LogicalResult LoadGatherOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult StoreScatterOp::verify() {
   auto tdescTy = getTensorDescType();
-  if (!tdescTy.isScattered())
-    return emitOpError("Expects a scattered TensorDesc.\n");
+  auto maskTy = getMaskType();
+  auto valueTy = getValueType();
 
   if (!isWriteHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -588,43 +589,8 @@ LogicalResult StoreScatterOp::verify() {
   if (!isWriteHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  auto maskTy = getMaskType();
-  auto valueTy = getValueType();
-
-  if (!valueTy)
-    return emitOpError("Expecting a vector type for the value.\n");
-
-  auto maskShape = getShapeOf(maskTy);
-  auto tdescShape = getShapeOf(tdescTy);
-  auto valueShape = getShapeOf(valueTy);
-  if (tdescShape[0] != maskShape[0])
-    return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
-
-  auto chunkSize = tdescTy.getChunkSize();
-
-  // a valid shape for SIMT case
-  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
-    if (tdescTy.getLayoutAttr())
-      return emitOpError()
-             << "TensorDesc doesn't need LayoutAttr for SIMT code";
-    if (getTransposeAttr())
-      return emitOpError() << "doesn't need TransposeAttr for SIMT code";
-    return success();
-  }
-
-  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
-    if (!getTransposeAttr())
-      return emitOpError("Store of a rank-2 tensor has to be transposed.");
-    transpose({1, 0}, tdescShape);
-  }
-
-  if (tdescShape != valueShape)
-    return emitOpError() << "Value shape " << makeString(valueShape)
-                         << " is neither a valid distribution for SIMT nor "
-                            "consistent with the tensor descriptor for SIMD "
-                         << tdescTy;
-
-  return success();
+  return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -660,14 +626,18 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
-  if (getAcc()) {
-    if (getAcc().getType() != getResultType())
-      return emitOpError("Expecting the acc type to be the same as result.");
-  }
+  if (getAcc() && getAcc().getType() != getResultType())
+    return emitOpError("Expecting the acc type to be the same as result.");
 
-  // SIMT code: skip the check since lack of semantic info at this level.
+  // SIMT code: the size of the B operand has to be a multiple of 32 bits.
+  // It skips the semantic check since lack of architecture information.
   // Users need to ensure the correctness.
   if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
+    auto numElems = getRhsType().getNumElements();
+    auto elemTy = getRhsType().getElementType();
+    auto factor = 32 / elemTy.getIntOrFloatBitWidth();
+    if (numElems % factor != 0)
+      return emitOpError("Expecting B operand to be a multiple of 32 bits.");
     return success();
   } else { // SIMD code
     if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 2a7436807f5f4..67ed89e11b4c9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
+  // expected-error at +1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<6xf32>
   return
 }
@@ -347,12 +347,19 @@ func.func @test_dpas_4(%a : vector<16x16xf16>, %b: vector<8x16x2xf16>) {
 }
 
 // -----
-func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
+func.func @test_dpas_5(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
   // expected-error at +1 {{N-dimension mismatch}}
   %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x8x2xf16> -> vector<8x16xf32>
   return
 }
 
+// -----
+func.func @test_dpas_simt_1(%a : vector<8xf16>, %b: vector<15xf16>) {
+  // expected-error at +1 {{Expecting B operand to be a multiple of 32 bits}}
+  %1 = xegpu.dpas %a, %b : vector<8xf16>, vector<15xf16> -> vector<8xf32>
+  return
+}
+
 // -----
 func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
   %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>

>From 605c99ec7291f8a7f3dec4e7951e196b4f6f27b5 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 17 Apr 2025 15:44:58 +0000
Subject: [PATCH 6/6] fix format

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index b02490909e067..1da2752f44b99 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -101,7 +101,10 @@ static bool isEvenDistributed(llvm::ArrayRef<int64_t> shape,
   return true;
 }
 
-static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref<InFlightDiagnostic()> emitError) {
+static LogicalResult
+isValidGatherScatterParams(Type maskTy, VectorType valueTy,
+                           TensorDescType tdescTy, UnitAttr transposeAttr,
+                           function_ref<InFlightDiagnostic()> emitError) {
 
   if (!tdescTy.isScattered())
     return emitError() << "Expects a scattered TensorDesc.";
@@ -115,10 +118,12 @@ static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   auto chunkSize = tdescTy.getChunkSize();
 
   if (valueTy.getElementType() != tdescTy.getElementType())
-    return emitError() << "Value should have the same element type as TensorDesc.";
+    return emitError()
+           << "Value should have the same element type as TensorDesc.";
 
   if (tdescShape[0] != maskShape[0])
-    return emitError() << "dim-0 of the Mask and TensorDesc should be the same.";
+    return emitError()
+           << "dim-0 of the Mask and TensorDesc should be the same.";
 
   // a valid shape for SIMT case
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
@@ -568,8 +573,9 @@ LogicalResult LoadGatherOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(),
-                         [&]() { return emitOpError(); });
+  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+                                    getTransposeAttr(),
+                                    [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -589,8 +595,9 @@ LogicalResult StoreScatterOp::verify() {
   if (!isWriteHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(),
-                         [&]() { return emitOpError(); });
+  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+                                    getTransposeAttr(),
+                                    [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list