[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 10 13:52:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

Add unrolling support for create_tdesc, load, store, prefetch, and update_offset. 

---
Full diff: https://github.com/llvm/llvm-project/pull/143602.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+189-2) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+143) 
- (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+23) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
+    SmallVector<Value> convertedIndiceVec =
+        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto indice : convertedIndiceVec) {
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+                                                        op.getSource(), indice);
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes =
+        getUnrolledTypes(maskTy, *targetShape);
+    SmallVector<Value> convertedMasks =
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+      auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto t : convertedTdesc)
+      rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy;
+    if (op.getMask())
+      maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> convertedMasks;
+    if (op.getMask()) {
+      SmallVector<Type> convertedMaskTypes =
+          getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks =
+          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
+
+    for (size_t i = 0; i < convertedValues.size(); ++i) {
+      Value v = convertedValues[i];
+      Value t = convertedTdescs[i];
+      Value m = op.getMask() ? convertedMasks[i] : nullptr;
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+    VectorType offsetVecTy = offsetVec.getType();
+    SmallVector<Type> convertedOffsetTypes =
+        getUnrolledTypes(offsetVecTy, *targetShape);
+    SmallVector<Value> convertedOffsetVec =
+        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+      auto newOp =
+          rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
-      patterns.getContext(), options);
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+               UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+               UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+                                                       options);
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
     %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
     gpu.return %c : vector<32x32xf32>
   }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_vec
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>,  #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_step
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+    // CHECK-LABEL: test_load
+    // CHECK-SAME: [[arg0:%.+]]: ui64
+    // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+    // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+    gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+      
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+      %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+      
+      gpu.return %ld : vector<32xf32> 
+    }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_prefetch(%src: ui64)  {
+
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+      gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_store
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+  gpu.func @test_store(%src: ui64) {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32xf32>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+    
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch_load_store_update
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+   
+      %delta = arith.constant dense<[
+      32,   32,  32,  32,  32,  32,  32,  32,
+      32,   32,  32,  32,  32,  32,  32,  64,
+      128, 128, 128, 128, 128, 128, 128, 128,
+      128, 128, 128, 128, 128, 128, 128, 256 
+      ]> : vector<32xindex>
+      %new_tdesc = xegpu.update_offset %tdesc, %delta
+           : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+ 
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+      %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+      xegpu.store %st_vec, %tdesc, %mask: 
+                 vector<32xf32>, 
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+                 vector<32xi1>
+  
+      gpu.return
+    }
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,29 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
+          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+              tdescTy = createOp.getType();
+            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+              tdescTy = updateOp.getTensorDescType();
+            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+              tdescTy = prefetchOp.getTensorDescType();
+            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+              tdescTy = loadOp.getTensorDescType();
+            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+              tdescTy = storeOp.getTensorDescType();
+            }
+
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              auto inst_data = layout.getInstData();
+              if (inst_data && layout.isSgLayout())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
           if (isa<xegpu::DpasOp>(op))
             return SmallVector<int64_t>{8, 16, 16};
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/143602


More information about the Mlir-commits mailing list