[Mlir-commits] [mlir] [MLIR][XeGPU] Remove the transpose attribute from LoadGatherOp and StoreScatterOp (PR #145389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 12:00:43 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

As the title suggests. This PR removes the transpose attribute from the definition of `LoadGatherOp` and `StoreScatterOp`. It is meaningful in the context of SIMD lowering pipeline, but not for SIMT lowering pipeline. 

---

Patch is 45.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145389.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+12-22) 
- (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+2) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+3-24) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+39-29) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+5-10) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+14-22) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+18-27) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+23-23) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+28-28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e6c7efc47593f..ffc08e9b90b56 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -609,12 +609,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let description = [{ It (aka. load) load data per each work-item. The output
     describes the data being loaded at the subgroup level, so its size is
     consistent with the number of work-items in a subgroup. When the chunk size
-    is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
-    to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
-    Specially, there is a transpose effect on the result (as compared to the TensorDesc)
-    due to the hardware implementation. Therefore, a transpose attribute is introduced
-    on purpose, making sure users are aware of this implicit transformation.
-
+    is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
+    to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
@@ -634,8 +630,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   Example 2:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
@@ -643,20 +638,18 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   ```
   Example 3 (SIMT mode):
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-            vector<16xi1> -> vector<8x1xf32>
+            vector<16xi1> -> vector<8xf32>
   ```
 
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        XeGPU_MaskType: $mask,
-                       OptionalAttr<UnitAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -714,19 +707,17 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
 
   Example 2:
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
           : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
   Example 3 (SIMT mode):
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
+          : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
   ```
 
@@ -736,7 +727,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc: $TensorDesc,
     XeGPU_MaskType: $mask,
-    OptionalAttr<UnitAttr>: $transpose,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 772cf73649646..09311e6017d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
+constexpr unsigned packedSizeInBitsForGatherScatter =
+    32; // Minimum packing size per register for Gather and Scatter ops.
 } // namespace targetinfo
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0afc502c026f7..f0fb03d4f1139 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,13 +20,6 @@
 namespace mlir {
 namespace xegpu {
 
-static void transpose(llvm::ArrayRef<int64_t> trans,
-                      SmallVector<int64_t> &shape) {
-  SmallVector<int64_t> old = shape;
-  for (size_t i = 0; i < trans.size(); i++)
-    shape[i] = old[trans[i]];
-}
-
 template <typename T>
 static std::string makeString(T array, bool breakline = false) {
   std::string buf;
@@ -76,7 +69,7 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 
 static LogicalResult
 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
-                           TensorDescType tdescTy, UnitAttr transposeAttr,
+                           TensorDescType tdescTy,
                            function_ref<InFlightDiagnostic()> emitError) {
 
   if (!tdescTy.isScattered())
@@ -102,17 +95,9 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
     if (tdescTy.getLayoutAttr())
       return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
-    if (transposeAttr)
-      return emitError() << "doesn't need TransposeAttr for SIMT code";
     return success();
   }
 
-  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
-    if (!transposeAttr)
-      return emitError() << "rank-2 tensor has to be transposed.";
-    transpose({1, 0}, tdescShape);
-  }
-
   if (tdescShape != valueShape)
     return emitError() << "Value shape " << makeString(valueShape)
                        << " is neither a valid distribution for SIMT nor "
@@ -310,13 +295,9 @@ LogicalResult LoadNdOp::verify() {
 
   if (getTranspose()) {
     auto trans = getTranspose().value();
-
     // Make sure the transpose value is valid.
-    bool valid = llvm::all_of(
-        trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); });
-
-    if (valid)
-      transpose(trans, tdescShape);
+    if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
+      tdescShape = applyPermutation(tdescShape, trans);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -536,7 +517,6 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
@@ -558,7 +538,6 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index cc22d2bbd8c39..60ccd823775a5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -213,6 +213,35 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
                     LaneData({1, packingFactor}));
 }
 
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+  // Expecting a 1D or 2D vector.
+  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
+         "Expected 1D or 2D TensorDesc.");
+  // Expecting int or float element type.
+  assert(tdescTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  // If the rank is 1, then return default layout for 1D vector.
+  if (tdescTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(1);
+  // Packing factor is determined by the element type bitwidth.
+  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+
+  if (tdescTy.isScattered()) {
+    int packingFactor =
+        xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+                      LaneData({1, packingFactor}));
+  }
+
+  int packingFactor =
+      (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+          ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
+          : 1;
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, packingFactor}));
+}
+
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
 /// is set according to the following criteria:
 /// * For A operand, the data must be packed in minimum
@@ -379,8 +408,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
   // Here we assign the default layout to the tensor descriptor operand of
   // prefetch.
   auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultSIMTLayoutInfo(
-      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
   // Propagate the layout to the source tensor descriptor.
   propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
 }
@@ -516,24 +544,14 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
 void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
+  // The layout is strictly determined by the tensor descriptor type.
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
 
-  LayoutInfo tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    // LoadGatherOp has the transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+
   // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
@@ -567,21 +585,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
-  LayoutInfo valueLayout =
-      getDefaultSIMTLayoutInfo(storeScatter.getValueType());
-  LayoutInfo storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    // StoreScatteOp allows transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at LayoutInfoPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
+  LayoutInfo layout =
+      getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+
   // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
   // Use default 1D layout for mask operand.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 0457f8128b908..be39ee1f0b53f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -507,8 +507,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
         for (int64_t i = 0; i < numNewChunks; ++i)
           convertedMasks.push_back(mask);
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
       newValueTy = valueTy.cloneWith(*targetShape, elemTy);
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
@@ -519,8 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
       auto newOp = rewriter.create<xegpu::LoadGatherOp>(
-          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -598,9 +596,6 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
           convertedMasks.push_back(mask);
         }
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
-
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
       convertedMasks =
@@ -616,9 +611,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
       Value v = convertedValues[i];
       Value t = convertedTdescs[i];
       Value m = op.getMask() ? convertedMasks[i] : nullptr;
-      rewriter.create<xegpu::StoreScatterOp>(
-          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+      rewriter.create<xegpu::StoreScatterOp>(loc, v, t, m, op.getL1HintAttr(),
+                                             op.getL2HintAttr(),
+                                             op.getL3HintAttr());
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 054c4d12fdb28..5ceb548221758 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -199,8 +199,8 @@ gpu.func @simt_load_nd_7(%src: memref<24x32xf16>) {
 gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
   gpu.return
 }
 
@@ -235,8 +235,6 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
   gpu.return
 }
 
-
-
 // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -248,7 +246,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
-
 // CHECK: func @simt_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
@@ -318,8 +315,8 @@ gpu.func @subgroup_load(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<4x2xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<4x2xf32>
   gpu.return
 }
 
@@ -370,8 +367,8 @@ gpu.func @subgroup_load_3(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<4x8xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.te...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/145389


More information about the Mlir-commits mailing list