[Mlir-commits] [mlir] 02d34d8 - [mlir][vector][xegpu] Vector to XeGPU conversion pass (#107419)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 19 13:16:27 PDT 2024


Author: Adam Siemieniuk
Date: 2024-09-19T15:16:23-05:00
New Revision: 02d34d800b94937c42fb7cff2db2b2836d918ac6

URL: https://github.com/llvm/llvm-project/commit/02d34d800b94937c42fb7cff2db2b2836d918ac6
DIFF: https://github.com/llvm/llvm-project/commit/02d34d800b94937c42fb7cff2db2b2836d918ac6.diff

LOG: [mlir][vector][xegpu] Vector to XeGPU conversion pass (#107419)

Add pass for Vector to XeGPU dialect conversion and initial conversion
patterns for vector.transfer_read|write operations.

Added: 
    mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h
    mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
    mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
    mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 208f26489d6c39..2ab32836c80b1c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -78,6 +78,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
+#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
 
 namespace mlir {
 

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2915cf7d5bb018..4d272ba219c6f1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1421,4 +1421,18 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToXeGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
+  let summary = "Lower the operations from the vector dialect into the XeGPU "
+                "dialect";
+  let constructor = "mlir::createConvertVectorToXeGPUPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "arith::ArithDialect",
+    "vector::VectorDialect", "xegpu::XeGPUDialect"
+  ];
+}
+
 #endif // MLIR_CONVERSION_PASSES

diff  --git a/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h b/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h
new file mode 100644
index 00000000000000..ac4915901fdeca
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h
@@ -0,0 +1,29 @@
+//===- VectorToXeGPU.h - Convert vector to XeGPU dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
+#define MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOXEGPU
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the vector to XeGPU ops.
+void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns);
+
+/// Create a pass to convert ops from vector to XeGPU.
+std::unique_ptr<Pass> createConvertVectorToXeGPUPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 813f700c5556e1..6651d87162257f 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -69,3 +69,4 @@ add_subdirectory(VectorToGPU)
 add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)
+add_subdirectory(VectorToXeGPU)

diff  --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..567083da002390
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRVectorToXeGPU
+  VectorToXeGPU.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToXeGPU
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRMemRefDialect
+  MLIRTransforms
+  MLIRVectorDialect
+  MLIRXeGPUDialect
+  )

diff  --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
new file mode 100644
index 00000000000000..be1581d619a8b1
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -0,0 +1,257 @@
+//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to XeGPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <algorithm>
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+static bool isZeroConstant(Value val) {
+  auto constant = val.getDefiningOp<arith::ConstantOp>();
+  if (!constant)
+    return false;
+
+  return TypeSwitch<Attribute, bool>(constant.getValue())
+      .Case<FloatAttr>(
+          [](auto floatAttr) { return floatAttr.getValue().isZero(); })
+      .Case<IntegerAttr>(
+          [](auto intAttr) { return intAttr.getValue().isZero(); })
+      .Default([](auto) { return false; });
+}
+
+static LogicalResult transferPreconditions(PatternRewriter &rewriter,
+                                           VectorTransferOpInterface xferOp) {
+  if (xferOp.getMask())
+    return rewriter.notifyMatchFailure(xferOp,
+                                       "Masked transfer is not supported");
+
+  auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
+  if (!srcTy)
+    return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
+  VectorType vecTy = xferOp.getVectorType();
+  unsigned vecRank = vecTy.getRank();
+  if (!(vecRank == 1 || vecRank == 2))
+    return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
+
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
+      strides.back() != 1)
+    return rewriter.notifyMatchFailure(
+        xferOp, "Buffer must be contiguous in the innermost dimension");
+
+  AffineMap map = xferOp.getPermutationMap();
+  if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
+    return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
+  unsigned numInputDims = map.getNumInputs();
+  for (AffineExpr expr : map.getResults().take_back(vecRank)) {
+    auto dim = dyn_cast<AffineDimExpr>(expr);
+    if (dim.getPosition() < (numInputDims - vecRank))
+      return rewriter.notifyMatchFailure(
+          xferOp, "Only the innermost dimensions can be accessed");
+  }
+
+  return success();
+}
+
+static xegpu::CreateNdDescOp
+createNdDescriptor(PatternRewriter &rewriter, Location loc,
+                   xegpu::TensorDescType descType, TypedValue<MemRefType> src,
+                   Operation::operand_range offsets) {
+  MemRefType srcTy = src.getType();
+  auto [strides, offset] = getStridesAndOffset(srcTy);
+
+  xegpu::CreateNdDescOp ndDesc;
+  if (srcTy.hasStaticShape()) {
+    ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
+                                                    getAsOpFoldResult(offsets));
+  } else {
+    // In case of any dynamic shapes, source's shape and strides have to be
+    // explicitly provided.
+    SmallVector<Value> sourceDims;
+    unsigned srcRank = srcTy.getRank();
+    for (unsigned i = 0; i < srcRank; ++i)
+      sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
+
+    SmallVector<int64_t> constOffsets;
+    SmallVector<Value> dynOffsets;
+    for (Value offset : offsets) {
+      std::optional<int64_t> staticVal = getConstantIntValue(offset);
+      if (!staticVal)
+        dynOffsets.push_back(offset);
+      constOffsets.push_back(staticVal ? *staticVal : ShapedType::kDynamic);
+    }
+
+    SmallVector<Value> dynShapes;
+    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+      if (shape == ShapedType::kDynamic)
+        dynShapes.push_back(sourceDims[idx]);
+    }
+
+    // Compute strides in reverse order.
+    SmallVector<Value> dynStrides;
+    Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    // Last stride is guaranteed to be static and unit.
+    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+      accStride =
+          rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
+      if (strides[i] == ShapedType::kDynamic)
+        dynStrides.push_back(accStride);
+    }
+    std::reverse(dynStrides.begin(), dynStrides.end());
+
+    ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
+        loc, descType, src, dynOffsets, dynShapes, dynStrides,
+        DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
+        DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
+        DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+  }
+
+  return ndDesc;
+}
+
+struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = readOp.getLoc();
+
+    if (failed(transferPreconditions(rewriter, readOp)))
+      return failure();
+
+    bool isOutOfBounds = readOp.hasOutOfBoundsDim();
+    if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
+      return rewriter.notifyMatchFailure(
+          readOp, "Unsupported non-zero padded out-of-bounds read");
+
+    AffineMap readMap = readOp.getPermutationMap();
+    bool isTransposeLoad = !readMap.isMinorIdentity();
+
+    VectorType vecTy = readOp.getVectorType();
+    Type elementType = vecTy.getElementType();
+    unsigned minTransposeBitWidth = 32;
+    if (isTransposeLoad &&
+        elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+      return rewriter.notifyMatchFailure(
+          readOp, "Unsupported data type for tranposition");
+
+    // If load is transposed, get the base shape for the tensor descriptor.
+    SmallVector<int64_t> descShape{vecTy.getShape()};
+    if (isTransposeLoad)
+      std::reverse(descShape.begin(), descShape.end());
+    auto descType = xegpu::TensorDescType::get(
+        descShape, elementType, /*scattered=*/false, /*array_length=*/1,
+        xegpu::MemoryScope::Global,
+        /*boundary_check=*/isOutOfBounds);
+
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType,
+                           dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
+                           readOp.getIndices());
+
+    DenseI64ArrayAttr transposeAttr =
+        !isTransposeLoad ? nullptr
+                         : DenseI64ArrayAttr::get(rewriter.getContext(),
+                                                  ArrayRef<int64_t>{1, 0});
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
+    auto loadOp = rewriter.create<xegpu::LoadNdOp>(
+        loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    rewriter.replaceOp(readOp, loadOp);
+
+    return success();
+  }
+};
+
+struct TransferWriteLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = writeOp.getLoc();
+
+    if (failed(transferPreconditions(rewriter, writeOp)))
+      return failure();
+
+    if (writeOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "Unsupported out-of-bounds write");
+    AffineMap map = writeOp.getPermutationMap();
+    if (!map.isMinorIdentity())
+      return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
+
+    VectorType vecTy = writeOp.getVectorType();
+    auto descType = xegpu::TensorDescType::get(
+        vecTy.getShape(), vecTy.getElementType(),
+        /*scattered=*/false, /*array_length=*/1, xegpu::MemoryScope::Global,
+        /*boundary_check=*/false);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType,
+        dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
+        writeOp.getIndices());
+
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
+    auto storeOp =
+        rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
+                                          /*l1_hint=*/hint,
+                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    rewriter.replaceOp(writeOp, storeOp);
+
+    return success();
+  }
+};
+
+struct ConvertVectorToXeGPUPass
+    : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToXeGPUConversionPatterns(patterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::populateVectorToXeGPUConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferReadLowering, TransferWriteLowering>(
+      patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
+  return std::make_unique<ConvertVectorToXeGPUPass>();
+}

diff  --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
new file mode 100644
index 00000000000000..4841ecbb62e807
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -0,0 +1,200 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+    {in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @load_1D_vector(
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    boundary_check = false
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_2D_vector(%source: memref<8x16x32xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+    {in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_2D_vector(
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = false
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset], %c0
+    {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_zero_pad_out_of_bounds(
+// CHECK-SAME:  %[[SRC:.+]]: memref<32x64xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_transposed(%source: memref<32x64xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset], %c0
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+    in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_transposed(
+// CHECK-SAME:  %[[SRC:.+]]: memref<32x64xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// CHECK-SAME:    -> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+    {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_dynamic_source(
+// CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
+    %offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) {
+  %c1 = arith.constant 1.0 : f32
+  %0 = vector.transfer_read %source[%offset, %arg2], %c1
+    {in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32>
+  %1 = vector.transfer_read %source[%arg2, %offset], %pad
+    {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
+  return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
+}
+
+// CHECK-LABEL:   @no_load_out_of_bounds_non_zero_pad(
+// CHECK-COUNT-2: vector.transfer_read
+
+// -----
+
+func.func @no_load_masked(%source : memref<4xf32>,
+    %offset : index) -> vector<4xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
+  %0 = vector.transfer_read %source[%offset], %c0, %mask
+    {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @no_load_masked(
+// CHECK:       vector.transfer_read
+
+// -----
+
+func.func @no_load_tensor(%source: tensor<32x64xf32>,
+    %offset: index, %arg2: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %arg2], %c0
+    {in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @no_load_tensor(
+// CHECK:       vector.transfer_read
+
+// -----
+
+func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
+    %offset: index, %arg2: index) -> vector<8x16x32xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
+    {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
+  return %0 : vector<8x16x32xf32>
+}
+
+// CHECK-LABEL: @no_load_high_dim_vector(
+// CHECK:       vector.transfer_read
+
+// -----
+
+func.func @no_load_non_unit_inner_stride(
+    %source: memref<32xf32, strided<[?], offset: ?>>,
+    %offset: index) -> vector<8xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]}
+    : memref<32xf32, strided<[?], offset: ?>>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @no_load_non_unit_inner_stride(
+// CHECK:       vector.transfer_read
+
+// -----
+
+func.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+    {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
+    in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @no_load_unsupported_map(
+// CHECK:       vector.transfer_read
+
+// -----
+
+func.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
+    %offset: index) -> vector<8x16xf16> {
+  %c0 = arith.constant 0.0 : f16
+  %0 = vector.transfer_read %source[%offset, %offset], %c0
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+    in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
+  return %0 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @no_load_transpose_unsupported_data_type(
+// CHECK:       vector.transfer_read

diff  --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
new file mode 100644
index 00000000000000..361919c47b097d
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @store_1D_vector(%vec: vector<8xf32>,
+    %source: memref<8x16x32xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    {in_bounds = [true]}
+    : vector<8xf32>, memref<8x16x32xf32>
+  return
+}
+
+// CHECK-LABEL: @store_1D_vector(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    boundary_check = false
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+
+// -----
+
+func.func @store_2D_vector(%vec: vector<8x16xf32>,
+    %source: memref<8x16x32xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    {in_bounds = [true, true]}
+    : vector<8x16xf32>, memref<8x16x32xf32>
+  return
+}
+
+// CHECK-LABEL: @store_2D_vector(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = false
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @store_dynamic_source(%vec: vector<8x16xf32>,
+    %source: memref<?x?x?xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    {in_bounds = [true, true]}
+    : vector<8x16xf32>, memref<?x?x?xf32>
+  return
+}
+
+// CHECK-LABEL: @store_dynamic_source(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @no_store_transposed(%vec: vector<8x16xf32>,
+    %source: memref<32x64xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset]
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+    in_bounds = [true, true]}
+    : vector<8x16xf32>, memref<32x64xf32>
+  return
+}
+
+// CHECK-LABEL: @no_store_transposed(
+// CHECK:       vector.transfer_write
+
+// -----
+
+func.func @no_store_out_of_bounds(%vec: vector<8x16xf32>,
+    %source: memref<32x64xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset]
+    {in_bounds = [false, true]}
+    : vector<8x16xf32>, memref<32x64xf32>
+  return
+}
+
+// CHECK-LABEL:   @no_store_out_of_bounds(
+// CHECK:         vector.transfer_write
+
+// -----
+
+func.func @no_store_masked(%vec: vector<4xf32>,
+    %source: memref<4xf32>, %offset: index) {
+  %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
+  vector.transfer_write %vec, %source[%offset], %mask
+    {in_bounds = [true]}
+    : vector<4xf32>, memref<4xf32>
+  return
+}
+
+// CHECK-LABEL: @no_store_masked(
+// CHECK:       vector.transfer_write
+
+// -----
+
+func.func @no_store_tensor(%vec: vector<8x16xf32>,
+    %source: tensor<32x64xf32>, %offset: index) -> tensor<32x64xf32> {
+  %0 = vector.transfer_write %vec, %source[%offset, %offset]
+    {in_bounds = [true, true]}
+    : vector<8x16xf32>, tensor<32x64xf32>
+  return %0 : tensor<32x64xf32>
+}
+
+// CHECK-LABEL: @no_store_tensor(
+// CHECK:       vector.transfer_write
+
+// -----
+
+func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
+    %source: memref<16x32x64xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    {in_bounds = [true, true, true]}
+    : vector<8x16x32xf32>, memref<16x32x64xf32>
+  return
+}
+
+// CHECK-LABEL: @no_store_high_dim_vector(
+// CHECK:       vector.transfer_write
+
+// -----
+
+func.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
+    %source: memref<32xf32, strided<[?], offset: ?>>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset]
+    {in_bounds = [true]}
+    : vector<8xf32>, memref<32xf32, strided<[?], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: @no_store_non_unit_inner_stride(
+// CHECK:       vector.transfer_write
+
+// -----
+
+func.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
+    %source: memref<16x32x64xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
+    in_bounds = [true, true]}
+    : vector<8x16xf32>, memref<16x32x64xf32>
+  return
+}
+
+// CHECK-LABEL: @no_store_unsupported_map(
+// CHECK:       vector.transfer_write


        


More information about the Mlir-commits mailing list