[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (PR #175857)

Nishant Patel llvmlistbot at llvm.org
Wed Jan 14 13:35:39 PST 2026


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/175857

>From f1666e4c7ba14c689b27a3ad94469037790a5561 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 13 Jan 2026 22:22:11 +0000
Subject: [PATCH 1/2] Clean up helpers in XeGPUPropagateLayout

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 81 ++++++++-----------
 1 file changed, 34 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..48906be79a54b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
   return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+/// Helper to get the default layout for 2D block operations.
+template <typename Ty>
+static LayoutInfo
+getSIMTLayoutInfoFor2DBlockOp(Ty ty, const xegpu::uArch::uArch *uArch,
+                              unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+  assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
+  assert(ty.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
+  if (ty.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                             {uArch->getSubgroupSize(), 1},
-                                             {1, packingFactor}));
-  }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, uArch->getSubgroupSize()},
-                                           {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(
+      ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+static LayoutInfo
+getSIMTLayoutInfoForScatterIOVector(VectorType vectorTy,
+                                    const xegpu::uArch::uArch *uArch,
+                                    unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
-         "Expected 1D or 2D TensorDesc.");
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(tdescTy.getElementType().isIntOrFloat() &&
+  assert(vectorTy.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
+  if (vectorTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
-  int subgroupSize = uArch->getSubgroupSize();
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
-  }
-
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+                                           {uArch->getSubgroupSize(), 1},
+                                           {1, packingFactor}));
 }
 
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInfoFor2DBlockOp(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getDefaultSIMTLayoutInfo(
+      prefetchLayout = getSIMTLayoutInfoFor2DBlockOp(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
-      storeLayout =
-          getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
-                                   uArchInstruction->getPackedFormatBitSize());
+      storeLayout = getSIMTLayoutInfoFor2DBlockOp(
+          store.getValueType(), uArch,
+          uArchInstruction->getPackedFormatBitSize());
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered*/ true);
+      loadLayout = getSIMTLayoutInfoForScatterIOVector(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
 
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered=*/true);
+      payloadLayout = getSIMTLayoutInfoForScatterIOVector(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
     }
 
     maskLayout =

>From ac5a5d97bd47e9573232c059ae60b6282fe3baf4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 14 Jan 2026 21:34:39 +0000
Subject: [PATCH 2/2] PR feedback

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 22 +++++++++----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 48906be79a54b..1341fc21e7fd4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -289,9 +289,9 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
 
 /// Helper to get the default layout for 2D block operations.
 template <typename Ty>
-static LayoutInfo
-getSIMTLayoutInfoFor2DBlockOp(Ty ty, const xegpu::uArch::uArch *uArch,
-                              unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
+                                               const xegpu::uArch::uArch *uArch,
+                                               unsigned packingSize) {
   // Expecting a 1D or 2D vector.
   assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -310,9 +310,9 @@ getSIMTLayoutInfoFor2DBlockOp(Ty ty, const xegpu::uArch::uArch *uArch,
 
 /// Helper to get the default layout for a vector type.
 static LayoutInfo
-getSIMTLayoutInfoForScatterIOVector(VectorType vectorTy,
-                                    const xegpu::uArch::uArch *uArch,
-                                    unsigned packingSize) {
+getSIMTLayoutInforForScatterIO(VectorType vectorTy,
+                               const xegpu::uArch::uArch *uArch,
+                               unsigned packingSize) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -354,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getSIMTLayoutInfoFor2DBlockOp(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getSIMTLayoutInfoFor2DBlockOp(
+      prefetchLayout = getSIMTLayoutInforForBlockIO(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -829,7 +829,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
-      storeLayout = getSIMTLayoutInfoFor2DBlockOp(
+      storeLayout = getSIMTLayoutInforForBlockIO(
           store.getValueType(), uArch,
           uArchInstruction->getPackedFormatBitSize());
     store.setLayoutAttr(
@@ -989,7 +989,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getSIMTLayoutInfoForScatterIOVector(
+      loadLayout = getSIMTLayoutInforForScatterIO(
           payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
 
     // Mask operand should have 1D default layout.
@@ -1066,7 +1066,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getSIMTLayoutInfoForScatterIOVector(
+      payloadLayout = getSIMTLayoutInforForScatterIO(
           payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
     }
 



More information about the Mlir-commits mailing list