[Mlir-commits] [mlir] [MLIR][XeGPU] Support pointer/dynamic-memref sources in array-length optimization (PR #195872)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 5 09:18:02 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Md Abdullah Shahneous Bari (mshahneo)
<details>
<summary>Changes</summary>
Extend `OptimizeCreateNdDescOp` to handle the two remaining `create_nd_tdesc` source forms — `i64` pointer and dynamic-shape memref — by forwarding the existing shape/strides operands through the general builder. The memory region is unchanged by the rewrite; only the `tensor_desc` view is narrowed along the FCD and tagged with `array_length`.
---
Full diff: https://github.com/llvm/llvm-project/pull/195872.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUArrayLengthOptimization.cpp (+14-11)
- (modified) mlir/test/Dialect/XeGPU/array-len-op-unit.mlir (+50)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUArrayLengthOptimization.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUArrayLengthOptimization.cpp
index 97faeaf230695..c3b762b6494b5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUArrayLengthOptimization.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUArrayLengthOptimization.cpp
@@ -92,10 +92,12 @@ static bool hasTransposeLaneLayout(xegpu::TensorDescType tdescType) {
}
/// Rewrite `xegpu.create_nd_tdesc` to fold an array_length attribute into the
-/// resulting tensor descriptor type. Only applies when the source is a static
-/// memref; dynamic-shape sources are left unchanged. Skipped if any consumer
-/// load_nd carries a non-identity transpose, since stacking the array blocks
-/// along the non-FCD dimension would invalidate that load.
+/// resulting tensor descriptor type. Supports static memref, dynamic-shape
+/// memref, and raw-pointer (integer) sources — the memory region described by
+/// `shape`/`strides` is unchanged; only the tensor_desc view is narrowed along
+/// the FCD and tagged with `array_length`. Skipped if any consumer load_nd
+/// carries a non-identity transpose, since stacking the array blocks along the
+/// non-FCD dimension would invalidate that load.
class OptimizeCreateNdDescOp : public OpRewritePattern<xegpu::CreateNdDescOp> {
public:
using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
@@ -112,11 +114,8 @@ class OptimizeCreateNdDescOp : public OpRewritePattern<xegpu::CreateNdDescOp> {
if (hasTransposeLaneLayout(tdescType))
return failure();
- // Only static memref sources are supported for now.
- // TODO: extend to dynamic-shape memrefs and raw pointer sources by
- // rewriting the `shape`/`strides` operands of create_nd_tdesc.
- auto memrefSource = dyn_cast<TypedValue<MemRefType>>(op.getSource());
- if (!memrefSource || !memrefSource.getType().hasStaticShape())
+ Value source = op.getSource();
+ if (!isa<MemRefType, IntegerType>(source.getType()))
return failure();
// Bail out if any consumer is a transposing load_nd.
@@ -135,8 +134,12 @@ class OptimizeCreateNdDescOp : public OpRewritePattern<xegpu::CreateNdDescOp> {
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
tdescType.getLayout());
- auto newOp = xegpu::CreateNdDescOp::create(rewriter, op.getLoc(),
- newTdescType, memrefSource);
+ // The memory region is unchanged; pass through the existing shape/strides.
+ // The general builder recognizes the static-memref case and drops the
+ // redundant attributes.
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), newTdescType, source, op.getMixedSizes(),
+ op.getMixedStrides());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/array-len-op-unit.mlir b/mlir/test/Dialect/XeGPU/array-len-op-unit.mlir
index 4aad98508398e..340a10a99cb88 100644
--- a/mlir/test/Dialect/XeGPU/array-len-op-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/array-len-op-unit.mlir
@@ -165,3 +165,53 @@ func.func @test_multiple_extracts(%arg0: memref<4096x4096xf16>) -> (vector<16x16
return %e0, %e1, %e2, %e3 : vector<16x16xf16>, vector<16x16xf16>, vector<16x16xf16>, vector<16x16xf16>
}
}
+
+// -----
+
+gpu.module @test {
+// Pointer-form (i64) source — shape/strides are given explicitly and must be
+// carried through to the rewritten create_nd_tdesc.
+// CHECK-LABEL: func.func @test_pointer_source
+// CHECK-SAME: (%[[ARG0:.*]]: i64)
+func.func @test_pointer_source(%arg0: i64) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]], shape : [64, 64], strides : [64, 1]
+ // CHECK-SAME: i64 -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64, boundary_check = false>>
+ %tdesc = xegpu.create_nd_tdesc %arg0, shape : [64, 64], strides : [64, 1] : i64 -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][%{{.*}}, %{{.*}}]
+ // CHECK-SAME: -> vector<64x16xf16>
+ %load = xegpu.load_nd %tdesc[%c0, %c0] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<32x32xf16>
+
+ // CHECK: vector.extract_strided_slice %[[LOAD]]
+ // CHECK-SAME: {offsets = [32, 0], sizes = [16, 16], strides = [1, 1]}
+ %e = vector.extract_strided_slice %load {offsets = [0, 16], sizes = [16, 16], strides = [1, 1]} : vector<32x32xf16> to vector<16x16xf16>
+
+ return %e : vector<16x16xf16>
+}
+}
+
+// -----
+
+gpu.module @test {
+// Dynamic-shape memref source — shape/strides are given via operands and must
+// be carried through.
+// CHECK-LABEL: func.func @test_dynamic_memref_source
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf16>, %[[H:.*]]: index, %[[W:.*]]: index)
+func.func @test_dynamic_memref_source(%arg0: memref<?x?xf16>, %h: index, %w: index) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]], shape : [%[[H]], %[[W]]], strides : [%[[W]], %{{.*}}]
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %tdesc = xegpu.create_nd_tdesc %arg0, shape : [%h, %w], strides : [%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<32x32xf16>
+
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][%{{.*}}, %{{.*}}]
+ // CHECK-SAME: -> vector<64x16xf16>
+ %load = xegpu.load_nd %tdesc[%c0, %c0] : !xegpu.tensor_desc<32x32xf16> -> vector<32x32xf16>
+
+ %e = vector.extract_strided_slice %load {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x32xf16> to vector<16x16xf16>
+ return %e : vector<16x16xf16>
+}
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/195872
More information about the Mlir-commits
mailing list