[Mlir-commits] [mlir] [mlir][xegpu] Add utilities for `xegpu::sliceAttr` (PR #157970)

Charitha Saumya llvmlistbot at llvm.org
Wed Sep 10 16:55:13 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/157970

>From be1c00cc486c3b2fe69c13b5477df5be8bd1c70e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 10 Sep 2025 17:51:57 +0000
Subject: [PATCH 1/3] add transpose function

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 46 ++++++++++++++++++-
 1 file changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index cfe3e800484ce..24756318e4339 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -231,7 +231,51 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "getOffsets",
-                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+    InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
+                     to some other layout according to given permutation of (0...n-1).}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isTransposeOf",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
+                    /*methodBody=*/[{
+                      if (!other)
+                        return false;
+                      if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
+                        return false;
+                      // check if the permutation is valid
+                      int64_t rank = $_self.getRank();
+                      SmallVector<bool, 8> seen(rank, false);
+                      for (const auto &ta : llvm::enumerate(perm)) {
+                        if (ta.value() < 0 || ta.value() >= rank)
+                          return false;
+                        if (seen[ta.value()])
+                          return false;
+                        seen[ta.value()] = true;
+                      }
+                      auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
+                        for (const auto &ta : llvm::enumerate(perm)) {
+                          if (src[ta.index()] != dst[ta.value()])
+                            return false;
+                        }
+                        return true;
+                      };
+                      // check sgLayout
+                      if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm))
+                        return false;
+                      // check sgData
+                      if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm))
+                        return false;
+                      // check instData
+                      if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm))
+                        return false;
+                      // check laneLayout
+                      if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm))
+                        return false;
+                      // check laneData
+                      if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
+                        return false;
+                      return true;
+                    }]>
   ];
 }
 

>From 916c75f12298f76b2f8c6e2b5645125e75d34a73 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 10 Sep 2025 23:15:18 +0000
Subject: [PATCH 2/3] add slice attribute utils

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 12 ++++++++++-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 21 +++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 24756318e4339..aa3e3c5cddc05 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -275,7 +275,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
                         return false;
                       return true;
-                    }]>
+                    }]>,
+    InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isSliceOf",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
   ];
 }
 
@@ -477,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<SmallVector<Value>>>
     getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    /// Check if this is slice of some other layout.
+    bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -638,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<SmallVector<Value>>>
     getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    /// Check if this is slice of some other layout.
+    bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
+
   }];
 
   let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7f3be7f91c56b..a3783d5e05df6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
@@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                                   shape);
 }
 
+bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
+  auto flattenedThis = flatten();
+  // If other is a LayoutAttr, just compare directly with parent of
+  // flattenedThis.
+  if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
+    return flattenedThis.getParent() == otherLayout;
+  // If other is a SliceAttr, flatten it first before comparing.
+  auto otherFlattened = dyn_cast<xegpu::SliceAttr>(other).flatten();
+  // Both must have common parent LayoutAttr.
+  if (flattenedThis.getParent() != otherFlattened.getParent())
+    return false;
+  // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
+  // dims.
+  llvm::SmallDenseSet<int64_t> thisDims(
+      flattenedThis.getDims().asArrayRef().begin(),
+      flattenedThis.getDims().asArrayRef().end());
+  return llvm::all_of(otherFlattened.getDims().asArrayRef(),
+                      [&](int64_t dim) { return thisDims.contains(dim); });
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//

>From 77e8a9477dbd76bf95e5d142a0a6e6a4596ab3d2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 10 Sep 2025 23:54:57 +0000
Subject: [PATCH 3/3] fix name

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index a3783d5e05df6..cc133b110c95a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -417,16 +417,16 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
   if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
     return flattenedThis.getParent() == otherLayout;
   // If other is a SliceAttr, flatten it first before comparing.
-  auto otherFlattened = dyn_cast<xegpu::SliceAttr>(other).flatten();
+  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
   // Both must have common parent LayoutAttr.
-  if (flattenedThis.getParent() != otherFlattened.getParent())
+  if (flattenedThis.getParent() != flattenedOther.getParent())
     return false;
   // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
   // dims.
   llvm::SmallDenseSet<int64_t> thisDims(
       flattenedThis.getDims().asArrayRef().begin(),
       flattenedThis.getDims().asArrayRef().end());
-  return llvm::all_of(otherFlattened.getDims().asArrayRef(),
+  return llvm::all_of(flattenedOther.getDims().asArrayRef(),
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
 



More information about the Mlir-commits mailing list