[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector load and store to XeGPU (PR #110826)

Adam Siemieniuk llvmlistbot at llvm.org
Wed Oct 2 04:11:23 PDT 2024


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/110826

Adds patterns to lower vector.load|store to XeGPU operations.

>From 81e28a345457db829f9b5e2482a7e642b3bd1ea1 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 2 Oct 2024 12:56:31 +0200
Subject: [PATCH] [mlir][xegpu] Convert Vector load and store to XeGPU

Adds patterns to lower vector.load|store to XeGPU operations.
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           |  85 +++++++++++++-
 .../VectorToXeGPU/load-to-xegpu.mlir          | 105 +++++++++++++++++
 .../VectorToXeGPU/store-to-xegpu.mlir         | 106 ++++++++++++++++++
 3 files changed, 291 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
 create mode 100644 mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0e21e96cc3fbb9..e9acda657b3c27 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -33,6 +33,7 @@ using namespace mlir;
 
 namespace {
 
+// Return true if value represents a zero constant.
 static bool isZeroConstant(Value val) {
   auto constant = val.getDefiningOp<arith::ConstantOp>();
   if (!constant)
@@ -46,6 +47,17 @@ static bool isZeroConstant(Value val) {
       .Default([](auto) { return false; });
 }
 
+static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
+                                            Operation *op, VectorType vecTy) {
+  // Validate only vector as the basic vector store and load ops guarantee
+  // XeGPU-compatible memref source.
+  unsigned vecRank = vecTy.getRank();
+  if (!(vecRank == 1 || vecRank == 2))
+    return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
+
+  return success();
+}
+
 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
                                            VectorTransferOpInterface xferOp) {
   if (xferOp.getMask())
@@ -55,11 +67,13 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
   if (!srcTy)
     return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
+
+  // Perform common data transfer checks.
   VectorType vecTy = xferOp.getVectorType();
-  unsigned vecRank = vecTy.getRank();
-  if (!(vecRank == 1 || vecRank == 2))
-    return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
+  if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
+    return failure();
 
+  // Validate further transfer op semantics.
   SmallVector<int64_t> strides;
   int64_t offset;
   if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
@@ -67,6 +81,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(
         xferOp, "Buffer must be contiguous in the innermost dimension");
 
+  unsigned vecRank = vecTy.getRank();
   AffineMap map = xferOp.getPermutationMap();
   if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
     return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
@@ -232,6 +247,66 @@ struct TransferWriteLowering
   }
 };
 
+struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
+  using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = loadOp.getLoc();
+
+    VectorType vecTy = loadOp.getResult().getType();
+    if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
+      return failure();
+
+    auto descType = xegpu::TensorDescType::get(
+        vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
+        /*boundary_check=*/true, xegpu::MemorySpace::Global);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
+
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
+    auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
+        loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    rewriter.replaceOp(loadOp, loadNdOp);
+
+    return success();
+  }
+};
+
+struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
+  using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = storeOp.getLoc();
+
+    TypedValue<VectorType> vector = storeOp.getValueToStore();
+    VectorType vecTy = vector.getType();
+    if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
+      return failure();
+
+    auto descType =
+        xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
+                                   /*array_length=*/1, /*boundary_check=*/true,
+                                   xegpu::MemorySpace::Global);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
+
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
+    auto storeNdOp =
+        rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
+                                          /*l1_hint=*/hint,
+                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    rewriter.replaceOp(storeOp, storeNdOp);
+
+    return success();
+  }
+};
+
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
@@ -247,8 +322,8 @@ struct ConvertVectorToXeGPUPass
 
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<TransferReadLowering, TransferWriteLowering>(
-      patterns.getContext());
+  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+               StoreLowering>(patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
new file mode 100644
index 00000000000000..e2a506f8ad5abd
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
+  %0 = vector.load %source[%offset, %offset, %offset]
+    : memref<8x16x32xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @load_1D_vector(
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_2D_vector(%source: memref<8x16x32xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %0 = vector.load %source[%offset, %offset, %offset]
+    : memref<8x16x32xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_2D_vector(
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %0 = vector.load %source[%offset, %offset, %offset]
+    : memref<?x?x?xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_dynamic_source(
+// CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @load_out_of_bounds(%source: memref<7x15xf32>,
+    %offset: index) -> vector<8x16xf32> {
+  %0 = vector.load %source[%offset, %offset]
+    : memref<7x15xf32>, vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @load_out_of_bounds(
+// CHECK-SAME:  %[[SRC:.+]]: memref<7x15xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[VEC]]
+
+// -----
+
+func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
+    %offset: index) -> vector<8x16x32xf32> {
+  %0 = vector.load %source[%offset, %offset, %offset]
+    : memref<16x32x64xf32>, vector<8x16x32xf32>
+  return %0 : vector<8x16x32xf32>
+}
+
+// CHECK-LABEL: @no_load_high_dim_vector(
+// CHECK:       vector.load
+
+// -----
+
+func.func @no_load_zero_dim_vector(%source: memref<64xf32>,
+    %offset: index) -> vector<f32> {
+  %0 = vector.load %source[%offset]
+    : memref<64xf32>, vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @no_load_zero_dim_vector(
+// CHECK:       vector.load
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
new file mode 100644
index 00000000000000..3d45407c2486d6
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+func.func @store_1D_vector(%vec: vector<8xf32>,
+    %source: memref<8x16x32xf32>, %offset: index) {
+  vector.store %vec, %source[%offset, %offset, %offset]
+    : memref<8x16x32xf32>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: @store_1D_vector(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+
+// -----
+
+func.func @store_2D_vector(%vec: vector<8x16xf32>,
+    %source: memref<8x16x32xf32>, %offset: index) {
+  vector.store %vec, %source[%offset, %offset, %offset]
+    : memref<8x16x32xf32>, vector<8x16xf32>
+  return
+}
+
+// CHECK-LABEL: @store_2D_vector(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @store_dynamic_source(%vec: vector<8x16xf32>,
+    %source: memref<?x?x?xf32>, %offset: index) {
+  vector.store %vec, %source[%offset, %offset, %offset]
+    : memref<?x?x?xf32>, vector<8x16xf32>
+  return
+}
+
+// CHECK-LABEL: @store_dynamic_source(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
+    %source: memref<7x64xf32>, %offset: index) {
+  vector.store %vec, %source[%offset, %offset]
+    : memref<7x64xf32>, vector<8x16xf32>
+  return
+}
+
+// CHECK-LABEL:   @store_out_of_bounds(
+// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
+// CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    boundary_check = true
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// -----
+
+func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
+    %source: memref<16x32x64xf32>, %offset: index) {
+  vector.store %vec, %source[%offset, %offset, %offset]
+    : memref<16x32x64xf32>, vector<8x16x32xf32>
+  return
+}
+
+// CHECK-LABEL: @no_store_high_dim_vector(
+// CHECK:       vector.store
+
+// -----
+
+func.func @no_store_zero_dim_vector(%vec: vector<f32>,
+    %source: memref<64xf32>, %offset: index) {
+  vector.store %vec, %source[%offset]
+    : memref<64xf32>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: @no_store_zero_dim_vector(
+// CHECK:       vector.store



More information about the Mlir-commits mailing list