[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (PR #175857)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 13 14:24:25 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/175857.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+34-47) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..48906be79a54b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
   return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+/// Helper to get the default layout for 2D block operations.
+template <typename Ty>
+static LayoutInfo
+getSIMTLayoutInfoFor2DBlockOp(Ty ty, const xegpu::uArch::uArch *uArch,
+                              unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+  assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
+  assert(ty.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
+  if (ty.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                             {uArch->getSubgroupSize(), 1},
-                                             {1, packingFactor}));
-  }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, uArch->getSubgroupSize()},
-                                           {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(
+      ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+static LayoutInfo
+getSIMTLayoutInfoForScatterIOVector(VectorType vectorTy,
+                                    const xegpu::uArch::uArch *uArch,
+                                    unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
-         "Expected 1D or 2D TensorDesc.");
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(tdescTy.getElementType().isIntOrFloat() &&
+  assert(vectorTy.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
+  if (vectorTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
-  int subgroupSize = uArch->getSubgroupSize();
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
-  }
-
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+                                           {uArch->getSubgroupSize(), 1},
+                                           {1, packingFactor}));
 }
 
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInfoFor2DBlockOp(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getDefaultSIMTLayoutInfo(
+      prefetchLayout = getSIMTLayoutInfoFor2DBlockOp(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
-      storeLayout =
-          getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
-                                   uArchInstruction->getPackedFormatBitSize());
+      storeLayout = getSIMTLayoutInfoFor2DBlockOp(
+          store.getValueType(), uArch,
+          uArchInstruction->getPackedFormatBitSize());
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered*/ true);
+      loadLayout = getSIMTLayoutInfoForScatterIOVector(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
 
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered=*/true);
+      payloadLayout = getSIMTLayoutInfoForScatterIOVector(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
     }
 
     maskLayout =

``````````

</details>


https://github.com/llvm/llvm-project/pull/175857


More information about the Mlir-commits mailing list