[Mlir-commits] [mlir] 5cfe24e - [mlir][Vector] Add nontemporal attribute, mirroring memref (#76752)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 9 09:05:25 PST 2024


Author: Krzysztof Drewniak
Date: 2024-01-09T11:05:20-06:00
New Revision: 5cfe24eee49dfb9f6f72e73142e075dbbadd3089

URL: https://github.com/llvm/llvm-project/commit/5cfe24eee49dfb9f6f72e73142e075dbbadd3089
DIFF: https://github.com/llvm/llvm-project/commit/5cfe24eee49dfb9f6f72e73142e075dbbadd3089.diff

LOG: [mlir][Vector] Add nontemporal attribute, mirroring memref (#76752)

Since vector loads and stores from scalar memrefs translate to
llvm.load/store, add the ability to tag said loads and stores as
nontemporal. This mirrors functionality available in memref.load/store.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 40d874dc99dd90..8e333def338691 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1626,7 +1626,8 @@ def Vector_LoadOp : Vector_Op<"load"> {
 
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
-      Variadic<Index>:$indices);
+      Variadic<Index>:$indices,
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1710,7 +1711,8 @@ def Vector_StoreOp : Vector_Op<"store"> {
       AnyVectorOfAnyRank:$valueToStore,
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
-      Variadic<Index>:$indices
+      Variadic<Index>:$indices,
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
   );
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a24fb6f8391538..b66b55ae8d57ff 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -192,7 +192,9 @@ static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
                                  vector::LoadOpAdaptor adaptor,
                                  VectorType vectorTy, Value ptr, unsigned align,
                                  ConversionPatternRewriter &rewriter) {
-  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
+  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
+                                            /*volatile_=*/false,
+                                            loadOp.getNontemporal());
 }
 
 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
@@ -208,7 +210,8 @@ static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
                                  VectorType vectorTy, Value ptr, unsigned align,
                                  ConversionPatternRewriter &rewriter) {
   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
-                                             ptr, align);
+                                             ptr, align, /*volatile_=*/false,
+                                             storeOp.getNontemporal());
 }
 
 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index b13c266609ae81..09108ab3179998 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2023,6 +2023,20 @@ func.func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index)
 
 // -----
 
+func.func @vector_load_op_nontemporal(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  %0 = vector.load %memref[%i, %j] {nontemporal = true} : memref<200x100xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @vector_load_op_nontemporal
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: llvm.load %[[gep]] {alignment = 4 : i64, nontemporal} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
 func.func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) -> vector<8xindex> {
   %0 = vector.load %memref[%i, %j] : memref<200x100xindex>, vector<8xindex>
   return %0 : vector<8xindex>
@@ -2049,6 +2063,21 @@ func.func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index
 
 // -----
 
+func.func @vector_store_op_nontemporal(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = arith.constant dense<11.0> : vector<4xf32>
+  vector.store %val, %memref[%i, %j] {nontemporal = true} : memref<200x100xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op_nontemporal
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: llvm.store %{{.*}}, %[[gep]] {alignment = 4 : i64, nontemporal} :  vector<4xf32>, !llvm.ptr
+
+// -----
+
 func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) {
   %val = arith.constant dense<11> : vector<4xindex>
   vector.store %val, %memref[%i, %j] : memref<200x100xindex>, vector<4xindex>


        


More information about the Mlir-commits mailing list