[Mlir-commits] [mlir] [mlir][vector] `LoadOp`/`StoreOp`: Allow 0-D vectors (PR #76134)

Matthias Springer llvmlistbot at llvm.org
Thu Dec 21 00:14:34 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/76134

Similar to `vector.transfer_read`/`vector.transfer_write`, allow 0-D vectors.

This commit fixes `mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir` when verifying the IR after each pattern (#74270). That test produces a temporary 0-D load/store op.


>From c7fe2521d762984a3972eabb2b2dfc16d0ff7e06 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 21 Dec 2023 17:09:24 +0900
Subject: [PATCH] [mlir][vector] `LoadOp`/`StoreOp`: Allow 0-D vectors

Similar to `vector.transfer_read`/`vector.transfer_write`, allow 0-D vectors.

This commit fixes `mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir` when verifying the IR after each pattern (#74270). That test produces a temporary 0-D load/store op.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 42 ++++++++++++-------
 .../VectorToLLVM/vector-to-llvm.mlir          | 30 +++++++++++++
 mlir/test/Dialect/Vector/ops.mlir             | 10 +++++
 3 files changed, 67 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 423118f79e733d..40d874dc99dd90 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1582,22 +1582,27 @@ def Vector_LoadOp : Vector_Op<"load"> {
     vector. If the memref element type is vector, it should match the result
     vector type.
 
-    Example 1: 1-D vector load on a scalar memref.
+    Example: 0-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<f32>
+    ```
+
+    Example: 1-D vector load on a scalar memref.
     ```mlir
     %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
     ```
 
-    Example 2: 1-D vector load on a vector memref.
+    Example: 1-D vector load on a vector memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
     ```
 
-    Example 3:  2-D vector load on a scalar memref.
+    Example:  2-D vector load on a scalar memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
     ```
 
-    Example 4:  2-D vector load on a vector memref.
+    Example:  2-D vector load on a vector memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
     ```
@@ -1608,12 +1613,12 @@ def Vector_LoadOp : Vector_Op<"load"> {
     loaded out of bounds. Not all targets may support out-of-bounds vector
     loads.
 
-    Example 5:  Potential out-of-bound vector load.
+    Example:  Potential out-of-bound vector load.
     ```mlir
     %result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
     ```
 
-    Example 6:  Explicit out-of-bound vector load.
+    Example:  Explicit out-of-bound vector load.
     ```mlir
     %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
@@ -1622,7 +1627,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices);
-  let results = (outs AnyVector:$result);
+  let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -1660,22 +1665,27 @@ def Vector_StoreOp : Vector_Op<"store"> {
     to store. If the memref element type is vector, it should match the type
     of the value to store.
 
-    Example 1: 1-D vector store on a scalar memref.
+    Example: 0-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+    ```
+
+    Example: 1-D vector store on a scalar memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
     ```
 
-    Example 2: 1-D vector store on a vector memref.
+    Example: 1-D vector store on a vector memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
     ```
 
-    Example 3:  2-D vector store on a scalar memref.
+    Example:  2-D vector store on a scalar memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
     ```
 
-    Example 4:  2-D vector store on a vector memref.
+    Example:  2-D vector store on a vector memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
     ```
@@ -1685,21 +1695,23 @@ def Vector_StoreOp : Vector_Op<"store"> {
     target-specific. No assumptions should be made on the memory written out of
     bounds. Not all targets may support out-of-bounds vector stores.
 
-    Example 5:  Potential out-of-bounds vector store.
+    Example:  Potential out-of-bounds vector store.
     ```mlir
     vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
     ```
 
-    Example 6:  Explicit out-of-bounds vector store.
+    Example:  Explicit out-of-bounds vector store.
     ```mlir
     vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
   }];
 
-  let arguments = (ins AnyVector:$valueToStore,
+  let arguments = (ins
+      AnyVectorOfAnyRank:$valueToStore,
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
-      Variadic<Index>:$indices);
+      Variadic<Index>:$indices
+  );
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d80392ebd87b03..7ea0197bdecb36 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2059,6 +2059,36 @@ func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j
 
 // -----
 
+func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<f32> {
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: func @vector_load_op_0d
+// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
+// CHECK: return %[[cast]] : vector<f32>
+
+// -----
+
+func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = arith.constant dense<11.0> : vector<f32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op_0d
+// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
+// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+
+// -----
+
 func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c1ef8f2c30c05c..49f0af5d81e45e 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -714,6 +714,16 @@ func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   return %0 : vector<16xi32>
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
+func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
+                                                  %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {



More information about the Mlir-commits mailing list