[Mlir-commits] [mlir] 1cb91b4 - [mlir] Add nontemporal field to memref.load/store and convey to llvm.load/store
Guray Ozen
llvmlistbot at llvm.org
Fri Feb 3 05:03:45 PST 2023
Author: Guray Ozen
Date: 2023-02-03T14:03:38+01:00
New Revision: 1cb91b421e7d73029c0c4bc44d33d006371f7adb
URL: https://github.com/llvm/llvm-project/commit/1cb91b421e7d73029c0c4bc44d33d006371f7adb
DIFF: https://github.com/llvm/llvm-project/commit/1cb91b421e7d73029c0c4bc44d33d006371f7adb.diff
LOG: [mlir] Add nontemporal field to memref.load/store and convey to llvm.load/store
`llvm.load` op has nonTemporal field which is missing for `memref.load` and `memref.store`. This revision first adds nonTemporal field to memref's load/store op, then it lowers the field to llvm.load/store ops.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D142616
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/emulate-wide-int.mlir
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index f5dab426cb9db..d9f4384546321 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1155,7 +1155,8 @@ def LoadOp : MemRef_Op<"load",
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
- Variadic<Index>:$indices);
+ Variadic<Index>:$indices,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
@@ -1690,7 +1691,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
let arguments = (ins AnyType:$value,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
- Variadic<Index>:$indices);
+ Variadic<Index>:$indices,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
let builders = [
OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index f8f1dd1be09b5..7318f1ab56e02 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -731,7 +731,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
Value dataPtr =
getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
- rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr, 0, false,
+ loadOp.getNontemporal());
return success();
}
};
@@ -748,7 +749,8 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
+ 0, false, op.getNontemporal());
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index b0e884c285f0c..81b9b279a6500 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -66,7 +66,8 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
op.getMemRefType()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(
- op, newResTy, adaptor.getMemref(), adaptor.getIndices());
+ op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
+ op.getNontemporal());
return success();
}
};
@@ -88,7 +89,8 @@ struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
op.getMemRefType()));
rewriter.replaceOpWithNewOp<memref::StoreOp>(
- op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices());
+ op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
+ op.getNontemporal());
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 33e9ee71ee3b5..129ac41233058 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -384,10 +384,14 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(loadOp, subViewOp.getSource(),
+ .Case([&](AffineLoadOp op) {
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(loadOp, subViewOp.getSource(),
sourceIndices);
})
+ .Case([&](memref::LoadOp op) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
+ })
.Case([&](vector::TransferReadOp transferReadOp) {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
transferReadOp, transferReadOp.getVectorType(),
@@ -490,10 +494,15 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
+ .Case([&](AffineStoreOp op) {
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
})
+ .Case([&](memref::StoreOp op) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices,
+ op.getNontemporal());
+ })
.Case([&](vector::TransferWriteOp op) {
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, op.getValue(), subViewOp.getSource(), sourceIndices,
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index e1ab72234dc89..b6c73e9a917bc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -537,3 +537,24 @@ func.func @extract_strided_metadata(
return
}
+
+// -----
+
+// CHECK-LABEL: func @load_non_temporal(
+func.func @load_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>) {
+ %1 = arith.constant 7 : index
+ // CHECK: llvm.load %{{.*}} {nontemporal} : !llvm.ptr<f32>
+ %2 = memref.load %arg0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @store_non_temporal(
+func.func @store_non_temporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) {
+ %1 = arith.constant 7 : index
+ %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+ // CHECK: llvm.store %{{.*}}, %{{.*}} {nontemporal} : !llvm.ptr<f32>
+ memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+ func.return
+}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3d9f71e260550..14c570759b4c1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -894,3 +894,15 @@ func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>,
to memref<?xf32, strided<[?], offset: ?>>
return %1 : memref<?xf32, strided<[?], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func @load_store_nontemporal(
+func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) {
+ %1 = arith.constant 7 : index
+ // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32>
+ %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+ // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32>
+ memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
+ func.return
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
index de1cba5c0477f..65ac5beed0a1d 100644
--- a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
@@ -44,3 +44,19 @@ func.func @alloc_load_store_i64() {
memref.store %c1, %m[%c0] : memref<4xi64, 1>
return
}
+
+
+// CHECK-LABEL: func @alloc_load_store_i64_nontemporal
+// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
+// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: memref.store [[C1]], [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: return
+func.func @alloc_load_store_i64_nontemporal() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i64
+ %m = memref.alloc() : memref<4xi64, 1>
+ %v = memref.load %m[%c0] {nontemporal = true} : memref<4xi64, 1>
+ memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1>
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index c2ecc90be8ddf..f0d5008f991aa 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -502,3 +502,25 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
to memref<f32, strided<[], offset: ?>>
return %1 : memref<f32, strided<[], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_load_keep_nontemporal(
+// CHECK: memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true}
+func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+ %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
+ %1 = memref.load %0[%arg3, %arg4] {nontemporal = true }: memref<4x4xf32, strided<[64, 3], offset: ?>>
+ return %1 : f32
+}
+
+
+// -----
+
+// CHECK-LABEL: func @fold_store_keep_nontemporal(
+// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
+func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
+ %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
+ memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
+ memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>>
+ return
+}
More information about the Mlir-commits
mailing list