[Mlir-commits] [mlir] 1cb91b4 - [mlir] Add nontemporal field to memref.load/store and convey to llvm.load/store

Guray Ozen llvmlistbot at llvm.org
Fri Feb 3 05:03:45 PST 2023


Author: Guray Ozen
Date: 2023-02-03T14:03:38+01:00
New Revision: 1cb91b421e7d73029c0c4bc44d33d006371f7adb

URL: https://github.com/llvm/llvm-project/commit/1cb91b421e7d73029c0c4bc44d33d006371f7adb
DIFF: https://github.com/llvm/llvm-project/commit/1cb91b421e7d73029c0c4bc44d33d006371f7adb.diff

LOG: [mlir] Add nontemporal field to memref.load/store and convey to llvm.load/store

`llvm.load` op has nonTemporal field which is missing for `memref.load` and `memref.store`. This revision first adds nonTemporal field to memref's load/store op, then it lowers the field to llvm.load/store ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D142616

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/emulate-wide-int.mlir
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index f5dab426cb9db..d9f4384546321 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1155,7 +1155,8 @@ def LoadOp : MemRef_Op<"load",
 
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
-                       Variadic<Index>:$indices);
+                       Variadic<Index>:$indices,                       
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1690,7 +1691,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let arguments = (ins AnyType:$value,
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
-                       Variadic<Index>:$indices);
+                       Variadic<Index>:$indices,                       
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
 
   let builders = [
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index f8f1dd1be09b5..7318f1ab56e02 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -731,7 +731,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
     Value dataPtr =
         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
                              adaptor.getIndices(), rewriter);
-    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
+    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr, 0, false,
+                                              loadOp.getNontemporal());
     return success();
   }
 };
@@ -748,7 +749,8 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
 
     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
                                          adaptor.getIndices(), rewriter);
-    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
+                                               0, false, op.getNontemporal());
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index b0e884c285f0c..81b9b279a6500 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -66,7 +66,8 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
                                       op.getMemRefType()));
 
     rewriter.replaceOpWithNewOp<memref::LoadOp>(
-        op, newResTy, adaptor.getMemref(), adaptor.getIndices());
+        op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
+        op.getNontemporal());
     return success();
   }
 };
@@ -88,7 +89,8 @@ struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
                                       op.getMemRefType()));
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
-        op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices());
+        op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
+        op.getNontemporal());
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 33e9ee71ee3b5..129ac41233058 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -384,10 +384,14 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     return failure();
 
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(loadOp, subViewOp.getSource(),
+      .Case([&](AffineLoadOp op) {
+        rewriter.replaceOpWithNewOp<AffineLoadOp>(loadOp, subViewOp.getSource(),
                                                   sourceIndices);
       })
+      .Case([&](memref::LoadOp op) {
+        rewriter.replaceOpWithNewOp<memref::LoadOp>(
+            loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
+      })
       .Case([&](vector::TransferReadOp transferReadOp) {
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             transferReadOp, transferReadOp.getVectorType(),
@@ -490,10 +494,15 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     return failure();
 
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
-        rewriter.replaceOpWithNewOp<decltype(op)>(
+      .Case([&](AffineStoreOp op) {
+        rewriter.replaceOpWithNewOp<AffineStoreOp>(
             storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
       })
+      .Case([&](memref::StoreOp op) {
+        rewriter.replaceOpWithNewOp<memref::StoreOp>(
+            storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices,
+            op.getNontemporal());
+      })
       .Case([&](vector::TransferWriteOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
             op, op.getValue(), subViewOp.getSource(), sourceIndices,

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index e1ab72234dc89..b6c73e9a917bc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -537,3 +537,24 @@ func.func @extract_strided_metadata(
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @load_non_temporal(
+func.func @load_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>) {  
+  %1 = arith.constant 7 : index
+  // CHECK: llvm.load %{{.*}} {nontemporal} : !llvm.ptr<f32>
+  %2 = memref.load %arg0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @store_non_temporal(
+func.func @store_non_temporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) {
+  %1 = arith.constant 7 : index
+  %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+  // CHECK: llvm.store %{{.*}}, %{{.*}}  {nontemporal} : !llvm.ptr<f32>
+  memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+  func.return
+}

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3d9f71e260550..14c570759b4c1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -894,3 +894,15 @@ func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>,
         to memref<?xf32, strided<[?], offset: ?>>
   return %1 : memref<?xf32, strided<[?], offset: ?>>
 }
+
+// -----
+
+// CHECK-LABEL: func @load_store_nontemporal(
+func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) {
+  %1 = arith.constant 7 : index
+  // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32>
+  %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+  // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32>
+  memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+  func.return
+}

diff  --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
index de1cba5c0477f..65ac5beed0a1d 100644
--- a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
@@ -44,3 +44,19 @@ func.func @alloc_load_store_i64() {
     memref.store %c1, %m[%c0] : memref<4xi64, 1>
     return
 }
+
+
+// CHECK-LABEL: func @alloc_load_store_i64_nontemporal
+// CHECK:         [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
+// CHECK-NEXT:    [[M:%.+]]  = memref.alloc() : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT:    [[V:%.+]]  = memref.load [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT:    memref.store [[C1]], [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT:    return
+func.func @alloc_load_store_i64_nontemporal() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i64
+    %m = memref.alloc() : memref<4xi64, 1>
+    %v = memref.load %m[%c0] {nontemporal = true} : memref<4xi64, 1>
+    memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1>
+    return
+}

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index c2ecc90be8ddf..f0d5008f991aa 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -502,3 +502,25 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
         to memref<f32, strided<[], offset: ?>>
   return %1 : memref<f32, strided<[], offset: ?>>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_load_keep_nontemporal(
+//      CHECK:   memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true}
+func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
+  %1 = memref.load %0[%arg3, %arg4] {nontemporal = true }: memref<4x4xf32, strided<[64, 3], offset: ?>>
+  return %1 : f32
+}
+
+
+// -----
+
+// CHECK-LABEL: func @fold_store_keep_nontemporal(
+//      CHECK:   memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}]  {nontemporal = true} : memref<12x32xf32> 
+func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
+  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
+    memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
+  memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>>
+  return
+}


        


More information about the Mlir-commits mailing list