[Mlir-commits] [mlir] 02d34d8 - [mlir][vector][xegpu] Vector to XeGPU conversion pass (#107419)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 19 13:16:27 PDT 2024
Author: Adam Siemieniuk
Date: 2024-09-19T15:16:23-05:00
New Revision: 02d34d800b94937c42fb7cff2db2b2836d918ac6
URL: https://github.com/llvm/llvm-project/commit/02d34d800b94937c42fb7cff2db2b2836d918ac6
DIFF: https://github.com/llvm/llvm-project/commit/02d34d800b94937c42fb7cff2db2b2836d918ac6.diff
LOG: [mlir][vector][xegpu] Vector to XeGPU conversion pass (#107419)
Add pass for Vector to XeGPU dialect conversion and initial conversion
patterns for vector.transfer_read|write operations.
Added:
mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h
mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 208f26489d6c39..2ab32836c80b1c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -78,6 +78,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
+#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2915cf7d5bb018..4d272ba219c6f1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1421,4 +1421,18 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
+//===----------------------------------------------------------------------===//
+// VectorToXeGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
+ let summary = "Lower the operations from the vector dialect into the XeGPU "
+ "dialect";
+ let constructor = "mlir::createConvertVectorToXeGPUPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect", "arith::ArithDialect",
+ "vector::VectorDialect", "xegpu::XeGPUDialect"
+ ];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h b/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h
new file mode 100644
index 00000000000000..ac4915901fdeca
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h
@@ -0,0 +1,29 @@
+//===- VectorToXeGPU.h - Convert vector to XeGPU dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
+#define MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOXEGPU
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the vector to XeGPU ops.
+void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns);
+
+/// Create a pass to convert ops from vector to XeGPU.
+std::unique_ptr<Pass> createConvertVectorToXeGPUPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 813f700c5556e1..6651d87162257f 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -69,3 +69,4 @@ add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
+add_subdirectory(VectorToXeGPU)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..567083da002390
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRVectorToXeGPU
+ VectorToXeGPU.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToXeGPU
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRMemRefDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ MLIRXeGPUDialect
+ )
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
new file mode 100644
index 00000000000000..be1581d619a8b1
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -0,0 +1,257 @@
+//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to XeGPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <algorithm>
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+static bool isZeroConstant(Value val) {
+ auto constant = val.getDefiningOp<arith::ConstantOp>();
+ if (!constant)
+ return false;
+
+ return TypeSwitch<Attribute, bool>(constant.getValue())
+ .Case<FloatAttr>(
+ [](auto floatAttr) { return floatAttr.getValue().isZero(); })
+ .Case<IntegerAttr>(
+ [](auto intAttr) { return intAttr.getValue().isZero(); })
+ .Default([](auto) { return false; });
+}
+
+static LogicalResult transferPreconditions(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp) {
+ if (xferOp.getMask())
+ return rewriter.notifyMatchFailure(xferOp,
+ "Masked transfer is not supported");
+
+ auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
+ if (!srcTy)
+ return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
+ VectorType vecTy = xferOp.getVectorType();
+ unsigned vecRank = vecTy.getRank();
+ if (!(vecRank == 1 || vecRank == 2))
+ return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
+
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
+ strides.back() != 1)
+ return rewriter.notifyMatchFailure(
+ xferOp, "Buffer must be contiguous in the innermost dimension");
+
+ AffineMap map = xferOp.getPermutationMap();
+ if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
+ return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
+ unsigned numInputDims = map.getNumInputs();
+ for (AffineExpr expr : map.getResults().take_back(vecRank)) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (dim.getPosition() < (numInputDims - vecRank))
+ return rewriter.notifyMatchFailure(
+ xferOp, "Only the innermost dimensions can be accessed");
+ }
+
+ return success();
+}
+
+static xegpu::CreateNdDescOp
+createNdDescriptor(PatternRewriter &rewriter, Location loc,
+ xegpu::TensorDescType descType, TypedValue<MemRefType> src,
+ Operation::operand_range offsets) {
+ MemRefType srcTy = src.getType();
+ auto [strides, offset] = getStridesAndOffset(srcTy);
+
+ xegpu::CreateNdDescOp ndDesc;
+ if (srcTy.hasStaticShape()) {
+ ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
+ getAsOpFoldResult(offsets));
+ } else {
+ // In case of any dynamic shapes, source's shape and strides have to be
+ // explicitly provided.
+ SmallVector<Value> sourceDims;
+ unsigned srcRank = srcTy.getRank();
+ for (unsigned i = 0; i < srcRank; ++i)
+ sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
+
+ SmallVector<int64_t> constOffsets;
+ SmallVector<Value> dynOffsets;
+ for (Value offset : offsets) {
+ std::optional<int64_t> staticVal = getConstantIntValue(offset);
+ if (!staticVal)
+ dynOffsets.push_back(offset);
+ constOffsets.push_back(staticVal ? *staticVal : ShapedType::kDynamic);
+ }
+
+ SmallVector<Value> dynShapes;
+ for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+ if (shape == ShapedType::kDynamic)
+ dynShapes.push_back(sourceDims[idx]);
+ }
+
+ // Compute strides in reverse order.
+ SmallVector<Value> dynStrides;
+ Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ // Last stride is guaranteed to be static and unit.
+ for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+ accStride =
+ rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
+ if (strides[i] == ShapedType::kDynamic)
+ dynStrides.push_back(accStride);
+ }
+ std::reverse(dynStrides.begin(), dynStrides.end());
+
+ ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
+ loc, descType, src, dynOffsets, dynShapes, dynStrides,
+ DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
+ DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
+ DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ }
+
+ return ndDesc;
+}
+
+struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = readOp.getLoc();
+
+ if (failed(transferPreconditions(rewriter, readOp)))
+ return failure();
+
+ bool isOutOfBounds = readOp.hasOutOfBoundsDim();
+ if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
+ return rewriter.notifyMatchFailure(
+ readOp, "Unsupported non-zero padded out-of-bounds read");
+
+ AffineMap readMap = readOp.getPermutationMap();
+ bool isTransposeLoad = !readMap.isMinorIdentity();
+
+ VectorType vecTy = readOp.getVectorType();
+ Type elementType = vecTy.getElementType();
+ unsigned minTransposeBitWidth = 32;
+ if (isTransposeLoad &&
+ elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+ return rewriter.notifyMatchFailure(
+ readOp, "Unsupported data type for tranposition");
+
+ // If load is transposed, get the base shape for the tensor descriptor.
+ SmallVector<int64_t> descShape{vecTy.getShape()};
+ if (isTransposeLoad)
+ std::reverse(descShape.begin(), descShape.end());
+ auto descType = xegpu::TensorDescType::get(
+ descShape, elementType, /*scattered=*/false, /*array_length=*/1,
+ xegpu::MemoryScope::Global,
+ /*boundary_check=*/isOutOfBounds);
+
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
+ readOp.getIndices());
+
+ DenseI64ArrayAttr transposeAttr =
+ !isTransposeLoad ? nullptr
+ : DenseI64ArrayAttr::get(rewriter.getContext(),
+ ArrayRef<int64_t>{1, 0});
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+ auto loadOp = rewriter.create<xegpu::LoadNdOp>(
+ loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ rewriter.replaceOp(readOp, loadOp);
+
+ return success();
+ }
+};
+
+struct TransferWriteLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = writeOp.getLoc();
+
+ if (failed(transferPreconditions(rewriter, writeOp)))
+ return failure();
+
+ if (writeOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(writeOp,
+ "Unsupported out-of-bounds write");
+ AffineMap map = writeOp.getPermutationMap();
+ if (!map.isMinorIdentity())
+ return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
+
+ VectorType vecTy = writeOp.getVectorType();
+ auto descType = xegpu::TensorDescType::get(
+ vecTy.getShape(), vecTy.getElementType(),
+ /*scattered=*/false, /*array_length=*/1, xegpu::MemoryScope::Global,
+ /*boundary_check=*/false);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
+ writeOp.getIndices());
+
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+ auto storeOp =
+ rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ rewriter.replaceOp(writeOp, storeOp);
+
+ return success();
+ }
+};
+
+struct ConvertVectorToXeGPUPass
+ : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToXeGPUConversionPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToXeGPUConversionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<TransferReadLowering, TransferWriteLowering>(
+ patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
+ return std::make_unique<ConvertVectorToXeGPUPass>();
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
new file mode 100644
index 00000000000000..4841ecbb62e807
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -0,0 +1,200 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+ {in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @load_1D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: boundary_check = false
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_2D_vector(%source: memref<8x16x32xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+ {in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_2D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = false
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset], %c0
+ {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_zero_pad_out_of_bounds(
+// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_transposed(%source: memref<32x64xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset], %c0
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+ in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_transposed(
+// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// CHECK-SAME: -> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+ {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_dynamic_source(
+// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
+ %offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) {
+ %c1 = arith.constant 1.0 : f32
+ %0 = vector.transfer_read %source[%offset, %arg2], %c1
+ {in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32>
+ %1 = vector.transfer_read %source[%arg2, %offset], %pad
+ {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
+ return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
+}
+
+// CHECK-LABEL: @no_load_out_of_bounds_non_zero_pad(
+// CHECK-COUNT-2: vector.transfer_read
+
+// -----
+
+func.func @no_load_masked(%source : memref<4xf32>,
+ %offset : index) -> vector<4xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
+ %0 = vector.transfer_read %source[%offset], %c0, %mask
+ {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @no_load_masked(
+// CHECK: vector.transfer_read
+
+// -----
+
+func.func @no_load_tensor(%source: tensor<32x64xf32>,
+ %offset: index, %arg2: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %arg2], %c0
+ {in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @no_load_tensor(
+// CHECK: vector.transfer_read
+
+// -----
+
+func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
+ %offset: index, %arg2: index) -> vector<8x16x32xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
+ {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
+ return %0 : vector<8x16x32xf32>
+}
+
+// CHECK-LABEL: @no_load_high_dim_vector(
+// CHECK: vector.transfer_read
+
+// -----
+
+func.func @no_load_non_unit_inner_stride(
+ %source: memref<32xf32, strided<[?], offset: ?>>,
+ %offset: index) -> vector<8xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]}
+ : memref<32xf32, strided<[?], offset: ?>>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @no_load_non_unit_inner_stride(
+// CHECK: vector.transfer_read
+
+// -----
+
+func.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+ {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
+ in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @no_load_unsupported_map(
+// CHECK: vector.transfer_read
+
+// -----
+
+func.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
+ %offset: index) -> vector<8x16xf16> {
+ %c0 = arith.constant 0.0 : f16
+ %0 = vector.transfer_read %source[%offset, %offset], %c0
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+ in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
+ return %0 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @no_load_transpose_unsupported_data_type(
+// CHECK: vector.transfer_read
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
new file mode 100644
index 00000000000000..361919c47b097d
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @store_1D_vector(%vec: vector<8xf32>,
+ %source: memref<8x16x32xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ {in_bounds = [true]}
+ : vector<8xf32>, memref<8x16x32xf32>
+ return
+}
+
+// CHECK-LABEL: @store_1D_vector(
+// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: boundary_check = false
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+
+// -----
+
+func.func @store_2D_vector(%vec: vector<8x16xf32>,
+ %source: memref<8x16x32xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ {in_bounds = [true, true]}
+ : vector<8x16xf32>, memref<8x16x32xf32>
+ return
+}
+
+// CHECK-LABEL: @store_2D_vector(
+// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = false
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @store_dynamic_source(%vec: vector<8x16xf32>,
+ %source: memref<?x?x?xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ {in_bounds = [true, true]}
+ : vector<8x16xf32>, memref<?x?x?xf32>
+ return
+}
+
+// CHECK-LABEL: @store_dynamic_source(
+// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @no_store_transposed(%vec: vector<8x16xf32>,
+ %source: memref<32x64xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset]
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+ in_bounds = [true, true]}
+ : vector<8x16xf32>, memref<32x64xf32>
+ return
+}
+
+// CHECK-LABEL: @no_store_transposed(
+// CHECK: vector.transfer_write
+
+// -----
+
+func.func @no_store_out_of_bounds(%vec: vector<8x16xf32>,
+ %source: memref<32x64xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset]
+ {in_bounds = [false, true]}
+ : vector<8x16xf32>, memref<32x64xf32>
+ return
+}
+
+// CHECK-LABEL: @no_store_out_of_bounds(
+// CHECK: vector.transfer_write
+
+// -----
+
+func.func @no_store_masked(%vec: vector<4xf32>,
+ %source: memref<4xf32>, %offset: index) {
+ %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
+ vector.transfer_write %vec, %source[%offset], %mask
+ {in_bounds = [true]}
+ : vector<4xf32>, memref<4xf32>
+ return
+}
+
+// CHECK-LABEL: @no_store_masked(
+// CHECK: vector.transfer_write
+
+// -----
+
+func.func @no_store_tensor(%vec: vector<8x16xf32>,
+ %source: tensor<32x64xf32>, %offset: index) -> tensor<32x64xf32> {
+ %0 = vector.transfer_write %vec, %source[%offset, %offset]
+ {in_bounds = [true, true]}
+ : vector<8x16xf32>, tensor<32x64xf32>
+ return %0 : tensor<32x64xf32>
+}
+
+// CHECK-LABEL: @no_store_tensor(
+// CHECK: vector.transfer_write
+
+// -----
+
+func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
+ %source: memref<16x32x64xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ {in_bounds = [true, true, true]}
+ : vector<8x16x32xf32>, memref<16x32x64xf32>
+ return
+}
+
+// CHECK-LABEL: @no_store_high_dim_vector(
+// CHECK: vector.transfer_write
+
+// -----
+
+func.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
+ %source: memref<32xf32, strided<[?], offset: ?>>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset]
+ {in_bounds = [true]}
+ : vector<8xf32>, memref<32xf32, strided<[?], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: @no_store_non_unit_inner_stride(
+// CHECK: vector.transfer_write
+
+// -----
+
+func.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
+ %source: memref<16x32x64xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
+ in_bounds = [true, true]}
+ : vector<8x16xf32>, memref<16x32x64xf32>
+ return
+}
+
+// CHECK-LABEL: @no_store_unsupported_map(
+// CHECK: vector.transfer_write
More information about the Mlir-commits
mailing list