[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (PR #175857)

Nishant Patel llvmlistbot at llvm.org
Tue Jan 13 14:23:52 PST 2026


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/175857

None

>From f1666e4c7ba14c689b27a3ad94469037790a5561 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 13 Jan 2026 22:22:11 +0000
Subject: [PATCH] Clean up helpers in XeGPUPropagateLayout

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 81 ++++++++-----------
 1 file changed, 34 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..48906be79a54b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
   return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+/// Helper to get the default layout for 2D block operations.
+template <typename Ty>
+static LayoutInfo
+getSIMTLayoutInfoFor2DBlockOp(Ty ty, const xegpu::uArch::uArch *uArch,
+                              unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+  assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
+  assert(ty.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
+  if (ty.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                             {uArch->getSubgroupSize(), 1},
-                                             {1, packingFactor}));
-  }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, uArch->getSubgroupSize()},
-                                           {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(
+      ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+static LayoutInfo
+getSIMTLayoutInfoForScatterIOVector(VectorType vectorTy,
+                                    const xegpu::uArch::uArch *uArch,
+                                    unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
-         "Expected 1D or 2D TensorDesc.");
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(tdescTy.getElementType().isIntOrFloat() &&
+  assert(vectorTy.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
+  if (vectorTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
-  int subgroupSize = uArch->getSubgroupSize();
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
-  }
-
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+                                           {uArch->getSubgroupSize(), 1},
+                                           {1, packingFactor}));
 }
 
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInfoFor2DBlockOp(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getDefaultSIMTLayoutInfo(
+      prefetchLayout = getSIMTLayoutInfoFor2DBlockOp(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
-      storeLayout =
-          getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
-                                   uArchInstruction->getPackedFormatBitSize());
+      storeLayout = getSIMTLayoutInfoFor2DBlockOp(
+          store.getValueType(), uArch,
+          uArchInstruction->getPackedFormatBitSize());
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered*/ true);
+      loadLayout = getSIMTLayoutInfoForScatterIOVector(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
 
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered=*/true);
+      payloadLayout = getSIMTLayoutInfoForScatterIOVector(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
     }
 
     maskLayout =



More information about the Mlir-commits mailing list