[Mlir-commits] [mlir] ff00b58 - [MLIR] Normalize memrefs in LoadOp and StoreOp of Standard Ops
Uday Bondhugula
llvmlistbot at llvm.org
Thu Sep 24 06:30:56 PDT 2020
Author: Haruki Imai
Date: 2020-09-24T18:57:15+05:30
New Revision: ff00b58392527419ea32d0b97575ef973c1bd085
URL: https://github.com/llvm/llvm-project/commit/ff00b58392527419ea32d0b97575ef973c1bd085
DIFF: https://github.com/llvm/llvm-project/commit/ff00b58392527419ea32d0b97575ef973c1bd085.diff
LOG: [MLIR] Normalize memrefs in LoadOp and StoreOp of Standard Ops
Added a trait, `MemRefsNormalizable` in LoadOp and StoreOp of Standard Ops
to normalize input memrefs in LoadOp and StoreOp.
Related revision: https://reviews.llvm.org/D86236
Differential Revision: https://reviews.llvm.org/D88156
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/test/Transforms/normalize-memrefs-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index dd49635ed8fc..649e941050a3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1831,7 +1831,8 @@ def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
def LoadOp : Std_Op<"load",
[TypesMatchWith<"result type matches element type of 'memref'",
"memref", "result",
- "$_self.cast<MemRefType>().getElementType()">]> {
+ "$_self.cast<MemRefType>().getElementType()">,
+ MemRefsNormalizable]> {
let summary = "load operation";
let description = [{
The `load` op reads an element from a memref specified by an index list. The
@@ -2580,7 +2581,8 @@ def SqrtOp : FloatUnaryOp<"sqrt"> {
def StoreOp : Std_Op<"store",
[TypesMatchWith<"type of 'value' matches element type of 'memref'",
"memref", "value",
- "$_self.cast<MemRefType>().getElementType()">]> {
+ "$_self.cast<MemRefType>().getElementType()">,
+ MemRefsNormalizable]> {
let summary = "store operation";
let description = [{
Store a value to a memref location given by indices. The value stored should
diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
index 8ce841e0d692..0c6715764492 100644
--- a/mlir/test/Transforms/normalize-memrefs-ops.mlir
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -55,3 +55,37 @@ func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
return
}
+
+// Test with maps in load and store ops.
+
+#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>
+
+// CHECK-LABEL: test_load_store
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32>
+func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
+ %0 = alloc() : memref<1x16x14x14xf32, #map_tile>
+ // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32>
+ %1 = alloc() : memref<1x16x14x14xf32>
+ // CHECK: %[[v1:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32>
+ "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
+ // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
+ %cst = constant 3.0 : f32
+ affine.for %i = 0 to 1 {
+ affine.for %j = 0 to 16 {
+ affine.for %k = 0 to 14 {
+ affine.for %l = 0 to 14 {
+ %2 = load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32>
+ // CHECK: memref<1x16x14x14xf32>
+ %3 = addf %2, %cst : f32
+ store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32>
+ // CHECK: memref<1x16x14x14xf32>
+ }
+ }
+ }
+ }
+ dealloc %0 : memref<1x16x14x14xf32, #map_tile>
+ // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
+ dealloc %1 : memref<1x16x14x14xf32>
+ // CHECK: dealloc %[[v1]] : memref<1x16x14x14xf32>
+ return
+}
More information about the Mlir-commits
mailing list