[Mlir-commits] [mlir] 0210750 - [MLIR][XeGPU] Add unroll patterns and blocking pass for XeGPU [2/N] (#140163)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 2 12:02:48 PDT 2025


Author: Chao Chen
Date: 2025-06-02T14:02:45-05:00
New Revision: 0210750d5a5b4cfc8d2b6a9e94ace24d31d65ddc

URL: https://github.com/llvm/llvm-project/commit/0210750d5a5b4cfc8d2b6a9e94ace24d31d65ddc
DIFF: https://github.com/llvm/llvm-project/commit/0210750d5a5b4cfc8d2b6a9e94ace24d31d65ddc.diff

LOG: [MLIR][XeGPU] Add unroll patterns and blocking pass for XeGPU [2/N] (#140163)

This PR introduces the initial implementation of a blocking pass for
XeGPU programs. The pass leverages unroll patterns from both the XeGPU
and Vector dialects. 

---------

Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>

Added: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
    mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
    mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 032ce5bc18334..84c1dc1373ee5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -295,11 +295,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     LayoutAttr dropSgLayoutAndData() {
+      // avoid every field of the attribute is nullptr, which may lead to segment fault
+      if (!getInstData() && !getLaneLayout())
+        return nullptr;
       return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
                              getLaneLayout(), getLaneData(), getOrder());
     }
 
     LayoutAttr dropInstData() {
+      // avoid every field of the attribute is nullptr, which may lead to segment fault
+      if (!getSgLayout() && !getLaneLayout())
+        return nullptr;
       return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
                              getLaneLayout(), getLaneData(), getOrder());
     }

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 6f585f9ceb29b..8bdf19ac0e47d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -45,4 +45,17 @@ def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
                            "gpu::GPUDialect", "index::IndexDialect"];
 }
 
+def XeGPUBlocking: Pass<"xegpu-blocking"> {
+  let summary = "Block XeGPU ops into smaller size.";
+  let description = [{
+    This pass partitions operations that process large shapes into multiple
+    operations on smaller shapes, as specified by the inst_data in the layout
+    attribute. This enables each resulting operation to be efficiently mapped
+    to a hardware instruction.
+  }];
+  let dependentDialects = [
+      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 3616fa614e7f9..f9327d63869c0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -13,6 +13,12 @@
 namespace mlir {
 
 class VectorType;
+class OpOperand;
+class OpResult;
+class OpBuilder;
+class ValueRange;
+class TypeConverter;
+
 namespace xegpu {
 class LayoutAttr;
 class TensorDescType;
@@ -50,6 +56,59 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
 FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
                                                LayoutAttr layout);
 
+/// Return the attribute name for the OpOperand to attach LayoutAttr
+std::string getLayoutName(const OpOperand &operand);
+
+/// Return the attribute name for the OpResult to attach LayoutAttr
+std::string getLayoutName(const OpResult result);
+
+/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
+/// values, the LayoutAttr is extracted from the TensorDescType itself. For
+/// other values, it is obtained from the attributes of the defining operation.
+/// Returns nullptr if no LayoutAttr is found.
+LayoutAttr getLayoutAttr(const Value value);
+
+/// Retrieves the LayoutAttr associated with a given OpOperand. It will
+/// first check the operand_layout_{id} of the owner operation. If not found,
+/// it will check the operand itself and its defining op.
+LayoutAttr getLayoutAttr(const OpOperand &opr);
+
+/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
+/// it to the owner's dictionary attributes
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);
+
+/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
+/// If the operation contains regions, it is also applied recursively to the
+/// contained operations
+void setLayoutAttrs(Operation *op,
+                    function_ref<LayoutAttr(Value)> getLayoutImpl);
+
+/// Extract a set of small vectors from a value with a given shape using
+/// vector.extract_stride_slice
+SmallVector<Value> extractVectorsWithShapeFromValue(OpBuilder &builder,
+                                                    Location loc, Value value,
+                                                    ArrayRef<int64_t> shape);
+
+/// Create a vector of shape from a set of values using
+/// vector.insert_stride_slice.
+Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
+                                      ValueRange values,
+                                      ArrayRef<int64_t> shape);
+
+/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure
+/// type convertion patterns. Since VectorType cannot carry the layout
+/// attribute, which is needed to guide the type conversion for XeGPU, they are
+/// first converted into RankedTensorType, where the layout attribute can be
+/// attached. And then upstream SCF structural type conversion patterns are
+/// applied with the provided converter.
+/// TODO: This is a temporary solution. We should refactor it when context-aware
+/// type conversion is available.
+void doSCFStructuralTypeConversionWithTensorType(Operation *op,
+                                                 TypeConverter converter);
+
 } // namespace xegpu
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7d9b5584b0b2b..af0d7f6bd9070 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
+  XeGPUBlocking.cpp
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
new file mode 100644
index 0000000000000..6e736cb7e6972
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -0,0 +1,337 @@
+//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUBLOCKING
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-blocking"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// reslove the unrealized conversion cast ops generated when doing SCF
+// Structural Type Conversion. It will have two formats, N:1 vector
+// cast and 1:N vector cast. vector::insert_strided_slice ops will be
+// used for the first case, and vector::extract_strided_slice ops will be
+// used for the second case.
+static void
+resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
+  ValueRange inputs = castOp.getInputs();
+  ValueRange outputs = castOp.getOutputs();
+
+  auto hasIdenticalVectorTypes = [](ValueRange values) {
+    auto types = values.getTypes();
+    return llvm::all_of(types, [&](Type type) {
+      return isa<VectorType>(type) && type == types.front();
+    });
+  };
+
+  // We only interest in the case where all inputs and outputs have the
+  // identical VectorTypes
+  if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
+    LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
+    return;
+  }
+
+  VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
+  OpBuilder builder(castOp);
+  if (inputs.size() > 1 && outputs.size() == 1) {
+    // the castOp is emulating an unpack op
+    ArrayRef<int64_t> shape = outputTy.getShape();
+    Value result = xegpu::createVectorWithShapeFromValues(
+        builder, castOp.getLoc(), inputs, shape);
+    castOp->replaceAllUsesWith(ValueRange(result));
+    castOp->erase();
+  } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
+    // the castOp is emulating a pack op
+    ArrayRef<int64_t> tileShape = outputTy.getShape();
+    SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
+        builder, castOp.getLoc(), inputs[0], tileShape);
+    castOp->replaceAllUsesWith(results);
+    castOp->erase();
+  }
+}
+
+//===------------------------------------------------------------------------===//
+// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
+// to partition operations that process large shapes into multiple operations on
+// smaller shapes, as specified by the inst_data in the layout attribute. This
+// enables each resulting operation to be efficiently mapped to a hardware
+// instruction.
+//===------------------------------------------------------------------------===//
+
+class XeGPUBlockingPass final
+    : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
+public:
+  void runOnOperation() override;
+
+private:
+  // Get the tile shape for a given OpOperand or OpResult by examining the
+  // corresponding layout attribute. If layout is not present or is not a
+  // subgroup level layout, it returns std::nullopt.
+  template <typename T,
+            typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                        std::is_same_v<T, OpResult>>>
+  std::optional<SmallVector<int64_t>>
+  getTileShape(const T &operandOrResult) const;
+
+  // Get the tile shape for a given operation.
+  std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
+
+  // Determine if the operation requires unrolling. Return false if all operands
+  // and results have tile shapes identical to their original types. Otherwise,
+  // return true.
+  bool needsUnroll(Operation *op) const;
+};
+} // namespace
+
+template <typename T, typename>
+std::optional<SmallVector<int64_t>>
+XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
+  Value value;
+  if constexpr (std::is_same_v<T, OpOperand>)
+    value = operandOrResult.get();
+  else
+    value = (Value)operandOrResult;
+
+  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
+  if (layout && layout.isSgLayout()) {
+    if (auto inst_data = layout.getInstData())
+      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+
+    if (auto type = dyn_cast<ShapedType>(value.getType()))
+      return llvm::to_vector(type.getShape());
+  }
+  LDBG("failed to getTileShape for: " << value);
+  return std::nullopt;
+}
+
+std::optional<SmallVector<int64_t>>
+XeGPUBlockingPass::getTileShape(Operation *op) const {
+  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+    return getTileShape(op->getOpResult(0));
+  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+    return getTileShape(op->getOpOperand(0));
+  if (isa<xegpu::StoreNdOp>(op))
+    return getTileShape(op->getOpOperand(1));
+
+  if (isa<xegpu::DpasOp>(op)) {
+    std::optional<SmallVector<int64_t>> aTile =
+        getTileShape(op->getOpOperand(0));
+    std::optional<SmallVector<int64_t>> bTile =
+        getTileShape(op->getOpOperand(1));
+
+    if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
+      return std::nullopt;
+
+    // semantic check for A and B
+    if ((*aTile)[1] != (*bTile)[0])
+      return std::nullopt;
+
+    // semantic check for C
+    if (op->getNumOperands() == 3) {
+      std::optional<SmallVector<int64_t>> cTile =
+          getTileShape(op->getOpOperand(2));
+      int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
+      if (!cTile || !llvm::equal(*cTile, expectedCTile))
+        return std::nullopt;
+    }
+
+    return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
+  }
+
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
+    return getTileShape(op->getOpResult(0));
+
+  return std::nullopt;
+}
+
+bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
+  // skip the op if any of its operands or results has workgroup level layouts
+  bool hasWgLayoutOperands =
+      llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
+        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
+        return layout && layout.isWgLayout();
+      });
+  bool hasWgLayoutResults =
+      llvm::any_of(op->getOpResults(), [](OpResult result) {
+        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+        return layout && layout.isWgLayout();
+      });
+  if (hasWgLayoutOperands || hasWgLayoutResults) {
+    LDBG("skip unrolling for op with workgroup level layout: " << *op);
+    return false;
+  }
+
+  auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
+    Type valTy = value.getType();
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
+      xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
+      return layout && layout.getInstData();
+    }
+    auto shapedType = dyn_cast<ShapedType>(valTy);
+    return shapedType && !llvm::equal(tileShape, shapedType.getShape());
+  };
+
+  bool hasUnrollableOperands =
+      llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) {
+        std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
+        return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
+      });
+  bool hasUnrollableResults =
+      llvm::any_of(op->getOpResults(), [&](OpResult result) {
+        std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
+        return tileShape.has_value() && isUnrollable(result, *tileShape);
+      });
+  return hasUnrollableOperands || hasUnrollableResults;
+}
+
+void XeGPUBlockingPass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  Operation *op = getOperation();
+
+  // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
+  // This ensures that the LayoutAttr remains accessible even if the defining
+  // operation is replaced.
+  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
+
+  auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
+                                 xegpu::LayoutAttr layout) {
+    int count = 1;
+    SmallVector<int64_t> tileShape(shape);
+    if (layout && layout.getInstData()) {
+      DenseI32ArrayAttr instData = layout.getInstData();
+      tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+      count = computeProduct(shape) / computeProduct(tileShape);
+    }
+    return std::make_pair(tileShape, count);
+  };
+
+  // Perform type conversion for SCF control folow ops
+  TypeConverter converter;
+  converter.addConversion([](Type type) -> Type { return type; });
+  converter.addConversion(
+      [&](RankedTensorType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        auto layout =
+            llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
+        if (layout && layout.isWgLayout())
+          return failure();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
+        auto newTy = VectorType::get(subShape, elemTy);
+        result.append(count, newTy);
+        return success();
+      });
+  converter.addConversion(
+      [&](xegpu::TensorDescType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        if (layout && layout.isWgLayout())
+          return failure();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
+
+        if (layout)
+          layout = layout.dropInstData();
+
+        auto newTy = xegpu::TensorDescType::get(
+            type.getContext(), subShape, elemTy, type.getEncoding(), layout);
+        result.append(count, newTy);
+        return success();
+      });
+
+  xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
+
+  xegpu::UnrollOptions options;
+  options.setFilterConstraint(
+      [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
+
+  options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
+
+  options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
+    Type elemTy = type.getElementType();
+    Type newTy;
+
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
+      newTy = xegpu::TensorDescType::get(
+          ctx, tileShape, elemTy, tdescTy.getEncoding(),
+          tdescTy.getLayoutAttr().dropInstData());
+    else
+      newTy = type.clone(tileShape, elemTy);
+
+    std::optional<SmallVector<int64_t>> ratio =
+        computeShapeRatio(type.getShape(), tileShape);
+    assert(ratio && "The shape of the type must be a multiple of tileShape.");
+    return SmallVector<Type>(computeProduct(*ratio), newTy);
+  });
+
+  RewritePatternSet patterns(ctx);
+
+  vector::UnrollVectorOptions vectorOptions;
+  vectorOptions.setNativeShapeFn(options.nativeShape);
+
+  populateXeGPUUnrollPatterns(patterns, options);
+  vector::populateVectorUnrollPatterns(patterns, vectorOptions);
+
+  (void)applyPatternsGreedily(op, std::move(patterns));
+
+  op->walk([](Operation *op) {
+    // Resolve unrealized conversion cast ops emulating pack/unpack
+    if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+      resolveUnrealizedConversionCastOp(castOp);
+
+    // Remove the layout attributes cached per operands.
+    for (OpOperand &opr : op->getOpOperands()) {
+      std::string name = xegpu::getLayoutName(opr);
+      if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name))
+        op->removeAttr(name);
+    }
+
+    // Update the layout attributes per result.
+    for (OpResult result : op->getOpResults()) {
+      std::string name = xegpu::getLayoutName(result);
+      if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
+        op->removeAttr(name);
+        if (!isa<LoopLikeOpInterface>(op))
+          xegpu::setLayoutAttr(result, layout.dropInstData());
+      }
+    }
+  });
+}

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 992700524146a..c84906cc45568 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -62,8 +62,6 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
-static const char *const operandLayoutNamePrefix = "layout_operand_";
-static const char *const resultLayoutNamePrefix = "layout_result_";
 
 namespace {
 
@@ -729,10 +727,7 @@ class LayoutAttrAssignment {
 void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
   for (OpOperand &user : v.getUses()) {
     Operation *owner = user.getOwner();
-    unsigned operandNumber = user.getOperandNumber();
-    // Use a generic name for ease of querying the layout attribute later.
-    std::string attrName =
-        operandLayoutNamePrefix + std::to_string(operandNumber);
+    std::string attrName = xegpu::getLayoutName(user);
     owner->setAttr(attrName, layout);
   }
 }
@@ -806,10 +801,10 @@ LogicalResult LayoutAttrAssignment::assign(Operation *op) {
     return success();
   }
   // Otherwise simply attach the layout to the op itself.
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
+  for (auto r : op->getOpResults()) {
     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
     if (layoutInfo) {
-      std::string attrName = resultLayoutNamePrefix + std::to_string(i);
+      std::string attrName = xegpu::getLayoutName(r);
       op->setAttr(attrName, layoutInfo);
       // Attach the layout attribute to the users of the result.
       assignToUsers(r, layoutInfo);
@@ -929,11 +924,8 @@ static SmallVector<NamedAttribute>
 removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
   SmallVector<NamedAttribute> newAttrs;
   for (NamedAttribute attr : attrs) {
-    if (attr.getName().strref().contains(operandLayoutNamePrefix) ||
-        attr.getName().strref().contains(resultLayoutNamePrefix)) {
-      continue;
-    }
-    newAttrs.push_back(attr);
+    if (!isa<xegpu::LayoutAttr>(attr.getValue()))
+      newAttrs.push_back(attr);
   }
   return newAttrs;
 }
@@ -1336,11 +1328,10 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
 
     auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
     unsigned operandIdx = operand->getOperandNumber();
-    std::string layoutAName =
-        llvm::formatv("{0}{1}", operandLayoutNamePrefix, 0).str();
-    std::string layoutBName =
-        llvm::formatv("{0}{1}", operandLayoutNamePrefix, 1).str();
-    auto layoutCName = llvm::formatv("{0}{1}", resultLayoutNamePrefix, 0).str();
+    std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
+    std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
+    std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
+
     xegpu::LayoutAttr layoutA =
         dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
     xegpu::LayoutAttr layoutB =

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 44d45dd2eaec0..885477fe4cbd5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
@@ -74,17 +75,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
       assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
              "Expecting blockSize size to match the rank of destTy.");
       auto shape = vecTy.getShape();
-      auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
-
-      Value result = rewriter.create<arith::ConstantOp>(
-          loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
-      for (auto [src, offsets] :
-           llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) {
-        SmallVector<int64_t> staticStrides(offsets.size(), 1);
-        result = rewriter.create<vector::InsertStridedSliceOp>(
-            loc, src, result, offsets, staticStrides);
-      }
-      return result;
+      return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
     }
 
     if (isa<xegpu::TensorDescType>(destTy)) {
@@ -109,16 +100,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
     if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
       assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
              "Expecting blockSize size to match the rank of src.");
-      auto shape = vecTy.getShape();
-      SmallVector<Value> results;
-      for (SmallVector<int64_t> offsets :
-           StaticTileOffsetRange(shape, blockSize)) {
-        SmallVector<int64_t> staticStrides(offsets.size(), 1);
-        auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, src, offsets, blockSize, staticStrides);
-        results.push_back(slice);
-      }
-      return results;
+      return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
+                                                     blockSize);
     }
 
     if (isa<xegpu::TensorDescType>(src.getType())) {
@@ -153,7 +136,7 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
     ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
@@ -204,10 +187,9 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     SmallVector<Type> convertedTdescTypes =
@@ -233,10 +215,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     SmallVector<Type> convertedTdescTypes =
@@ -260,10 +241,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     Location loc = op.getLoc();
     VectorType valueTy = op.getType();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     Type elemTy = tdescTy.getElementType();
@@ -295,10 +275,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     Location loc = op.getLoc();
     VectorType valueTy = op.getValueType();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     SmallVector<Type> convertedValTypes =

diff  --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index afd8e2d5c4df3..98e84a4420722 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -6,5 +6,6 @@ add_mlir_dialect_library(MLIRXeGPUUtils
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRSCFTransforms
   MLIRXeGPUDialect
   )

diff  --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6b45ed0ae4ced..974aac94f9699 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,12 +11,29 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
 using namespace mlir;
 
+/// convert ArrayRef<ValueRange> into SmallVector<Value>
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
 FailureOr<VectorType>
 mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
@@ -83,3 +100,268 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
       /*memory_space=*/xegpu::MemorySpace::Global, layout);
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
+
+std::string xegpu::getLayoutName(const OpOperand &operand) {
+  const StringRef prefix("layout_operand_");
+  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+  return llvm::formatv("{0}{1}", prefix, idx).str();
+}
+
+std::string xegpu::getLayoutName(const OpResult result) {
+  const StringRef prefix = "layout_result_";
+  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+}
+
+xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
+  if (!value)
+    return nullptr;
+
+  if (auto tdescTy =
+          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+    return tdescTy.getLayoutAttr();
+
+  if (auto result = dyn_cast<OpResult>(value)) {
+    Operation *defOp = result.getDefiningOp();
+    assert(defOp && "result must have a defining op");
+
+    // for LoadNdOp, the layout is stored in the tensor descriptor
+    if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
+      return getLayoutAttr(loadNd.getTensorDesc());
+
+    std::string layoutName = getLayoutName(result);
+    if (defOp->hasAttr(layoutName))
+      return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+  }
+
+  if (auto arg = dyn_cast<BlockArgument>(value)) {
+    auto parentOp = arg.getOwner()->getParentOp();
+    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+      return getLayoutAttr(tiedInit->get());
+    }
+  }
+
+  return nullptr;
+}
+
+xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
+  Operation *op = opr.getOwner();
+  std::string layoutName = xegpu::getLayoutName(opr);
+  if (op->hasAttr(layoutName))
+    return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+  return getLayoutAttr(opr.get());
+}
+
+template <typename T, typename>
+void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getLayoutName(operandOrResult);
+  if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+    owner->setAttr(name, layout);
+}
+
+void xegpu::setLayoutAttrs(Operation *op,
+                           function_ref<LayoutAttr(Value)> getLayoutImpl) {
+  op->walk([&](Operation *nestOp) {
+    for (OpOperand &opr : nestOp->getOpOperands()) {
+      auto layout = getLayoutImpl(opr.get());
+      setLayoutAttr(opr, layout);
+    }
+    for (OpResult result : nestOp->getOpResults()) {
+      auto layout = getLayoutImpl(result);
+      setLayoutAttr(result, layout);
+    }
+  });
+}
+
+SmallVector<Value>
+xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
+                                        Value value, ArrayRef<int64_t> shape) {
+  auto vecTy = dyn_cast<VectorType>(value.getType());
+  if (!vecTy)
+    return {value};
+
+  ArrayRef<int64_t> srcShape = vecTy.getShape();
+  if (!computeShapeRatio(srcShape, shape))
+    return {value};
+
+  SmallVector<Value> result;
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    result.push_back(builder.create<vector::ExtractStridedSliceOp>(
+        loc, value, offsets, shape, staticStrides));
+  }
+
+  return result;
+}
+
+Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
+                                             ValueRange values,
+                                             ArrayRef<int64_t> shape) {
+  VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
+  assert(llvm::all_of(values.getTypes(),
+                      [&](Type type) { return type == inputTy; }) &&
+         "values must be of the same VectorType");
+
+  Type elemTy = inputTy.getElementType();
+  ArrayRef<int64_t> tileShape = inputTy.getShape();
+
+  VectorType resultTy = VectorType::get(shape, elemTy);
+  auto zeroAttr = builder.getZeroAttr(elemTy);
+  Value result = builder.create<arith::ConstantOp>(
+      loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
+
+  for (auto [src, offsets] :
+       llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
+    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    result = builder.create<vector::InsertStridedSliceOp>(
+        loc, src, result, offsets, staticStrides);
+  }
+  return result;
+}
+
+void xegpu::doSCFStructuralTypeConversionWithTensorType(
+    Operation *op, TypeConverter converter) {
+  MLIRContext *context = op->getContext();
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+        .getResult(0);
+  };
+
+  { // convert VectorType to RankedTensorType for SCF Structural ops
+    TypeConverter converter;
+    converter.addConversion([](Type type) -> Type { return type; });
+    converter.addConversion([](VectorType type) -> Type {
+      return RankedTensorType::get(type.getShape(), type.getElementType());
+    });
+    converter.addSourceMaterialization(materializeCast);
+    converter.addTargetMaterialization(materializeCast);
+
+    mlir::ConversionTarget target(*context);
+    target.addLegalOp<UnrealizedConversionCastOp>();
+
+    mlir::RewritePatternSet patterns(context);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    (void)mlir::applyPartialConversion(op, target, std::move(patterns));
+  }
+
+  { // propagate the layout attribute to RankedTensorType by checking
+    // BuiltInUnrealizedCastOps
+    // for VectorType to RankedTensorType cast.
+    op->walk([](UnrealizedConversionCastOp castOp) {
+      if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
+        return WalkResult::skip();
+
+      Value input = castOp.getInputs()[0];
+      Value result = castOp.getResults()[0];
+      auto inputTy = dyn_cast<VectorType>(input.getType());
+      auto resultTy = dyn_cast<RankedTensorType>(result.getType());
+
+      // Only look at ops casting from VectorType to RankedTensorType
+      if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
+        return WalkResult::skip();
+
+      xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+      if (!layout)
+        return WalkResult::skip();
+
+      RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
+      result.setType(newTy);
+
+      // update the arguments if user is a LoopLike op.
+      for (OpOperand &use : result.getUses()) {
+        if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
+          BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
+          arg.setType(newTy);
+        }
+        // whileOp has two regions, the BlockArgument of the after region
+        // is not exposed by LoopLikeOpInterface
+        if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
+          unsigned idx = use.getOperandNumber();
+          BlockArgument arg = whileOp.getAfterArguments()[idx];
+          arg.setType(newTy);
+        }
+      }
+      return WalkResult::advance();
+    });
+
+    // using yieldOp as anchor to update the result type of its ParentOp
+    op->walk([](scf::YieldOp yieldOp) {
+      Operation *parentOp = yieldOp->getParentOp();
+      for (OpResult r : parentOp->getOpResults()) {
+        unsigned idx = r.getResultNumber();
+        Type resultTy = r.getType();
+        Type yieldTy = yieldOp.getResults()[idx].getType();
+        if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
+          r.setType(yieldTy);
+      }
+    });
+  }
+
+  { // perform the conversion from RankedTensorType to VectorType based on the
+    // LayoutAttr
+
+    // Handle the UnrealizedConversionCastOp introduced by the first step.
+    // For vector->RankedTensorType, it will simply forward the inputs.
+    // For RankedTensorType->vector, it will update the inputs with the
+    // one from the adaptor.
+    class UnrealizedConversionCastOpPattern
+        : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
+      using OpConversionPattern<
+          mlir::UnrealizedConversionCastOp>::OpConversionPattern;
+
+      mlir::LogicalResult
+      matchAndRewrite(mlir::UnrealizedConversionCastOp op,
+                      OneToNOpAdaptor adaptor,
+                      ConversionPatternRewriter &rewriter) const override {
+        auto inputs = op.getOperands();
+        auto outputs = op.getOutputs();
+
+        if (inputs.size() != 1 || outputs.size() != 1)
+          return failure();
+
+        auto inputTy = inputs[0].getType();
+        auto outputTy = outputs[0].getType();
+
+        if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
+          rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
+          return success();
+        }
+
+        if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
+          SmallVector<Value> values = flattenValues(adaptor.getInputs());
+          auto newOp = rewriter.create<UnrealizedConversionCastOp>(
+              op.getLoc(), outputTy, values);
+          rewriter.replaceOp(op, newOp);
+          return success();
+        }
+        return failure();
+      }
+    };
+
+    converter.addSourceMaterialization(materializeCast);
+    converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
+                                           ValueRange inputs, Location loc) {
+      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+          .getResults();
+    });
+
+    mlir::ConversionTarget target(*context);
+    target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+        [](UnrealizedConversionCastOp op) {
+          auto isTensorTy = [](Type type) {
+            return isa<RankedTensorType>(type);
+          };
+          return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
+                 llvm::none_of(op->getResultTypes(), isTensorTy);
+        });
+    mlir::RewritePatternSet patterns(context);
+    patterns.insert<UnrealizedConversionCastOpPattern>(context);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    (void)mlir::applyPartialConversion(op, target, std::move(patterns));
+  }
+}

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
new file mode 100644
index 0000000000000..f9114988686c8
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s
+
+#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
+#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm_with_one_to_n_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c16 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) {
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16>
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+      //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #c}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #a>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [16, 1]>>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #b>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>
+    }
+    //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+    xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
+    gpu.return
+  }
+}
+
+// -----
+#l1 = #xegpu.layout<inst_data = [8, 16]>
+#l2 = #xegpu.layout<inst_data = [16, 16]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm_with_inst_data_only_attribute(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c16 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #l1>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #l1> -> vector<16x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l1>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #l2>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<16x32xf16, #l1>, !xegpu.tensor_desc<32x32xf16, #l2>, vector<16x32xf32>) {
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l1> -> vector<16x32xf16>
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #l2> -> vector<32x32xf16>
+      //CHECK-COUNT-8: xegpu.dpas {{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l1>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #l2>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<16x32xf16, #l1>, !xegpu.tensor_desc<32x32xf16, #l2>, vector<16x32xf32>
+    }
+    //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #l1>
+    gpu.return
+  }
+}
+
+// -----
+#l1 = #xegpu.layout<inst_data = [8, 16]>
+#l2 = #xegpu.layout<inst_data = [16, 16]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm_with_one_to_one_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c8 = arith.constant 8 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c8 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x32xf32, #l1>
+
+    //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<8x32xf32, #l1> -> vector<8x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16, #l1>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l2>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c16
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>) {
+      //CHECK: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16, #l1> -> vector<8x16xf16>
+      //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l2> -> vector<16x32xf16>
+      %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<8x16xf16>, vector<16x32xf16>, vector<8x32xf32> -> vector<8x32xf32>
+      //CHECK: xegpu.update_nd_offset {{.*}} [%c0, %c32] : !xegpu.tensor_desc<8x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16, #l1>
+      //CHECK-COUNT-2: xegpu.update_nd_offset {{.*}} [%c32, %c0] : !xegpu.tensor_desc<16x16xf16>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<16x32xf16, #l2>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>
+    }
+    //CHECK-COUNT-2: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %out#2, %c_tdesc: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #l1>
+    gpu.return
+  }
+}
+
+// -----
+#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
+#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm_with_elemwise_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c16 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) {
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16>
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+      //CHECK-COUNT-4: math.exp {{.*}} : vector<8x16xf16>
+      %e = math.exp %a {layout_result_0 = #a} : vector<16x32xf16>
+      //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      %c = xegpu.dpas %e, %b, %arg2 {layout_result_0 = #c}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #a>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [16, 1]>>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #b>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>
+    }
+    //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+    xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [8, 16]>
+gpu.module @test_kernel {
+  gpu.func @test_elementwise_with_inst_data_only(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c32 : index
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
+
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
+      -> (!xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>) {
+      //CHECK-COUNT-8: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16>
+
+      //CHECK-COUNT-4: arith.addf {{.*}} : vector<8x16xf16>
+      %c = arith.addf %a, %b {layout_result_0 = #l} : vector<16x32xf16>
+
+      //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+      xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #l>
+
+      //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
+      %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
+        : !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>
+    }
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [8]>
+gpu.module @test_kernel {
+  gpu.func @test_elementwise_1D(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c32 : index
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
+
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
+      -> (!xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>) {
+      //CHECK-COUNT-8: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8xf16> -> vector<8xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16>
+
+      //CHECK-COUNT-4: arith.addf {{.*}} : vector<8xf16>
+      %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32xf16>
+
+      //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8xf16>, !xegpu.tensor_desc<8xf16>
+      xegpu.store_nd %c, %arg2: vector<32xf16>, !xegpu.tensor_desc<32xf16, #l>
+
+      //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c32] : !xegpu.tensor_desc<32xf16, #l>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32] : !xegpu.tensor_desc<32xf16, #l>
+      %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c32] : !xegpu.tensor_desc<32xf16, #l>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
+        : !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>
+    }
+    gpu.return
+  }
+}


        


More information about the Mlir-commits mailing list