[Mlir-commits] [mlir] [mlir][xegpu] SIMT distribution patterns for XeGPU CreateNdTdesc, LoadNd, StoreNd and Dpas Ops. (PR #135271)

Charitha Saumya llvmlistbot at llvm.org
Wed Apr 16 13:13:49 PDT 2025


================
@@ -634,6 +678,643 @@ void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
   }
 }
 
+static void attachLayoutAttributeToUsers(Value v, xegpu::LayoutAttr layout) {
+  for (OpOperand &user : v.getUses()) {
+    Operation *owner = user.getOwner();
+    unsigned operandNumber = user.getOperandNumber();
+    /// Use a generic name for ease of querying the layout attribute later.
+    std::string attrName =
+        operandLayoutNamePrefix + std::to_string(operandNumber);
+    owner->setAttr(attrName, layout);
+  }
+}
+
+static LogicalResult attachLayoutAttributes(
+    Operation *top, llvm::function_ref<LayoutInfo(Value)> getPropagatedLayout) {
+  /// Helper to convert the layout info to the xegpu::LayoutAttr.
+  auto getLayoutInfoForResult = [&](Value r) -> xegpu::LayoutAttr {
+    auto layout = getPropagatedLayout(r);
+    if (!layout.isAssigned())
+      return {};
+    SmallVector<int, 2> laneLayout, laneData;
+    for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
+                                               layout.getDataAsArrayRef())) {
+      laneLayout.push_back(static_cast<int>(layout));
+      laneData.push_back(static_cast<int>(data));
+    }
+    return xegpu::LayoutAttr::get(r.getContext(), laneLayout, laneData);
+  };
+  /// Attach the layout attributes to the results of the operations.
+  auto walkResult = top->walk([&](Operation *op) {
+    /// For function ops, propagate the argument layout to the users.
+    if (auto func = dyn_cast<FunctionOpInterface>(op)) {
+      for (auto arg : func.getArguments()) {
+        auto layoutInfo = getLayoutInfoForResult(arg);
+        if (layoutInfo) {
+          attachLayoutAttributeToUsers(arg, layoutInfo);
+        }
+      }
+      return WalkResult::advance();
+    }
+    /// If no results, move on.
+    if (op->getNumResults() == 0)
+      return WalkResult::advance();
+    /// If all the results are scalars, move on.
+    if (llvm::all_of(op->getResultTypes(),
+                     [](Type t) { return t.isIntOrIndexOrFloat(); }))
+      return WalkResult::advance();
+
+    if (auto tensorDescTy =
+            dyn_cast<xegpu::TensorDescType>(op->getResult(0).getType())) {
+      auto layoutInfo = getLayoutInfoForResult(op->getResult(0));
+      if (!layoutInfo) {
+        LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
+        return WalkResult::interrupt();
+      }
+
+      /// Clone the op, attach the sg_map to the result tensor descriptor, and
+      /// remove the original op.
+      OpBuilder builder(op);
+      auto *newOp = builder.clone(*op);
+      auto newTensorDescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(),
+          layoutInfo);
+      newOp->getResult(0).setType(newTensorDescTy);
+      op->replaceAllUsesWith(newOp->getResults());
+      op->erase();
+      return WalkResult::advance();
+    }
+    /// Otherwise simply attach the layout to the op itself.
+    for (auto [i, r] : llvm::enumerate(op->getResults())) {
+      auto layoutInfo = getLayoutInfoForResult(r);
+      if (layoutInfo) {
+        auto attrName = resultLayoutNamePrefix + std::to_string(i);
+        op->setAttr(attrName, layoutInfo);
+        /// Attach the layout attribute to the users of the result.
+        attachLayoutAttributeToUsers(r, layoutInfo);
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+static LogicalResult resolveLayoutConflicts(Operation *top) {
+  /// TODO: Implement the layout conflict resolution.
+  return success();
+}
+
+namespace {
+
+///===----------------------------------------------------------------------===///
+/// SIMT Distribution Patterns
+///===----------------------------------------------------------------------===///
+
+/// Returns the distributed vector type for a source vector type according to
+/// the lane_layout. We simply divide each dimension of tensor descriptor shape
+/// by corresponding lane_layout dimension. If array_length > 1, that is
+/// appended to the front of the disributed shape.
+///
+/// Examples:
+/// | original vector shape | lane_layout | distributed vector shape |
+/// |-----------------------|-------------|--------------------------|
+/// | 32x16                 | [1, 16]     | 32x1                     |
+/// | 32x16                 | [2, 8]      | 16x2                     |
+/// | 2x32x16               | [1, 16]     | 2x32x1                   |
+FailureOr<VectorType> getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
+                                                      VectorType originalType) {
+  if (!layout)
+    return failure();
+
+  auto laneLayout = layout.getLaneLayout().asArrayRef();
+  assert(originalType.getShape().size() >= laneLayout.size() &&
+         "Rank of the original vector type should be greater or equal to the "
+         "size of the lane layout to distribute the vector type.");
+  SmallVector<int64_t> distributedShape(originalType.getShape());
+  /// Only distribute the last `laneLayout.size()` dimensions. The remaining
+  /// dimensions are not distributed.
+  unsigned distributionStart = originalType.getRank() - laneLayout.size();
+  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
+    if (i < distributionStart) {
+      continue;
+    }
+    /// Check if the dimension can be distributed evenly.
+    if (dim % laneLayout[i - distributionStart] != 0)
+      return failure();
+    distributedShape[i] = dim / laneLayout[i - distributionStart];
+  }
+  return VectorType::get(distributedShape, originalType.getElementType());
+}
+
+static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
+                                           VectorType originalType) {
+  auto shape = originalType.getShape();
+  auto distVecTyOrFailure =
+      xegpu::TensorDescType::get(shape, originalType.getElementType(),
+                                 /*array_length=*/1, /*boundary_check=*/true,
+                                 /*memory_space=*/xegpu::MemorySpace::Global,
+                                 layout)
+          .getDistributedVectorType();
+  assert(llvm::succeeded(distVecTyOrFailure) &&
+         "Failed to compute distributed vector type for the given vector type");
+  return distVecTyOrFailure.value();
+}
+
+static xegpu::TensorDescType dropLayouts(xegpu::TensorDescType tensorDesc) {
+  return xegpu::TensorDescType::get(
+      tensorDesc.getContext(), tensorDesc.getShape(),
+      tensorDesc.getElementType(), tensorDesc.getEncoding(),
+      xegpu::LayoutAttr());
+}
+
+template <typename T>
+static Value resolveDistributedTy(Value orig, T expected,
+                                  PatternRewriter &rewriter) {
+  /// If orig and expected types are the same, return orig.
+  if (orig.getType() == expected)
+    return orig;
+  /// If orig is a vector type, create a shape cast op to reconcile the types.
+  if (auto origVecType = isa<VectorType>(orig.getType())) {
+    auto castOp =
+        rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
+    return castOp.getResult();
+  }
+  /// If orig is a tensor descriptor type, create an unrealized conversion cast
+  /// op to reconcile the types.
+  if (auto origTensorDescTy = isa<xegpu::TensorDescType>(orig.getType())) {
+    auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
+                                                              expected, orig);
+    return castOp.getResult(0);
+  }
+  llvm_unreachable("Unsupported type for reconciliation");
+  return orig;
+}
+
+// static Value reconcileDistributedTensorDescTy(Value orig,
+//                                               xegpu::TensorDescType expected,
+//                                               PatternRewriter &rewriter) {
+//   assert(isa<xegpu::TensorDescType>(orig.getType()) &&
+//          "expecting tensor descriptor type");
+//   auto origTensorDescTy = cast<xegpu::TensorDescType>(orig.getType());
+//   /// No need to reconcile if the types are the same.
+//   if (origTensorDescTy == expected)
+//     return orig;
+//   auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
+//                                                             expected, orig);
+//   return castOp.getResult(0);
+// }
+
+// // unify above 2 functions with a template
+// template <typename T>
+// static Value reconcileDistributedType(Value orig, T expected,
+//                                        PatternRewriter &rewriter) {
+//   if constexpr (std::is_same_v<T, VectorType>) {
+//     return reconcileDistributedVecType(orig, expected, rewriter);
+//   } else if constexpr (std::is_same_v<T, xegpu::TensorDescType>) {
+//     return reconcileDistributedTensorDescTy(orig, expected, rewriter);
+//   } else {
+//     static_assert(llvm::is_one_of<T, VectorType,
+//     xegpu::TensorDescType>::value,
+//                   "Unsupported type for reconciliation");
+//   }
+//   return orig;
+// }
+
+static SmallVector<NamedAttribute>
+filterTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute> newAttrs;
+  for (auto attr : attrs) {
+    if (attr.getName().strref().contains(operandLayoutNamePrefix) ||
+        attr.getName().strref().contains(resultLayoutNamePrefix)) {
+      continue;
+    }
+    newAttrs.push_back(attr);
+  }
+  return newAttrs;
+}
+
+/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
+/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
+/// contained within a WarpExecuteOnLane0Op.
+/// Example:
+///
+/// ```
+///   gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
+///     ...
+///     ...
+///     gpu.return %result: vector<8x16xf32>
+///   }
+/// ```
+/// To
+/// ```
+///   gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
+///     %laneid = gpu.lane_id : index
+///     %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
+///       ...
+///       ...
+///       gpu.yield %result: vector<8x16xf32>
+///     }
+///     return %0
+///   }
+struct MoveFuncBodyToWarpExecuteOnLane0
+    : public OpRewritePattern<gpu::GPUFuncOp> {
+  using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
+                                PatternRewriter &rewriter) const override {
+    /// If the function only contains a single void return, skip.
+    if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
+          return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
+        }))
+      return failure();
+    /// If the function already moved inside a warp_execute_on_lane0, skip.
+    if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
+          return isa<gpu::WarpExecuteOnLane0Op>(op);
+        }))
+      return failure();
+    /// Create a new function with the same signature.
+    auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
+        gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
+    /// Create a WarpExecuteOnLane0Op with same arguments and results as the
+    /// original gpuFuncOp.
+    rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
+    auto laneId = rewriter.create<gpu::LaneIdOp>(
+        newGpuFunc.getLoc(), rewriter.getIndexType(),
+        /** upperBound = **/ mlir::IntegerAttr());
+    auto gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
+    auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
+        laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
+        newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
+    auto &warpBodyBlock = warpOp.getBodyRegion().front();
+    /// Replace the ReturnOp of the original gpu function with a YieldOp.
+    auto origRetunOp =
+        cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
+    rewriter.setInsertionPointAfter(origRetunOp);
+    rewriter.create<gpu::YieldOp>(origRetunOp.getLoc(),
+                                  origRetunOp.getOperands());
+    rewriter.eraseOp(origRetunOp);
+    /// Move the original function body to the WarpExecuteOnLane0Op body.
+    rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
+                                warpOp.getBodyRegion().begin());
+    rewriter.eraseBlock(&warpBodyBlock);
+    /// Insert a new ReturnOp after the WarpExecuteOnLane0Op.
+    rewriter.setInsertionPointAfter(warpOp);
+    rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
+    rewriter.replaceOp(gpuFuncOp, newGpuFunc);
+    return success();
+  }
+};
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `gpu.warp_execute_on_lane_0` and put it after the warp op. The warp op
+/// will still contain the original op that will not be used by the yield op
+/// (and should be cleaned up later with dce). The yield op will bypass the
+/// create_nd_tdesc's arguments. Tensor descriptor is not distributed because
+/// it is a uniform value accorss all work items within the subgroup.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = gpu.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %td
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %arg0, %dead
+///   }
+///   %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
+///                                 -> !xegpu.tensor_desc<4x8xf32>
+///
+/// ```
+struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
+    auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+
+    auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
+    if (!srcTypedVal)
+      return rewriter.notifyMatchFailure(
+          descOp, "expecting a memref typed value as the source");
+
+    auto descOffsets = descOp.getMixedOffsets();
+
+    xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          descOp, "the tensor descriptor lacks sg_map attribute");
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> newYieldValues;
+    SmallVector<Type> newYieldTypes;
+
+    for (auto arg : descOp->getOperands()) {
+      newYieldValues.push_back(arg);
+      newYieldTypes.push_back(arg.getType());
+    }
+    rewriter.setInsertionPoint(subgroupOp);
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
+        /* new yielded types = */ newYieldTypes, newRetIndices);
+
+    SmallVector<Value> newDescOperands;
+    for (auto i : newRetIndices) {
+      newDescOperands.push_back(newWarpOp.getResult(i));
+    }
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto distributedTensorDescTy =
+        dropLayouts(descOp.getType()); /// Distributed tensor descriptor type
+                                       /// does not contain layout info.
+    auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+        newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
+        descOp->getAttrs());
+
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newDescOp);
+    return success();
+  }
+};
+
+/// Sink a store_nd op at the end of enclosing `gpu.warp_execute_on_lane_0`.
+/// In case arguments for the store are passed through the warp op interface
+/// they would be propagated as returned values. Only the source vector for
+/// the store is distributed according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   gpu.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+///                                 !xegpu.tensor_desc<4x8xf32>
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
+///     gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32>
+///   }
+///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>,
+///     !xegpu.tensor_desc<4x8xf32>
+///
+/// ```
+struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
+    if (!storeOp)
+      return failure();
+
+    auto tensorDescTy = storeOp.getTensorDescType();
+    xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          storeOp, "the source tensor descriptor lacks sg_map attribute");
+
+    auto distributedTypeByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
+    if (failed(distributedTypeByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(storeOp,
+                                         "Failed to distribute the type");
+    VectorType distributedTypeByWarpOp =
+        distributedTypeByWarpOpOrFailure.value();
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, subgroupOp,
+        /* new yielded values = */
+        ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
+        /* new yielded types = */
+        TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
+        newRetIndices);
+    /// Create a new store op outside the warp op with the distributed vector
+    /// type. Tensor descriptor is not distributed.
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Value> newStoreOperands;
+
+    /// For the value operand, there can be a mismatch between the vector type
+    /// distributed by the warp op and (xegpu-specific) distributed type
+    /// supported by the store op. Type mismatch must be resolved using
+    /// appropriate cast op.
+    auto storeNdDistributedValueTyOrFailure =
+        storeOp.getTensorDescType().getDistributedVectorType();
+    if (failed(storeNdDistributedValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          storeOp, "Failed to get distributed vector type for the store op");
+    newStoreOperands.push_back(resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[0]),
+        storeNdDistributedValueTyOrFailure.value(), rewriter));
+    /// For the tensor descriptor operand, the layout attibute is dropped after
+    /// distribution. Types needs to be resolved in this case also.
+    auto distributedTensorDescTy = dropLayouts(storeOp.getTensorDescType());
+    newStoreOperands.push_back(
+        resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
+                             distributedTensorDescTy, rewriter));
+
+    rewriter.create<xegpu::StoreNdOp>(
+        newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
+        filterTemporaryLayoutAttributes(storeOp->getAttrs()));
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by
+/// the yield op (and should be cleaned up later with dce). The yield op will
+/// bypass the load's arguments. Only the loaded vector is distributed
+/// according to sg_map attribute and, tensor descriptor types is not
+/// distributed.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = gpu.warp_execute_on_lane_0(%laneid) ->
+///                   (vector<4x1xf32>) {
+///     ...
+///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32> ->
+///       vector<4x8xf32>
+///     gpu.yield %ld
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32> ->
+///     vector<4x8xf32> gpu.yield %arg0, %arg1
+///   }
+///   %ld = xegpu.load_nd %r#0: !xegpu.tensor_desc<4x8xf32> -> vector<4x1xf32>
+///
+/// ```
+struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          subgroupOp, "warp result is not a xegpu::LoadNd op");
+
+    auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+    xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
+    xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          loadOp, "the source tensor descriptor lacks sg_map attribute");
+
+    unsigned operandIdx = operand->getOperandNumber();
+    VectorType distributedTypeByWarpOp =
+        cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, subgroupOp,
+        /* new yielded values = */ loadOp.getTensorDesc(),
+        /* new yielded types = */ tensorDescTy, newRetIndices);
+
+    /// Create a new load op outside the warp op with the distributed vector
+    /// type.
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto loadNdDistValueTyOrFailure =
+        loadOp.getTensorDescType().getDistributedVectorType();
+    if (failed(loadNdDistValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          loadOp, "Failed to get distributed vector type for the load op");
+    auto distributedTensorDescTy =
+        dropLayouts(loadOp.getTensorDescType()); /// Distributed tensor
+                                                 /// descriptor type does not
+                                                 /// contain layout info.
+    Value newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+        newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
+        resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
+                             distributedTensorDescTy, rewriter),
+        filterTemporaryLayoutAttributes(loadOp->getAttrs()));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    /// There can be a conflict between the vector type distributed by the
+    /// warp op and (xegpu-specific) distributed type supported by the load
+    /// op. Resolve these mismatches by inserting a cast.
+    newLoadOp =
+        resolveDistributedTy(newLoadOp, distributedTypeByWarpOp, rewriter);
+    rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
+    return success();
+  }
+};
+
+struct DpasDistribution final : public gpu::WarpDistributionPattern {
----------------
charithaintc wrote:

added all the missing comments and refactored code for clarity. 

https://github.com/llvm/llvm-project/pull/135271


More information about the Mlir-commits mailing list