[Mlir-commits] [mlir] [MLIR][Conversion] XeGPU to xevm1 d block (PR #165894)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 31 11:04:38 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/165894.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+126-61) 


``````````diff
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2ed1f6ed..bc360c1a6929b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -264,8 +264,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
 
     auto tdesc = adaptor.getTensorDesc();
     auto tdescTy = op.getTensorDescType();
-    if (tdescTy.getRank() != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
     if (elemBitSize % 8 != 0)
@@ -294,71 +292,138 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
-    // Convert base pointer (i64) to LLVM pointer type.
-    Value basePtrLLVM =
-        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
-    // Compute element byte size and surface width in bytes.
+    // Compute element byte size.
     Value elemByteSize = arith::ConstantIntOp::create(
         rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
-    Value surfaceW =
-        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
-    // Get tile sizes and vblocks from the tensor descriptor type.
-    auto tileW = tdescTy.getDimSize(1);
-    auto tileH = tdescTy.getDimSize(0);
-    int32_t vblocks = tdescTy.getArrayLength();
-    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
-      Value src = adaptor.getValue();
-      // If store value is a scalar, get value from op instead of adaptor.
-      // Adaptor might have optimized away single element vector
-      if (src.getType().isIntOrFloat()) {
-        src = op.getValue();
-      }
-      VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
-      if (!srcVecTy)
-        return rewriter.notifyMatchFailure(
-            op, "Expected store value to be a vector type.");
-      // Get flat vector type of integer type with matching element bit size.
-      VectorType newSrcVecTy =
-          encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
-      if (srcVecTy != newSrcVecTy)
-        src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
-      auto storeCacheControl =
-          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      xevm::BlockStore2dOp::create(
-          rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-          offsetH, elemBitSize, tileW, tileH, src,
-          xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
-      rewriter.eraseOp(op);
-    } else {
-      auto loadCacheControl =
-          translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
-        xevm::BlockPrefetch2dOp::create(
+    auto tileRank = tdescTy.getRank();
+    // Get tile width from the tensor descriptor type.
+    auto tileW = tdescTy.getDimSize(tileRank - 1);
+    if (tileRank == 2) {
+      // Convert base pointer (i64) to LLVM pointer type.
+      Value basePtrLLVM =
+          LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+      // Compute width in bytes.
+      Value elemByteSize = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+      Value surfaceW =
+          arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+      // Get tile height from the tensor descriptor type.
+      auto tileH = tdescTy.getDimSize(0);
+      // Get vblocks from the tensor descriptor type.
+      int32_t vblocks = tdescTy.getArrayLength();
+      if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+        Value src = adaptor.getValue();
+        // If store value is a scalar, get value from op instead of adaptor.
+        // Adaptor might have optimized away single element vector
+        if (src.getType().isIntOrFloat()) {
+          src = op.getValue();
+        }
+        VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+        if (!srcVecTy)
+          return rewriter.notifyMatchFailure(
+              op, "Expected store value to be a vector type.");
+        // Get flat vector type of integer type with matching element bit size.
+        VectorType newSrcVecTy =
+            encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+        if (srcVecTy != newSrcVecTy)
+          src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+        auto storeCacheControl =
+            translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        xevm::BlockStore2dOp::create(
             rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-            offsetH, elemBitSize, tileW, tileH, vblocks,
-            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+            offsetH, elemBitSize, tileW, tileH, src,
+            xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
         rewriter.eraseOp(op);
       } else {
-        VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
-        const bool vnni = op.getPacked().value_or(false);
-        auto transposeValue = op.getTranspose();
-        bool transpose =
-            transposeValue.has_value() && transposeValue.value()[0] == 1;
-        VectorType loadedTy = encodeVectorTypeTo(
-            dstVecTy, vnni ? rewriter.getI32Type()
-                           : rewriter.getIntegerType(elemBitSize));
-
-        Value resultFlatVec = xevm::BlockLoad2dOp::create(
-            rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
-            surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
-            transpose, vnni,
+        auto loadCacheControl =
+            translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+          xevm::BlockPrefetch2dOp::create(
+              rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
+              offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+              xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+          rewriter.eraseOp(op);
+        } else {
+          VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+          const bool vnni = op.getPacked().value_or(false);
+          auto transposeValue = op.getTranspose();
+          bool transpose =
+              transposeValue.has_value() && transposeValue.value()[0] == 1;
+          VectorType loadedTy = encodeVectorTypeTo(
+              dstVecTy, vnni ? rewriter.getI32Type()
+                             : rewriter.getIntegerType(elemBitSize));
+
+          Value resultFlatVec = xevm::BlockLoad2dOp::create(
+              rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+              surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+              transpose, vnni,
+              xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+          resultFlatVec = vector::BitCastOp::create(
+              rewriter, loc,
+              encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+              resultFlatVec);
+          rewriter.replaceOp(op, resultFlatVec);
+        }
+      }
+    } else {
+      // Get address from base address and offsets.
+      // Offset in number of elements, need to multiply by element byte size.
+      // Compute linear offset.
+      // linearOffset = offsetH * baseShapeW + offsetW
+      Value offsetHInElems =
+          rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
+      Value linearOffset =
+          rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
+      // Then compute byte offset by multiplying with element byte size.
+      // byteOffset = linearOffset * elemByteSize
+      Value byteOffset =
+          rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
+      // Final address = basePtr + byteOffset
+      Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+          loc, basePtr,
+          getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+                                          byteOffset));
+      // Convert base pointer (i64) to LLVM pointer type.
+      Value finalPtrLLVM =
+          LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+      if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+        Value src = adaptor.getValue();
+        // If store value is a scalar, get value from op instead of adaptor.
+        // Adaptor might have optimized away single element vector
+        if (src.getType().isIntOrFloat()) {
+          src = op.getValue();
+        }
+        VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+        if (!srcVecTy)
+          return rewriter.notifyMatchFailure(
+              op, "Expected store value to be a vector type.");
+        // Get flat vector type of integer type with matching element bit size.
+        VectorType newSrcVecTy =
+            encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+        if (srcVecTy != newSrcVecTy)
+          src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+        auto storeCacheControl =
+            translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+            op, finalPtrLLVM, src,
+            xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+      } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+        auto loadCacheControl =
+            translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        VectorType resTy = cast<VectorType>(op.getValue().getType());
+        VectorType loadedTy =
+            encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+        Value load = xevm::BlockLoadOp::create(
+            rewriter, loc, loadedTy, finalPtrLLVM,
             xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
-        resultFlatVec = vector::BitCastOp::create(
-            rewriter, loc,
-            encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
-            resultFlatVec);
-        rewriter.replaceOp(op, resultFlatVec);
+        if (loadedTy != resTy)
+          load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+        rewriter.replaceOp(op, load);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+                "descriptor rank == 1");
       }
     }
     return success();

``````````

</details>


https://github.com/llvm/llvm-project/pull/165894


More information about the Mlir-commits mailing list