[Mlir-commits] [mlir] 58d2347 - [MLIR][XeGPU] Add unroll patterns for scatter ops (#143602)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 16 08:48:45 PDT 2025


Author: Jianhui Li
Date: 2025-06-16T10:48:41-05:00
New Revision: 58d23476f0ce76c847497a880f975550a645c796

URL: https://github.com/llvm/llvm-project/commit/58d23476f0ce76c847497a880f975550a645c796
DIFF: https://github.com/llvm/llvm-project/commit/58d23476f0ce76c847497a880f975550a645c796.diff

LOG: [MLIR][XeGPU] Add unroll patterns for scatter ops  (#143602)

Add unrolling support for create_tdesc, load, store, prefetch, and update_offset.

---------

Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
Co-authored-by: Chao Chen <chao.chen at intel.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..9c234c1e866b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,214 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
+
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
+    SmallVector<Value> convertedIndiceVec =
+        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto indice : convertedIndiceVec) {
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+                                                        op.getSource(), indice);
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
+      return failure();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes =
+        getUnrolledTypes(maskTy, *targetShape);
+    SmallVector<Value> convertedMasks =
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+      auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto t : convertedTdesc)
+      rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
+      return failure();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes =
+        getUnrolledTypes(maskTy, *targetShape);
+    SmallVector<Value> convertedMasks =
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    for (size_t i = 0; i < convertedValues.size(); ++i) {
+      Value v = convertedValues[i];
+      Value t = convertedTdescs[i];
+      Value m = op.getMask() ? convertedMasks[i] : nullptr;
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+    VectorType offsetVecTy = offsetVec.getType();
+    SmallVector<Type> convertedOffsetTypes =
+        getUnrolledTypes(offsetVecTy, *targetShape);
+    SmallVector<Value> convertedOffsetVec =
+        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+      auto newOp =
+          rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
-      patterns.getContext(), options);
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+               UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+               UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+                                                       options);
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..52ec3b856da49 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,145 @@ gpu.module @test {
     %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
     gpu.return %c : vector<32x32xf32>
   }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_vec
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>,  #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_step
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+  // CHECK-LABEL: test_load
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+      
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+      
+    gpu.return %ld : vector<32xf32> 
+  }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_prefetch(%src: ui64)  {
+
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_store
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+  gpu.func @test_store(%src: ui64) {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+    
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch_load_store_update
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+   
+    %delta = arith.constant dense<[
+    32,   32,  32,  32,  32,  32,  32,  32,
+    32,   32,  32,  32,  32,  32,  32,  64,
+    128, 128, 128, 128, 128, 128, 128, 128,
+    128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+ 
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+    xegpu.store %st_vec, %tdesc, %mask: 
+                 vector<32xf32>, 
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+                 vector<32xi1>
+  
+    gpu.return
+  }
 }

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,29 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
+          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+              tdescTy = createOp.getType();
+            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+              tdescTy = updateOp.getTensorDescType();
+            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+              tdescTy = prefetchOp.getTensorDescType();
+            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+              tdescTy = loadOp.getTensorDescType();
+            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+              tdescTy = storeOp.getTensorDescType();
+            }
+
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              auto inst_data = layout.getInstData();
+              if (inst_data && layout.isSgLayout())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
           if (isa<xegpu::DpasOp>(op))
             return SmallVector<int64_t>{8, 16, 16};
 


        


More information about the Mlir-commits mailing list