[Mlir-commits] [mlir] 2e2c073 - [mlir:Transforms] Move NormalizeMemRefs to MemRef/Transforms/

River Riddle llvmlistbot at llvm.org
Mon Jan 24 19:30:16 PST 2022


Author: River Riddle
Date: 2022-01-24T19:25:53-08:00
New Revision: 2e2c0738e80e9c2b7c1413ca4719d5be2df4c6b5

URL: https://github.com/llvm/llvm-project/commit/2e2c0738e80e9c2b7c1413ca4719d5be2df4c6b5
DIFF: https://github.com/llvm/llvm-project/commit/2e2c0738e80e9c2b7c1413ca4719d5be2df4c6b5.diff

LOG: [mlir:Transforms] Move NormalizeMemRefs to MemRef/Transforms/

Transforms/  should only contain transformations that are dialect-independent and
this pass interacts with MemRef operations (making it a better fit for living in that
dialect).

Differential Revision: https://reviews.llvm.org/D117841

Added: 
    mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
    mlir/lib/Dialect/MemRef/Transforms/PassDetail.h

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
    mlir/lib/Transforms/CMakeLists.txt

Removed: 
    mlir/lib/Transforms/NormalizeMemRefs.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 186782c6efdb..23d12508b65c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -55,6 +55,10 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 /// load/store ops into `patterns`.
 std::unique_ptr<Pass> createFoldSubViewOpsPass();
 
+/// Creates an interprocedural pass to normalize memrefs to have a trivial
+/// (identity) layout map.
+std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
+
 /// Creates an operation pass to resolve `memref.dim` operations with values
 /// that are defined by operations that implement the
 /// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 29984c4fc385..d67746b9c603 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -23,6 +23,122 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
   ];
 }
 
+def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
+  let summary = "Normalize memrefs";
+   let description = [{
+    This pass transforms memref types with a non-trivial
+    [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
+    memref types with an identity layout map, e.g. (i, j) -> (i, j). This
+    pass is inter-procedural, in the sense that it can modify function
+    interfaces and call sites that pass memref types. In order to modify
+    memref types while preserving the original behavior, users of those
+    memref types are also modified to incorporate the resulting layout map.
+    For instance, an [AffineLoadOp]
+    (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
+    will be updated to compose the layout map with with the affine expression
+    contained in the op. Operations marked with the [MemRefsNormalizable]
+    (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
+    expected to be normalizable. Supported operations include affine
+    operations, memref.alloc, memref.dealloc, and std.return.
+
+    Given an appropriate layout map specified in the code, this transformation
+    can express tiled or linearized access to multi-dimensional data
+    structures, but will not modify memref types without an explicit layout
+    map.
+
+    Currently this pass is limited to only modify
+    functions where all memref types can be normalized. If a function
+    contains any operations that are not MemRefNormalizable, then the function
+    and any functions that call or call it will not be modified.
+
+    Input
+
+    ```mlir
+    #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
+    func @matmul(%A: memref<16xf64, #tile>,
+                 %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
+      affine.for %arg3 = 0 to 16 {
+            %a = affine.load %A[%arg3] : memref<16xf64, #tile>
+            %p = arith.mulf %a, %a : f64
+            affine.store %p, %A[%arg3] : memref<16xf64, #tile>
+      }
+      %c = memref.alloc() : memref<16xf64, #tile>
+      %d = affine.load %c[0] : memref<16xf64, #tile>
+      return %A: memref<16xf64, #tile>
+    }
+    ```
+
+    Output
+
+    ```mlir
+    func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
+      -> memref<4x4xf64> {
+      affine.for %arg3 = 0 to 16 {
+        %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
+        %4 = arith.mulf %3, %3 : f64
+        affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
+      }
+      %0 = memref.alloc() : memref<4x4xf64>
+      %1 = affine.apply #map1()
+      %2 = affine.load %0[0, 0] : memref<4x4xf64>
+      return %arg0 : memref<4x4xf64>
+    }
+    ```
+
+    Input
+
+    ```
+    #linear8 = affine_map<(i, j) -> (i * 8 + j)>
+    func @linearize(%arg0: memref<8x8xi32, #linear8>,
+                    %arg1: memref<8x8xi32, #linear8>,
+                    %arg2: memref<8x8xi32, #linear8>) {
+      %c8 = arith.constant 8 : index
+      %c0 = arith.constant 0 : index
+      %c1 = arith.constant 1 : index
+      affine.for %arg3 = %c0 to %c8  {
+      affine.for %arg4 = %c0 to %c8  {
+        affine.for %arg5 = %c0 to %c8 {
+          %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
+          %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
+          %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
+          %3 = arith.muli %0, %1 : i32
+          %4 = arith.addi %2, %3 : i32
+          affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
+        }
+      }
+      }
+      return
+    }
+    ```
+
+    Output
+
+    ```mlir
+    func @linearize(%arg0: memref<64xi32>,
+                    %arg1: memref<64xi32>,
+                    %arg2: memref<64xi32>) {
+    %c8 = arith.constant 8 : index
+    %c0 = arith.constant 0 : index
+    affine.for %arg3 = %c0 to %c8 {
+      affine.for %arg4 = %c0 to %c8 {
+        affine.for %arg5 = %c0 to %c8 {
+          %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
+          %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
+          %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
+          %3 = arith.muli %0, %1 : i32
+          %4 = arith.addi %2, %3 : i32
+          affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
+        }
+      }
+    }
+    return
+  }
+  ```
+  }];
+  let constructor = "mlir::memref::createNormalizeMemRefsPass()";
+  let dependentDialects = ["AffineDialect"];
+}
+
 def ResolveRankedShapeTypeResultDims :
     Pass<"resolve-ranked-shaped-type-result-dims"> {
   let summary = "Resolve memref.dim of result values of ranked shape type";

diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 222434324451..4876d705afcb 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -113,10 +113,6 @@ std::unique_ptr<Pass> createSCCPPass();
 /// pass may *only* be scheduled on an operation that defines a SymbolTable.
 std::unique_ptr<Pass> createSymbolDCEPass();
 
-/// Creates an interprocedural pass to normalize memrefs to have a trivial
-/// (identity) layout map.
-std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 9942c0fc8892..44bf475af24c 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -351,122 +351,6 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
   let constructor = "mlir::createLoopInvariantCodeMotionPass()";
 }
 
-def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
-  let summary = "Normalize memrefs";
-   let description = [{
-    This pass transforms memref types with a non-trivial
-    [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
-    memref types with an identity layout map, e.g. (i, j) -> (i, j). This
-    pass is inter-procedural, in the sense that it can modify function
-    interfaces and call sites that pass memref types. In order to modify
-    memref types while preserving the original behavior, users of those
-    memref types are also modified to incorporate the resulting layout map.
-    For instance, an [AffineLoadOp]
-    (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
-    will be updated to compose the layout map with with the affine expression
-    contained in the op. Operations marked with the [MemRefsNormalizable]
-    (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
-    expected to be normalizable. Supported operations include affine
-    operations, memref.alloc, memref.dealloc, and std.return.
-
-    Given an appropriate layout map specified in the code, this transformation
-    can express tiled or linearized access to multi-dimensional data
-    structures, but will not modify memref types without an explicit layout
-    map.
-
-    Currently this pass is limited to only modify
-    functions where all memref types can be normalized. If a function
-    contains any operations that are not MemRefNormalizable, then the function
-    and any functions that call or call it will not be modified.
-
-    Input
-
-    ```mlir
-    #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
-    func @matmul(%A: memref<16xf64, #tile>,
-                 %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
-      affine.for %arg3 = 0 to 16 {
-            %a = affine.load %A[%arg3] : memref<16xf64, #tile>
-            %p = arith.mulf %a, %a : f64
-            affine.store %p, %A[%arg3] : memref<16xf64, #tile>
-      }
-      %c = memref.alloc() : memref<16xf64, #tile>
-      %d = affine.load %c[0] : memref<16xf64, #tile>
-      return %A: memref<16xf64, #tile>
-    }
-    ```
-
-    Output
-
-    ```mlir
-    func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-      -> memref<4x4xf64> {
-      affine.for %arg3 = 0 to 16 {
-        %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
-        %4 = arith.mulf %3, %3 : f64
-        affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
-      }
-      %0 = memref.alloc() : memref<4x4xf64>
-      %1 = affine.apply #map1()
-      %2 = affine.load %0[0, 0] : memref<4x4xf64>
-      return %arg0 : memref<4x4xf64>
-    }
-    ```
-
-    Input
-
-    ```
-    #linear8 = affine_map<(i, j) -> (i * 8 + j)>
-    func @linearize(%arg0: memref<8x8xi32, #linear8>,
-                    %arg1: memref<8x8xi32, #linear8>,
-                    %arg2: memref<8x8xi32, #linear8>) {
-      %c8 = arith.constant 8 : index
-      %c0 = arith.constant 0 : index
-      %c1 = arith.constant 1 : index
-      affine.for %arg3 = %c0 to %c8  {
-      affine.for %arg4 = %c0 to %c8  {
-        affine.for %arg5 = %c0 to %c8 {
-          %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
-          %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
-          %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
-          %3 = arith.muli %0, %1 : i32
-          %4 = arith.addi %2, %3 : i32
-          affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
-        }
-      }
-      }
-      return
-    }
-    ```
-
-    Output
-
-    ```mlir
-    func @linearize(%arg0: memref<64xi32>,
-                    %arg1: memref<64xi32>,
-                    %arg2: memref<64xi32>) {
-    %c8 = arith.constant 8 : index
-    %c0 = arith.constant 0 : index
-    affine.for %arg3 = %c0 to %c8 {
-      affine.for %arg4 = %c0 to %c8 {
-        affine.for %arg5 = %c0 to %c8 {
-          %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
-          %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
-          %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
-          %3 = arith.muli %0, %1 : i32
-          %4 = arith.addi %2, %3 : i32
-          affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
-        }
-      }
-    }
-    return
-  }
-  ```
-  }];
-  let constructor = "mlir::createNormalizeMemRefsPass()";
-  let dependentDialects = ["AffineDialect"];
-}
-
 def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> {
   let summary = "Collapse parallel loops to use less induction variables";
   let constructor = "mlir::createParallelLoopCollapsingPass()";

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 3eda2ded018f..319f9bbb95a3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   FoldSubViewOps.cpp
+  NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
similarity index 99%
rename from mlir/lib/Transforms/NormalizeMemRefs.cpp
rename to mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 5119c4526364..0b5e49b2df52 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -14,7 +14,7 @@
 #include "PassDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/Debug.h"
@@ -43,7 +43,8 @@ struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
 
 } // namespace
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::memref::createNormalizeMemRefsPass() {
   return std::make_unique<NormalizeMemRefs>();
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h
new file mode 100644
index 000000000000..d15631526817
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h
@@ -0,0 +1,43 @@
+//===- PassDetail.h - MemRef Pass class details -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_
+#define DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+class AffineDialect;
+
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace arith {
+class ArithmeticDialect;
+} // namespace arith
+
+namespace memref {
+class MemRefDialect;
+} // namespace memref
+
+namespace tensor {
+class TensorDialect;
+} // namespace tensor
+
+namespace vector {
+class VectorDialect;
+} // namespace vector
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // DIALECT_MEMREF_TRANSFORMS_PASSDETAIL_H_

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 60f82f3b9e4b..3f6aeeb69641 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "PassDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -107,9 +108,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
 //===----------------------------------------------------------------------===//
 
 namespace {
-#define GEN_PASS_CLASSES
-#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
-
 struct ResolveRankedShapeTypeResultDimsPass final
     : public ResolveRankedShapeTypeResultDimsBase<
           ResolveRankedShapeTypeResultDimsPass> {

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 1eba1ad94e5a..7826650f5747 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -9,7 +9,6 @@ add_mlir_library(MLIRTransforms
   LoopCoalescing.cpp
   LoopFusion.cpp
   LoopInvariantCodeMotion.cpp
-  NormalizeMemRefs.cpp
   OpStats.cpp
   ParallelLoopCollapsing.cpp
   PipelineDataTransfer.cpp


        


More information about the Mlir-commits mailing list