[Mlir-commits] [mlir] ff00b58 - [MLIR] Normalize memrefs in LoadOp and StoreOp of Standard Ops

Uday Bondhugula llvmlistbot at llvm.org
Thu Sep 24 06:30:56 PDT 2020


Author: Haruki Imai
Date: 2020-09-24T18:57:15+05:30
New Revision: ff00b58392527419ea32d0b97575ef973c1bd085

URL: https://github.com/llvm/llvm-project/commit/ff00b58392527419ea32d0b97575ef973c1bd085
DIFF: https://github.com/llvm/llvm-project/commit/ff00b58392527419ea32d0b97575ef973c1bd085.diff

LOG: [MLIR] Normalize memrefs in LoadOp and StoreOp of Standard Ops

Added a trait, `MemRefsNormalizable` in LoadOp and StoreOp of Standard Ops
to normalize input memrefs in LoadOp and StoreOp.

Related revision: https://reviews.llvm.org/D86236

Differential Revision: https://reviews.llvm.org/D88156

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/test/Transforms/normalize-memrefs-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index dd49635ed8fc..649e941050a3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1831,7 +1831,8 @@ def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
 def LoadOp : Std_Op<"load",
      [TypesMatchWith<"result type matches element type of 'memref'",
                      "memref", "result",
-                     "$_self.cast<MemRefType>().getElementType()">]> {
+                     "$_self.cast<MemRefType>().getElementType()">,
+                     MemRefsNormalizable]> {
   let summary = "load operation";
   let description = [{
     The `load` op reads an element from a memref specified by an index list. The
@@ -2580,7 +2581,8 @@ def SqrtOp : FloatUnaryOp<"sqrt"> {
 def StoreOp : Std_Op<"store",
      [TypesMatchWith<"type of 'value' matches element type of 'memref'",
                      "memref", "value",
-                     "$_self.cast<MemRefType>().getElementType()">]> {
+                     "$_self.cast<MemRefType>().getElementType()">,
+                     MemRefsNormalizable]> {
   let summary = "store operation";
   let description = [{
     Store a value to a memref location given by indices. The value stored should

diff  --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
index 8ce841e0d692..0c6715764492 100644
--- a/mlir/test/Transforms/normalize-memrefs-ops.mlir
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -55,3 +55,37 @@ func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
     // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
     return
 }
+
+// Test with maps in load and store ops.
+
+#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>
+
+// CHECK-LABEL: test_load_store
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32>
+func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map_tile>
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32>
+    %1 = alloc() : memref<1x16x14x14xf32>
+    // CHECK: %[[v1:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32>
+    "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
+    // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
+    %cst = constant 3.0 : f32
+    affine.for %i = 0 to 1 {
+      affine.for %j = 0 to 16 {
+        affine.for %k = 0 to 14 {
+          affine.for %l = 0 to 14 {
+            %2 = load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32>
+            // CHECK: memref<1x16x14x14xf32>
+            %3 = addf %2, %cst : f32
+            store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32>
+            // CHECK: memref<1x16x14x14xf32>
+          }
+        }
+      }
+    }
+    dealloc %0 :  memref<1x16x14x14xf32, #map_tile>
+    // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
+    dealloc %1 :  memref<1x16x14x14xf32>
+    // CHECK: dealloc %[[v1]] : memref<1x16x14x14xf32>
+    return
+}


        


More information about the Mlir-commits mailing list