[Mlir-commits] [mlir] [MLIR] Vector to XeGPU conversion: Use proper source variant for create_nd_tdesc op creation. (PR #171216)

Sang Ik Lee llvmlistbot at llvm.org
Wed Dec 10 11:46:38 PST 2025


================
@@ -102,18 +102,47 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 xegpu::TensorDescType descType,
                                                 TypedValue<MemRefType> src) {
   MemRefType srcTy = src.getType();
+  assert(srcTy.isStrided() && "Expected strided memref type");
   auto [strides, offset] = srcTy.getStridesAndOffset();
+  bool isStatic = true;
+
+  // Memref is dynamic if any of its shape, offset or strides is dynamic.
+  if (!srcTy.hasStaticShape()) {
+    isStatic = false;
+  }
+
+  if (offset == ShapedType::kDynamic)
+    isStatic = false;
+
+  for (auto stride : strides) {
+    if (stride == ShapedType::kDynamic) {
+      isStatic = false;
+      break;
+    }
+  }
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
+  if (isStatic) {
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   } else {
-    // In case of any dynamic shapes, source's shape and strides have to be
+    // In case of ranked dynamic memref, instead of passing on the memref,
+    // i64 base address, source's offset, shape and strides have to be
     // explicitly provided.
     auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           meta.getConstifiedMixedSizes(),
-                                           meta.getConstifiedMixedStrides());
+    auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, meta.getBaseBuffer());
+    auto baseAddrI64 = arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
+    // Strided metadata only provides 1D offset but create_nd_desc op expect
+    // offset match the rank of source memref. Add leading zeros if rank > 1.
----------------
silee2 wrote:

Updated PR to generate adjusted base addr: base addr + offset * element_size_in_bytes

https://github.com/llvm/llvm-project/pull/171216


More information about the Mlir-commits mailing list