[Mlir-commits] [mlir] [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (PR #170541)

Charitha Saumya llvmlistbot at llvm.org
Wed Dec 3 15:52:57 PST 2025


================
@@ -991,27 +991,70 @@ struct ConvertXeGPUToXeVMPass
     });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
-      if (type.getMemorySpaceAsInt() == 3)
-        return IntegerType::get(&getContext(), 32);
-      return IntegerType::get(&getContext(), 64);
+      return IntegerType::get(&getContext(),
+                              (xegpu::isSharedMemRef(type) ? 32 : 64));
     });
 
     // LLVM type converter puts unrealized casts for the following cases:
     // add materialization casts to handle them.
 
-    // Materialization to convert memref to i64
+    // Materialization to convert memref to i64 or i32 depending on global/SLM
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
                                         ValueRange inputs,
                                         Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
       if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+        unsigned rank = memrefTy.getRank();
+        Type indexType = builder.getIndexType();
 
-        Value addr =
-            memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
-        return arith::IndexCastUIOp::create(builder, loc, type, addr)
-            .getResult();
+        int64_t intOffsets;
+        SmallVector<int64_t> intStrides;
+        Value addr;
+        Value offset;
+        if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
+
+          // Result types: [base_memref, offset, stride0, stride1, ...,
+          // strideN-1, size0, size1, ..., sizeN-1]
+          SmallVector<Type> resultTypes{
+              MemRefType::get({}, memrefTy.getElementType(),
+                              MemRefLayoutAttrInterface(),
+                              memrefTy.getMemorySpace()),
+              indexType};
+          // strides + sizes
+          resultTypes.append(2 * rank, indexType);
+
+          auto meta = memref::ExtractStridedMetadataOp::create(
+              builder, loc, resultTypes, input);
+
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(
+              builder, loc, meta.getBaseBuffer());
+          offset = meta.getOffset();
+
+        } else {
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+                                                                input);
+          offset = arith::ConstantOp::create(builder, loc,
+                                             builder.getIndexAttr(intOffsets));
+        }
+
+        auto addr_casted =
----------------
charithaintc wrote:

LLVM does not use snake case variable naming. rename to `addrCasted`

https://github.com/llvm/llvm-project/pull/170541


More information about the Mlir-commits mailing list