[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 20 08:19:45 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

Add XeGPU to XeVM conversion pass and tests.

---

Patch is 80.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154556.diff


13 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+12) 
- (added) mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h (+27) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1) 
- (added) mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt (+25) 
- (added) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+932) 
- (added) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+48) 
- (added) mlir/test/Conversion/XeGPUToXeVM/dpas.mlir (+17) 
- (added) mlir/test/Conversion/XeGPUToXeVM/fence.mlir (+15) 
- (added) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+71) 
- (added) mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir (+357) 
- (added) mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir (+40) 
- (added) mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir (+25) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 91b2ecf8922a3..da061b269daf7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -82,6 +82,7 @@
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
 #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2058aba7f9e37..323af3e97e2d4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1555,4 +1555,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPUToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
+  let summary = "Convert XeGPU to XeVM dialect";
+  let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect", "arith::ArithDialect",
+                           "LLVM::LLVMDialect", "index::IndexDialect",
+                           "gpu::GPUDialect", "scf::SCFDialect"];
+}
+
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
new file mode 100644
index 0000000000000..fb23d24b0161b
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -0,0 +1,27 @@
+//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeGPUToXeVMConversionPatterns(
+    mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 171f7169fd41d..134fe8e14ca38 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -76,3 +76,4 @@ add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)
 add_subdirectory(VectorToXeGPU)
 add_subdirectory(XeVMToLLVM)
+add_subdirectory(XeGPUToXeVM)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..ed54b0bb5ee81
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_conversion_library(MLIRXeGPUToXeVM
+  XeGPUToXeVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRGPUDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRXeVMDialect
+  MLIRVectorDialect
+  MLIRArithDialect
+  MLIRIndexDialect
+  MLIRXeGPUDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
new file mode 100644
index 0000000000000..380409afbc62e
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -0,0 +1,932 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+enum class NdDescI32Layout : uint32_t {
+  BasePtr = 0,
+  BaseShapeW = 2,
+  BaseShapeH = 3,
+  TensorOffsetW = 4,
+  TensorOffsetH = 5
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+  switch (xeGpuMemspace) {
+  case xegpu::MemorySpace::Global:
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
+  case xegpu::MemorySpace::SLM:
+    return static_cast<int>(xevm::AddrSpace::SHARED);
+  }
+  llvm_unreachable("Unknown XeGPU memory space.");
+}
+
+template <typename T>
+std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
+  assert(!denseAttr.empty());
+  const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
+  if (denseAttr.size() < 2)
+    return {true, 0, intercept};
+  const T slope{denseAttr[1] - denseAttr[0]};
+  for (size_t i = 1; i < denseAttr.size(); ++i)
+    if (denseAttr[i] - denseAttr[i - 1] != slope)
+      return {false, 0, 0};
+  return {true, static_cast<int32_t>(slope), intercept};
+}
+
+VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+  auto elemType = currentVecType.getElementType();
+  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+  const int size =
+      currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+  return VectorType::get(size, toElemType);
+}
+
+xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                            std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::CACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::READ_INVALIDATE:
+    return xevm::LoadCacheControl::INVALIDATE_READ;
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                             std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_BACK:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_THROUGH:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+class CreateNdDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op,
+                  xegpu::CreateNdDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto source = op.getSource();
+    Type payloadElemTy = rewriter.getI32Type();
+    Type i64Ty = rewriter.getI64Type();
+    VectorType payloadTy = VectorType::get(8, payloadElemTy);
+    VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+    Value payload = arith::ConstantOp::create(
+        rewriter, loc,
+        DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+    Value baseAddr;
+    Value baseShapeW;
+    Value baseShapeH;
+    Value offsetW;
+    Value offsetH;
+
+    bool sourceIsMemref = false;
+    auto sourceTy = source.getType();
+    int64_t rank;
+    if (isa<MemRefType>(sourceTy)) {
+      sourceIsMemref = true;
+      baseAddr =
+          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      if (!sourceMemrefTy.hasStaticShape()) {
+        op.emitError() << "Expected static memref shape.";
+        return failure();
+      }
+      rank = sourceMemrefTy.getRank();
+      if (rank != 2) {
+        op.emitError() << "Expected a 2D memref.";
+        return failure();
+      }
+    } else if (sourceTy == rewriter.getIntegerType(64, false)) {
+      rank = op.getMixedSizes().size();
+    } else {
+      op.emitError() << "Expected source to be a 2D memref or ui64.";
+      return failure();
+    }
+    auto createOffset = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedOffsets()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    auto offsets = op.getMixedOffsets();
+    if (offsets.size() == 2) {
+      offsetW = createOffset(rank - 1);
+      offsetH = createOffset(rank - 2);
+    } else {
+      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    }
+    auto createShape = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedSizes()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    if (sourceIsMemref) {
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      baseShapeW = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
+      baseShapeH = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+    } else {
+      baseShapeW = createShape(rank - 1);
+      baseShapeH = createShape(rank - 2);
+      baseAddr = adaptor.getSource();
+    }
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+    payLoadAsI64 =
+        vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+                                 static_cast<int>(NdDescI32Layout::BasePtr));
+    payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeW));
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeH));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetW, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetW));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetH, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    rewriter.replaceOp(op, payload);
+    return success();
+  }
+};
+
+class UpdateNdOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = op.getOffsets();
+    auto tdesc = adaptor.getTensorDesc();
+    for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
+      auto offset = offsets[offsetDim];
+      if (auto cst =
+              dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
+        if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
+            attr && !attr.getInt())
+          continue;
+      const int offsetPos =
+          static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
+                                     : NdDescI32Layout::TensorOffsetH);
+      auto oldOffset =
+          vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
+      offset = arith::IndexCastUIOp::create(rewriter, loc,
+                                            rewriter.getI32Type(), offset);
+      auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+      tdesc =
+          vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
+    }
+    rewriter.replaceOp(op, tdesc);
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+
+    auto tdesc = adaptor.getTensorDesc();
+    auto tdescTy = op.getTensorDescType();
+    if (tdescTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+    }
+
+    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+    Value basePtr =
+        vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+                                  static_cast<int>(NdDescI32Layout::BasePtr));
+    Value baseShapeW = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+    Value baseShapeH = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
+    // Offsets can come from three sources:
+    // 1. Constant offsets, which are provided by the op.
+    // 2. Offsets as operands, which are provided by the op.
+    // 3. Offsets extracted from the tensor descriptor.
+    Value offsetW;
+    Value offsetH;
+    auto cOffsets = op.getConstOffsets();
+    auto offsets = op.getOffsets();
+    if (cOffsets) {
+      offsetW = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
+      offsetH = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
+    } else if (offsets.size() != 0) {
+      // offsets are provided as operands
+      if (offsets[0].getType() != rewriter.getI32Type()) {
+        if (offsets[0].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetW = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[0]);
+      } else {
+        offsetW = offsets[0];
+      }
+      if (offsets[1].getType() != rewriter.getI32Type()) {
+        if (offsets[1].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetH = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[1]);
+      } else {
+        offsetH = offsets[1];
+      }
+    } else {
+      // If offsets are not available, we need to extract them from the tensor
+      // descriptor.
+      offsetW = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetW));
+      offsetH = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    }
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+    auto elemType = tdescTy.getElementType();
+    auto elemBitSize = elemType.getIntOrFloatBitWidth();
+    // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(),
+    // elemBitSize);
+    Value elemByteSize = arith::ConstantIntOp::create(
+        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+    Value surfaceW =
+        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+    auto tileW = tdescTy.getDimSize(1);
+    auto tileH = tdescTy.getDimSize(0);
+    int32_t vblocks = tdescTy.getArrayLength();
+    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+      VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.get...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/154556


More information about the Mlir-commits mailing list