[Mlir-commits] [mlir] [mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution (PR #142618)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 4 10:35:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>



---

Patch is 22.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142618.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+3) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+136-24) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+3-3) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+25-1) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+67-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index f9327d63869c0..6fea10185402a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -26,6 +26,9 @@ class TensorDescType;
 
 namespace xegpu {
 
+/// Flatten a set of ValueRange into a single SmallVector<Value>
+SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);
+
 /// If tensor descriptor has a layout attribute it is used in SIMT mode.
 /// In this mode, the distributed vector shape is determined as follows:
 /// Definitions:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..e29da76898c58 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -29,6 +30,29 @@ using namespace mlir;
 
 namespace {
 
+static std::pair<SmallVector<int64_t>, int>
+getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+  int count = 1;
+  SmallVector<int64_t> sgShape(shape);
+
+  if (layout && layout.isWgLayout()) {
+    DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
+    auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    else
+      sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+    SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
+    // Clamp distUnit to the original shape to handle cases where data is
+    // shared among subgroups, which may cause distUnit to exceed the original
+    // shape.
+    for (size_t i = 0; i < distUnit.size(); ++i)
+      distUnit[i] = std::min(shape[i], distUnit[i]);
+    count = computeProduct(shape) / computeProduct(distUnit);
+  }
+  return std::make_pair(sgShape, count);
+};
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -129,18 +153,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return rewriter.notifyMatchFailure(
           op, "sgLayout attribute is required in layout");
 
-    SmallVector<int64_t> sgShape;
-    if (auto sgDataAttr = layout.getSgData()) {
-      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    } else {
-      assert(wgShape.size() == sgLayout.size() &&
-             "sgLayout and wgShape must have the same rank");
-      sgShape.reserve(wgShape.size());
-      for (size_t i = 0; i < wgShape.size(); ++i) {
-        assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
-        sgShape.push_back(wgShape[i] / sgLayout[i]);
-      }
-    }
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
 
     // TODO : Handle order attribute
     // Get the subgroup ID
@@ -266,15 +279,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (resultTy.getRank() != 2)
       return failure();
 
-    auto originalLayout =
-        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    auto originalLayout = xegpu::getLayoutAttr(op.getResult());
     if (!originalLayout)
       return failure();
 
-    SmallVector<Value> newDpasOps;
     size_t i = 0;
+    SmallVector<Value> newDpasOps;
     for (auto aVec : adaptor.getLhs()) {
       for (auto bVec : adaptor.getRhs()) {
+
         llvm::SmallVector<Value> operands({aVec, bVec});
         Value tmpC;
         if (op.getAcc()) {
@@ -288,10 +301,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
             llvm::cast<VectorType>(bVec.getType()).getShape();
         VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
                                            resultTy.getElementType());
-        tmpC = rewriter.create<xegpu::DpasOp>(
-            loc, resTy, operands,
-            llvm::ArrayRef<NamedAttribute>(
-                {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
+        tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
+        xegpu::setLayoutAttr(cast<OpResult>(tmpC),
+                             originalLayout.dropSgLayoutAndData());
+
         newDpasOps.push_back(tmpC);
       }
     }
@@ -314,14 +327,69 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// Handles UnrealizedConversionCastOp generated during
+// SCFStructuralTypeConversions (step 1). This op may appear as either a
+// target or source materialization for Vector or TensorDesc, e.g.:
+// 1. unrealized_conversion_cast %1 : tensor_desc<16xf16> to
+// tensor_desc<128xf16, ...>
+// 2. unrealized_conversion_cast %1 : vector<256xf32> to vector<16xf32>, ...
+// 3. unrealized_conversion_cast %1 : vector<16xf32>, ... to vector<256xf32>
+// In all cases, the pattern simply forwards the inputs to the outputs with
+// one-to-one or one-to-n patterns.
+struct UnrealizedConversionCastOpPattern
+    : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
+  using OpConversionPattern<
+      mlir::UnrealizedConversionCastOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
+
+    // Handles the case where cast %1 : tensor_desc<16xf16> to
+    // tensor_desc<128xf16, ...> The input values provided by the adaptor should
+    // already be distributed.
+    if (op.getNumOperands() == 1 && op.getNumResults() == 1 &&
+        isa<xegpu::TensorDescType>(op->getOperand(0).getType()) &&
+        isa<xegpu::TensorDescType>(op->getResult(0).getType())) {
+      rewriter.replaceOp(op, inputs);
+      return success();
+    }
+
+    // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
+    // the input values provided by the adaptor should already be distributed,
+    // and their types should correspond exactly to the result types of the
+    // operation.
+    if (op.getNumOperands() == 1 &&
+        llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
+      rewriter.replaceOp(op, inputs);
+      return success();
+    }
+
+    // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
+    // All input values must have the same vector type, and their shape must be
+    // evenly divisible by the output vector's shape.
+    auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
+    auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
+    if (op.getNumResults() == 1 && inputTy && outputTy &&
+        llvm::all_equal(ValueRange(inputs).getTypes()) &&
+        computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
+      rewriter.replaceOpWithMultiple(op, {inputs});
+      return success();
+    }
+
+    return mlir::failure();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
-      patterns.getContext());
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+               UnrealizedConversionCastOpPattern>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -334,6 +402,47 @@ struct XeGPUWgToSgDistributePass
 } // namespace
 
 void XeGPUWgToSgDistributePass::runOnOperation() {
+  TypeConverter converter;
+  converter.addConversion([&](Type type) -> Type { return type; });
+  converter.addConversion(
+      [&](RankedTensorType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) = getSgShapeAndCount(
+            shape, dyn_cast<xegpu::LayoutAttr>(type.getEncoding()));
+
+        auto newTy = VectorType::get(subShape, elemTy);
+        result.append(count, newTy);
+        return success();
+      });
+
+  converter.addConversion(
+      [&](xegpu::TensorDescType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
+
+        if (layout)
+          layout = layout.dropSgLayoutAndData();
+
+        auto newTy = xegpu::TensorDescType::get(
+            type.getContext(), subShape, elemTy, type.getEncoding(), layout);
+        result.append(count, newTy);
+        return success();
+      });
+
+  // step1: perform SCFStructuralTypeConversions on SCF ops
+  xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
+
   MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(ctx);
   ConversionTarget target(*ctx);
@@ -353,24 +462,27 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   };
 
   auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
-    return !layout || layout.getSgLayout() == nullptr;
+    return !layout || !layout.isWgLayout();
   };
 
   target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
                                xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
                                xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
     auto tdescTy = getTensorDescType(op);
-    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
     return isLegal(layout);
   });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
-    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    auto layout = xegpu::getLayoutAttr(op.getResult());
     return isLegal(layout);
   });
 
+  target.addIllegalOp<UnrealizedConversionCastOp>();
+
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
+  // step2: Perform for workgroup to subgroup distribution for rest ops
   xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index dcaf4e85a82c5..6b85a66a8bd36 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -27,7 +27,7 @@
 using namespace mlir;
 
 /// convert ArrayRef<ValueRange> into SmallVector<Value>
-static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
   SmallVector<Value> result;
   for (const auto &vals : values)
     llvm::append_range(result, vals);
@@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
       auto resultTy = dyn_cast<RankedTensorType>(result.getType());
 
       // Only look at ops casting from VectorType to RankedTensorType
-      if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
+      if (!inputTy || !resultTy)
         return WalkResult::skip();
 
       xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
@@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
         }
 
         if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
-          SmallVector<Value> values = flattenValues(adaptor.getInputs());
+          SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
           auto newOp = rewriter.create<UnrealizedConversionCastOp>(
               op.getLoc(), outputTy, values);
           rewriter.replaceOp(op, newOp);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..ff86e65300bb8 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -85,7 +85,7 @@ gpu.module @test_round_robin_assignment {
     %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
       -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
@@ -102,4 +102,28 @@ gpu.module @test_round_robin_assignment {
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
+
+  gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %c1_i32 = arith.constant 1 : i32
+    %c10_i32 = arith.constant 10 : i32
+    %c0_i32 = arith.constant 0 : i32
+    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
+    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
+    %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
+      %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
+      //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
+      scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
+    } do {
+    // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32)
+    ^bb0(%arg2: vector<256xf32>, %arg3: i32):
+      xegpu.store_nd %arg2, %2  : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %4 = arith.addi %arg3, %c1_i32 : i32
+      %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %6 = xegpu.load_nd %5  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
+      scf.yield %6, %4 : vector<256xf32>, i32
+    }
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7e89ada934071..d016d3a30a339 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -5,7 +5,7 @@
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: test_create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {  
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
   // CHECK: %[[SGID:.*]] = gpu.subgroup_id
   // CHECK: %[[C12:.*]] = arith.constant 12 : index
   // CHECK: %[[C4:.*]] = arith.constant 4 : index
@@ -108,7 +108,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
       -> vector<32x24xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
@@ -142,7 +142,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
       -> vector<32x24xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
@@ -169,4 +169,68 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+    %c0 = arith.constant 0 : index
+    %c128 = arith.constant 128 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c128 : index
+    %1 = arith.muli %block_id_y, %c128 : index
+    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+
+    //CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
+    //CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    //CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    //CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    //CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
+    //CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
+    //CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
+      %8 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+      %9 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+      %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+      %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+      %12 = xegpu.upd...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/142618


More information about the Mlir-commits mailing list