[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector load and store to XeGPU (PR #110826)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed Oct 2 04:11:23 PDT 2024
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/110826
Adds patterns to lower vector.load|store to XeGPU operations.
>From 81e28a345457db829f9b5e2482a7e642b3bd1ea1 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 2 Oct 2024 12:56:31 +0200
Subject: [PATCH] [mlir][xegpu] Convert Vector load and store to XeGPU
Adds patterns to lower vector.load|store to XeGPU operations.
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 85 +++++++++++++-
.../VectorToXeGPU/load-to-xegpu.mlir | 105 +++++++++++++++++
.../VectorToXeGPU/store-to-xegpu.mlir | 106 ++++++++++++++++++
3 files changed, 291 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
create mode 100644 mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0e21e96cc3fbb9..e9acda657b3c27 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -33,6 +33,7 @@ using namespace mlir;
namespace {
+// Return true if value represents a zero constant.
static bool isZeroConstant(Value val) {
auto constant = val.getDefiningOp<arith::ConstantOp>();
if (!constant)
@@ -46,6 +47,17 @@ static bool isZeroConstant(Value val) {
.Default([](auto) { return false; });
}
+static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
+ Operation *op, VectorType vecTy) {
+ // Validate only vector as the basic vector store and load ops guarantee
+ // XeGPU-compatible memref source.
+ unsigned vecRank = vecTy.getRank();
+ if (!(vecRank == 1 || vecRank == 2))
+ return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
+
+ return success();
+}
+
static LogicalResult transferPreconditions(PatternRewriter &rewriter,
VectorTransferOpInterface xferOp) {
if (xferOp.getMask())
@@ -55,11 +67,13 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
if (!srcTy)
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
+
+ // Perform common data transfer checks.
VectorType vecTy = xferOp.getVectorType();
- unsigned vecRank = vecTy.getRank();
- if (!(vecRank == 1 || vecRank == 2))
- return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
+ if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
+ return failure();
+ // Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
@@ -67,6 +81,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");
+ unsigned vecRank = vecTy.getRank();
AffineMap map = xferOp.getPermutationMap();
if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
@@ -232,6 +247,66 @@ struct TransferWriteLowering
}
};
+struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
+ using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = loadOp.getLoc();
+
+ VectorType vecTy = loadOp.getResult().getType();
+ if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
+ return failure();
+
+ auto descType = xegpu::TensorDescType::get(
+ vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
+ /*boundary_check=*/true, xegpu::MemorySpace::Global);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
+
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+ auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
+ loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ rewriter.replaceOp(loadOp, loadNdOp);
+
+ return success();
+ }
+};
+
+struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
+ using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = storeOp.getLoc();
+
+ TypedValue<VectorType> vector = storeOp.getValueToStore();
+ VectorType vecTy = vector.getType();
+ if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
+ return failure();
+
+ auto descType =
+ xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
+ /*array_length=*/1, /*boundary_check=*/true,
+ xegpu::MemorySpace::Global);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
+
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+ auto storeNdOp =
+ rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ rewriter.replaceOp(storeOp, storeNdOp);
+
+ return success();
+ }
+};
+
struct ConvertVectorToXeGPUPass
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
void runOnOperation() override {
@@ -247,8 +322,8 @@ struct ConvertVectorToXeGPUPass
void mlir::populateVectorToXeGPUConversionPatterns(
RewritePatternSet &patterns) {
- patterns.add<TransferReadLowering, TransferWriteLowering>(
- patterns.getContext());
+ patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+ StoreLowering>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
new file mode 100644
index 00000000000000..e2a506f8ad5abd
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
+ %0 = vector.load %source[%offset, %offset, %offset]
+ : memref<8x16x32xf32>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @load_1D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_2D_vector(%source: memref<8x16x32xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %0 = vector.load %source[%offset, %offset, %offset]
+ : memref<8x16x32xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_2D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %0 = vector.load %source[%offset, %offset, %offset]
+ : memref<?x?x?xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_dynamic_source(
+// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @load_out_of_bounds(%source: memref<7x15xf32>,
+ %offset: index) -> vector<8x16xf32> {
+ %0 = vector.load %source[%offset, %offset]
+ : memref<7x15xf32>, vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_out_of_bounds(
+// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: return %[[VEC]]
+
+// -----
+
+func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
+ %offset: index) -> vector<8x16x32xf32> {
+ %0 = vector.load %source[%offset, %offset, %offset]
+ : memref<16x32x64xf32>, vector<8x16x32xf32>
+ return %0 : vector<8x16x32xf32>
+}
+
+// CHECK-LABEL: @no_load_high_dim_vector(
+// CHECK: vector.load
+
+// -----
+
+func.func @no_load_zero_dim_vector(%source: memref<64xf32>,
+ %offset: index) -> vector<f32> {
+ %0 = vector.load %source[%offset]
+ : memref<64xf32>, vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @no_load_zero_dim_vector(
+// CHECK: vector.load
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
new file mode 100644
index 00000000000000..3d45407c2486d6
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @store_1D_vector(%vec: vector<8xf32>,
+ %source: memref<8x16x32xf32>, %offset: index) {
+ vector.store %vec, %source[%offset, %offset, %offset]
+ : memref<8x16x32xf32>, vector<8xf32>
+ return
+}
+
+// CHECK-LABEL: @store_1D_vector(
+// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+
+// -----
+
+func.func @store_2D_vector(%vec: vector<8x16xf32>,
+ %source: memref<8x16x32xf32>, %offset: index) {
+ vector.store %vec, %source[%offset, %offset, %offset]
+ : memref<8x16x32xf32>, vector<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: @store_2D_vector(
+// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @store_dynamic_source(%vec: vector<8x16xf32>,
+ %source: memref<?x?x?xf32>, %offset: index) {
+ vector.store %vec, %source[%offset, %offset, %offset]
+ : memref<?x?x?xf32>, vector<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: @store_dynamic_source(
+// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
+ %source: memref<7x64xf32>, %offset: index) {
+ vector.store %vec, %source[%offset, %offset]
+ : memref<7x64xf32>, vector<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: @store_out_of_bounds(
+// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
+// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME: boundary_check = true
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
+ %source: memref<16x32x64xf32>, %offset: index) {
+ vector.store %vec, %source[%offset, %offset, %offset]
+ : memref<16x32x64xf32>, vector<8x16x32xf32>
+ return
+}
+
+// CHECK-LABEL: @no_store_high_dim_vector(
+// CHECK: vector.store
+
+// -----
+
+func.func @no_store_zero_dim_vector(%vec: vector<f32>,
+ %source: memref<64xf32>, %offset: index) {
+ vector.store %vec, %source[%offset]
+ : memref<64xf32>, vector<f32>
+ return
+}
+
+// CHECK-LABEL: @no_store_zero_dim_vector(
+// CHECK: vector.store
More information about the Mlir-commits
mailing list