[Mlir-commits] [mlir] 3150b73 - [MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (#175857)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 15 08:13:45 PST 2026


Author: Nishant Patel
Date: 2026-01-15T08:13:38-08:00
New Revision: 3150b73dec83d76321727a3c01a7907485bec6ce

URL: https://github.com/llvm/llvm-project/commit/3150b73dec83d76321727a3c01a7907485bec6ce
DIFF: https://github.com/llvm/llvm-project/commit/3150b73dec83d76321727a3c01a7907485bec6ce.diff

LOG: [MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (#175857)

In XeGPUPropagateLayout.cpp, the helper getDefaultSIMTLayoutInfo is
implemented via multiple overloads that differ significantly in
semantics, not just parameter types.
Reusing the same function name for these semantically different
behaviors makes call sites harder to read and reason about and increases
the maintenance burden. This PR improves readability and maintainability
of layout propagation logic.

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..1341fc21e7fd4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
   return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+/// Helper to get the default layout for 2D block operations.
+template <typename Ty>
+static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
+                                               const xegpu::uArch::uArch *uArch,
+                                               unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+  assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
+  assert(ty.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
+  if (ty.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                             {uArch->getSubgroupSize(), 1},
-                                             {1, packingFactor}));
-  }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, uArch->getSubgroupSize()},
-                                           {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(
+      ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
-                                           const xegpu::uArch::uArch *uArch,
-                                           unsigned packingSize,
-                                           bool isScattered = false) {
+static LayoutInfo
+getSIMTLayoutInforForScatterIO(VectorType vectorTy,
+                               const xegpu::uArch::uArch *uArch,
+                               unsigned packingSize) {
   // Expecting a 1D or 2D vector.
-  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
-         "Expected 1D or 2D TensorDesc.");
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
   // Expecting int or float element type.
-  assert(tdescTy.getElementType().isIntOrFloat() &&
+  assert(vectorTy.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
-  if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
+  if (vectorTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
-  int subgroupSize = uArch->getSubgroupSize();
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
-  }
-
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+                                           {uArch->getSubgroupSize(), 1},
+                                           {1, packingFactor}));
 }
 
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getDefaultSIMTLayoutInfo(
+      prefetchLayout = getSIMTLayoutInforForBlockIO(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
-      storeLayout =
-          getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
-                                   uArchInstruction->getPackedFormatBitSize());
+      storeLayout = getSIMTLayoutInforForBlockIO(
+          store.getValueType(), uArch,
+          uArchInstruction->getPackedFormatBitSize());
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered*/ true);
+      loadLayout = getSIMTLayoutInforForScatterIO(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
 
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getDefaultSIMTLayoutInfo(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
-          /*scattered=*/true);
+      payloadLayout = getSIMTLayoutInforForScatterIO(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
     }
 
     maskLayout =


        


More information about the Mlir-commits mailing list