[Mlir-commits] [mlir] [MLIR][Conversion] Convert XeGPU to XeVM pass: Remove lowering support for tensor descriptor with offsets. (PR #157550)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 8 13:50:18 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

And update load/store/prefetch test cases to use direct offsets.
Tensor descriptors with offsets are getting deprecated.

---

Patch is 38.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157550.diff


5 Files Affected:

- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+34-136) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (-32) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir (+38-206) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir (+13-25) 
- (removed) mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir (-25) 


``````````diff
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a7f2dc2d6a43e..7d756620ecd16 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
   matchAndRewrite(xegpu::CreateNdDescOp op,
                   xegpu::CreateNdDescOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    if (mixedOffsets.size() != 0)
+      return rewriter.notifyMatchFailure(op, "Offsets not supported.");
     auto loc = op.getLoc();
     auto source = op.getSource();
     // Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
 
     // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
-    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
     if (rank != 2)
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
       val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
       return val;
     };
-    // Offsets can be either 2D or not provided (0 is used).
-    if (mixedOffsets.size() == 2) {
-      offsetW = createOffset(mixedOffsets, 1);
-      offsetH = createOffset(mixedOffsets, 0);
-    } else if (mixedOffsets.size() == 0) {
-      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-    } else {
-      return rewriter.notifyMatchFailure(op,
-                                         "Expected 2D offsets or no offsets.");
-    }
+    // Offsets are not supported not (0 is used).
+    offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
     // Get shape values from op fold results.
     baseShapeW = createOffset(mixedSizes, 1);
     baseShapeH = createOffset(mixedSizes, 0);
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
   }
 };
 
-class UpdateNdOffsetToXeVMPattern
-    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
-                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto mixedOffsets = op.getMixedOffsets();
-    // Only 2D offsets are supported for now.
-    if (mixedOffsets.size() != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
-    auto payload = adaptor.getTensorDesc();
-    // Utility for updating payload offset values from op fold result.
-    auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
-      Value offset =
-          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
-      offset = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                               rewriter.getI32Type(), offset);
-      Value oldOffset =
-          vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
-      Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
-      return vector::InsertOp::create(rewriter, loc, newOffset, payload,
-                                      payloadPos);
-    };
-    // Update offsets in the payload.
-    payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
-    rewriter.replaceOp(op, payload);
-    return success();
-  }
-};
-
 template <
     typename OpType,
     typename = std::enable_if_t<llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
   LogicalResult
   matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto mixedOffsets = op.getMixedOffsets();
+    int64_t opOffsetsSize = mixedOffsets.size();
+    if (opOffsetsSize != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
 
@@ -311,32 +276,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
     Value baseShapeH = vector::ExtractOp::create(
         rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
-    // Offsets provided in two ways:
-    // 1. Offsets are extracted from the tensor descriptor.
-    // 2. (Mixed) offsets which are provided by the op.
-    Value offsetW;
-    Value offsetH;
-    auto mixedOffsets = op.getMixedOffsets();
-    int64_t opOffsetsSize = mixedOffsets.size();
-    if (opOffsetsSize != 0 && opOffsetsSize != 2)
-      return rewriter.notifyMatchFailure(op,
-                                         "Expected 2D offsets or no offsets.");
-    if (opOffsetsSize) {
-      // If mixed offsets are provided by the op convert them to i32.
-      offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
-      offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                                rewriter.getI32Type(), offsetW);
-      offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
-      offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                                rewriter.getI32Type(), offsetH);
-    } else {
-      // If offsets are not available, we need to extract them from the tensor
-      // descriptor.
-      offsetW = vector::ExtractOp::create(
-          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
-      offsetH = vector::ExtractOp::create(
-          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    }
+    // Offsets are provided by the op.
+    // convert them to i32.
+    Value offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+    offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                             rewriter.getI32Type(), offsetW);
+    Value offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+    offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                              rewriter.getI32Type(), offsetH);
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
@@ -422,54 +369,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
   return newAddr;
 }
 
-class CreateDescToXeVMPattern
-    : public OpConversionPattern<xegpu::CreateDescOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto eTy = op.getTensorDescType().getElementType();
-    auto eBw = eTy.getIntOrFloatBitWidth();
-    if (eBw % 8 != 0)
-      return rewriter.notifyMatchFailure(
-          op, "Expected element type bit width to be multiple of 8.");
-    auto loc = op.getLoc();
-    // Offsets are provided as scalar i64 by type converter.
-    auto offsets = adaptor.getOffsets();
-    // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
-    // But type converter will convert them to integer types.
-    Value addr = adaptor.getSource();
-    // ui32 or i32 are passed as i32 so they need to be casted to i64.
-    if (addr.getType() != rewriter.getI64Type())
-      addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
-    auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
-    rewriter.replaceOp(op, laneAddr);
-    return success();
-  }
-};
-
-class UpdateOffsetToXeVMPattern
-    : public OpConversionPattern<xegpu::UpdateOffsetOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::UpdateOffsetOp op,
-                  xegpu::UpdateOffsetOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto eTy = op.getTensorDescType().getElementType();
-    auto eBw = eTy.getIntOrFloatBitWidth();
-    if (eBw % 8 != 0)
-      return rewriter.notifyMatchFailure(
-          op, "Expected element type bit width to be multiple of 8.");
-    auto loc = op.getLoc();
-    // Scatter descriptor is provided as scalar i64 by type converter.
-    // Offsets are provided as scalar i64 by type converter.
-    Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
-                                adaptor.getOffsets(), eBw / 8);
-    rewriter.replaceOp(op, newOffset);
-    return success();
-  }
-};
-
 template <typename OpType,
           typename = std::enable_if_t<llvm::is_one_of<
               OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +377,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
   LogicalResult
   matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Value offsets = adaptor.getOffsets();
+    if (!offsets)
+      return rewriter.notifyMatchFailure(op, "Expected offsets to be provided.");
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
     auto tdescTy = op.getTensorDescType();
@@ -527,21 +429,18 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
       basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
                                           basePtrI64);
     }
-    Value offsets = adaptor.getOffsets();
     Value mask = adaptor.getMask();
-    if (offsets) {
-      if (dyn_cast<VectorType>(offsets.getType())) {
-        // Offset needs be scalar. Single element vector is converted to scalar
-        // by type converter.
-        return rewriter.notifyMatchFailure(op,
-                                           "Expected offsets to be a scalar.");
-      } else {
-        // If offsets are provided, we add them to the base pointer.
-        // Offsets are in number of elements, we need to multiply by
-        // element byte size.
-        basePtrI64 =
-            addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
-      }
+    if (dyn_cast<VectorType>(offsets.getType())) {
+      // Offset needs be scalar. Single element vector is converted to scalar
+      // by type converter.
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected offsets to be a scalar.");
+    } else {
+      // If offsets are provided, we add them to the base pointer.
+      // Offsets are in number of elements, we need to multiply by
+      // element byte size.
+      basePtrI64 =
+          addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
     }
     // Convert base pointer (i64) to LLVM pointer type.
     Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
 //===----------------------------------------------------------------------===//
 void mlir::populateXeGPUToXeVMConversionPatterns(
     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+  patterns.add<CreateNdDescToXeVMPattern,
                LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
                LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
                LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
       typeConverter, patterns.getContext());
-  patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
-               AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+  patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
                LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
                LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
       typeConverter, patterns.getContext());
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index ed664a739d134..d6e36fa73bf04 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
         // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
-        // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
-        // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
-        // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
-        // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
-        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
-        // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
-        // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
-        // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
-        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
-        // CHECK: %[[C8:.*]] = arith.constant 8 : index
-        %c8 = arith.constant 8 : index
-        // CHECK: %[[C16:.*]] = arith.constant 16 : index
-        %c16 = arith.constant 16 : index
-        // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
-        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
-        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
-        // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
-        // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
-        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
-        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
-        // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
-        %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
         gpu.return
     }
 }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 0f67dc290689b..0b150e9d58153 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -1,239 +1,73 @@
 // RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
 
 gpu.module @test {
-// CHECK-LABEL: @load_gather_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
-  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
-  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
-  // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
-  // CHECK:      %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
-  // CHECK-SAME:     : !llvm.ptr<1> -> vector<2xf32>
-  // CHECK:      scf.yield %[[VAR9]] : vector<2xf32>
-  // CHECK:    } else {
-  // CHECK:      %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-  // CHECK:      scf.yield %[[CST_1]] : vector<2xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
-  %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
-  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
-  // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
-  // CHECK:      %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
-  // CHECK-SAME:     : !llvm.ptr<1> -> vector<1xf32>
-  // CHECK:      %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
-  // CHECK:      scf.yield %[[VAR9]] : f32
-  // CHECK:    } else {
-  // CHECK:      %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-  // CHECK:      %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
-  // CHECK:      scf.yield %[[VAR8]] : f32
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+// CHECK-LABEL: @load_gather_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: ve...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/157550


More information about the Mlir-commits mailing list