[Mlir-commits] [mlir] 88ff0f9 - [MLIR][XeGPU] Distribute create_nd_desc op without offset from Wg to Sg (#152351)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 11 21:58:27 PDT 2025


Author: Nishant Patel
Date: 2025-08-11T21:58:24-07:00
New Revision: 88ff0f955c8798ae1f4816776119e6796b507e34

URL: https://github.com/llvm/llvm-project/commit/88ff0f955c8798ae1f4816776119e6796b507e34
DIFF: https://github.com/llvm/llvm-project/commit/88ff0f955c8798ae1f4816776119e6796b507e34.diff

LOG: [MLIR][XeGPU] Distribute create_nd_desc op without offset from Wg to Sg (#152351)

This PR adds pattern to distribute the create_nd_desc op without offsets
from workgroup (Wg) IR to subgroup (Sg) IR.
The round robin distribution logic (involves offset calculation) now
will happen in load/store/prefetch nd ops instead of create_nd.

Added: 
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 1a6a34c8d775a..480b43e740736 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,11 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
   let builders = [
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value ": $source,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>,
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2cd086feb5deb..4dd937eb5114d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -156,41 +156,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && strides.size() && shape.size() == strides.size() &&
-         "Shape and strides must be present and of equal size for ui64 "
-         "initialization.");
+  Type srcTy = source.getType();
+  assert((isa<IntegerType, MemRefType>(srcTy)) &&
+         "Source has to be either int or memref.");
 
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
-        dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
-        staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && strides.size() && shape.size() == strides.size() &&
-         "Shape and strides must be present and of equal size for ui64 "
-         "initialization.");
-
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
@@ -198,6 +175,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    // to keep the IR print clean.
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
         dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
         staticStridesAttr);
@@ -357,13 +346,10 @@ ParseResult parseOptionalDynamicIndexList(
 void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                    OperandRange values,
                                    DenseI64ArrayAttr integers) {
-
-  if (!integers)
+  if (!integers || integers.empty())
     return;
-
-  return printDynamicIndexList(printer, op, values, integers,
-                               /*scalableFlags=*/{}, {},
-                               AsmParser::Delimiter::Square);
+  printDynamicIndexList(printer, op, values, integers,
+                        /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
 }
 //===----------------------------------------------------------------------===//
 // XeGPU_PrefetchNdOp

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 4a5525c8abb30..97c97ac3fd680 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -128,6 +128,12 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   LogicalResult
   matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
+    // Ensure that the op has explicit offsets specified (either dynamic or
+    // constant).
+    if (op.getMixedOffsets().empty())
+      return failure();
+
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
@@ -199,6 +205,49 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   }
 };
 
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Check no offsets are specified.
+    if (!op.getMixedOffsets().empty())
+      return failure();
+
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    if (!layout || !layout.isWgLayout())
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+
+    SmallVector<Value> newCreateNdOps(count);
+    std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
+      return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
+                                           op.getSource(), op.getMixedSizes(),
+                                           op.getMixedStrides());
+    });
+
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
 /// This pattern transforms the LoadNdOp to load subgroup data.
 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
@@ -603,11 +652,12 @@ struct UnrealizedConversionCastOpPattern
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
-               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-      patterns.getContext());
+  patterns
+      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+           WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+           WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
new file mode 100644
index 0000000000000..b6f44b5bc0b68
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_distribution {
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.create_nd_tdesc
+      %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+    }
+}

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
new file mode 100644
index 0000000000000..025d48e22307e
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_distribution {
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+  }
+
+  // CHECK-LABEL: create_nd_tdesc_with_ptr
+  // CHECK-SAME: %[[ARG_0:.*]]: ui64
+  gpu.func @create_nd_tdesc_with_ptr(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]], shape : [{{.*}}, {{.*}}], strides : [{{.*}}, {{.*}}] : ui64
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %c1 = arith.constant 1 : index
+    %tdesc = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides: [%w, %c1] : ui64
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+  }
+}


        


More information about the Mlir-commits mailing list