[Mlir-commits] [mlir] 47df8c5 - [MLIR] Updates around MemRef Normalization

Stephen Neuendorffer llvmlistbot at llvm.org
Thu Oct 1 21:24:30 PDT 2020


Author: Stephen Neuendorffer
Date: 2020-10-01T21:11:41-07:00
New Revision: 47df8c57e4ed01fa0101aa0b320fc7cf5a90df28

URL: https://github.com/llvm/llvm-project/commit/47df8c57e4ed01fa0101aa0b320fc7cf5a90df28
DIFF: https://github.com/llvm/llvm-project/commit/47df8c57e4ed01fa0101aa0b320fc7cf5a90df28.diff

LOG: [MLIR] Updates around MemRef Normalization

The documentation for the NormalizeMemRefs pass and the associated MemRefsNormalizable
traits was confusing and not on the website.  This update clarifies the language
around the difference between a MemRef Type, an operation that accesses the value of
MemRef Type, and better documents the limitations of the current implementation.
This patch also includes some basic debugging information for the pass so people
might have a chance of figuring out why it doesn't work on their code.

Differential Revision: https://reviews.llvm.org/D88532

Added: 
    

Modified: 
    mlir/docs/Traits.md
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/NormalizeMemRefs.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index 3fa56249ae42..488da39e6504 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -251,13 +251,15 @@ to have [passes](PassManagement.md) scheduled under them.
 
 * `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
 
-This trait is used to flag operations that can accommodate `MemRefs` with
-non-identity memory-layout specifications. This trait indicates that the
-normalization of memory layout can be performed for such operations.
-`MemRefs` normalization consists of replacing an original memory reference
-with layout specifications to an equivalent memory reference where
-the specified memory layout is applied by rewritting accesses and types
-associated with that memory reference.
+This trait is used to flag operations that consume or produce
+values of `MemRef` type where those references can be 'normalized'.
+In cases where an associated `MemRef` has a
+non-identity memory-layout specification, such normalizable operations can be
+modified so that the `MemRef` has an identity layout specification.
+This can be implemented by associating the operation with its own
+index expression that can express the equivalent of the memory-layout
+specification of the MemRef type. See [the -normalize-memrefs pass].
+(https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs)
 
 ### Single Block with Implicit Terminator
 

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 9f3df4343261..6861523e0d04 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1212,13 +1212,8 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
   }
 };
 
-/// This trait is used to flag operations that can accommodate MemRefs with
-/// non-identity memory-layout specifications. This trait indicates that the
-/// normalization of memory layout can be performed for such operations.
-/// MemRefs normalization consists of replacing an original memory reference
-/// with layout specifications to an equivalent memory reference where the
-/// specified memory layout is applied by rewritting accesses and types
-/// associated with that memory reference.
+// This trait is used to flag operations that consume or produce
+// values of `MemRef` type where those references can be 'normalized'.
 // TODO: Right now, the operands of an operation are either all normalizable,
 // or not. In the future, we may want to allow some of the operands to be
 // normalizable.

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 3292d5e7dec2..367e19cfcd55 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -313,6 +313,116 @@ def MemRefDataFlowOpt : FunctionPass<"memref-dataflow-opt"> {
 
 def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
   let summary = "Normalize memrefs";
+   let description = [{
+    This pass transforms memref types with a non-trivial
+    [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
+    memref types with an identity layout map, e.g. (i, j) -> (i, j). This
+    pass is inter-procedural, in the sense that it can modify function
+    interfaces and call sites that pass memref types. In order to modify
+    memref types while preserving the original behavior, users of those
+    memref types are also modified to incorporate the resulting layout map.
+    For instance, an [AffineLoadOp]
+    (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
+    will be updated to compose the layout map with with the affine expression
+    contained in the op. Operations marked with the [MemRefsNormalizable]
+    (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
+    expected to be normalizable. Supported operations include affine
+    operations, std.alloc, std.dealloc, and std.return.
+
+    Given an appropriate layout map specified in the code, this transformation
+    can express tiled or linearized access to multi-dimensional data
+    structures, but will not modify memref types without an explicit layout
+    map.
+
+    Currently this pass is limited to only modify
+    functions where all memref types can be normalized. If a function
+    contains any operations that are not MemRefNormalizable, then the function
+    and any functions that call or call it will not be modified.
+
+    Input
+
+    ```mlir
+    #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
+    func @matmul(%A: memref<16xf64, #tile>,
+                 %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
+      affine.for %arg3 = 0 to 16 {
+            %a = affine.load %A[%arg3] : memref<16xf64, #tile>
+            %p = mulf %a, %a : f64
+            affine.store %p, %A[%arg3] : memref<16xf64, #tile>
+      }
+      %c = alloc() : memref<16xf64, #tile>
+      %d = affine.load %c[0] : memref<16xf64, #tile>
+      return %A: memref<16xf64, #tile>
+    }
+    ```
+
+    Output
+
+    ```mlir
+    func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
+      -> memref<4x4xf64> {
+      affine.for %arg3 = 0 to 16 {
+        %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
+        %4 = mulf %3, %3 : f64
+        affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
+      }
+      %0 = alloc() : memref<4x4xf64>
+      %1 = affine.apply #map1()
+      %2 = affine.load %0[0, 0] : memref<4x4xf64>
+      return %arg0 : memref<4x4xf64>
+    }
+    ```
+
+    Input
+
+    ```
+    #linear8 = affine_map<(i, j) -> (i * 8 + j)>
+    func @linearize(%arg0: memref<8x8xi32, #linear8>,
+                    %arg1: memref<8x8xi32, #linear8>,
+                    %arg2: memref<8x8xi32, #linear8>) {
+      %c8 = constant 8 : index
+      %c0 = constant 0 : index
+      %c1 = constant 1 : index
+      affine.for %arg3 = %c0 to %c8  {
+      affine.for %arg4 = %c0 to %c8  {
+        affine.for %arg5 = %c0 to %c8 {
+          %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
+          %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
+          %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
+          %3 = muli %0, %1 : i32
+          %4 = addi %2, %3 : i32
+          affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
+        }
+      }
+      }
+      return
+    }
+    ```
+
+    Output
+
+    ```mlir
+    func @linearize(%arg0: memref<64xi32>,
+                    %arg1: memref<64xi32>,
+                    %arg2: memref<64xi32>) {
+    %c8 = constant 8 : index
+    %c0 = constant 0 : index
+    affine.for %arg3 = %c0 to %c8 {
+      affine.for %arg4 = %c0 to %c8 {
+        affine.for %arg5 = %c0 to %c8 {
+          %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
+          %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
+          %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
+          %3 = muli %0, %1 : i32
+          %4 = addi %2, %3 : i32
+          affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
+        }
+      }
+    }
+    return
+  }
+  ```
+  }];
   let constructor = "mlir::createNormalizeMemRefsPass()";
 }
 

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index ac02f0e6ba97..44b3ccbd2c3f 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -29,34 +29,6 @@ namespace {
 /// such functions as normalizable. Also, if a normalizable function is known
 /// to call a non-normalizable function, we treat that function as
 /// non-normalizable as well. We assume external functions to be normalizable.
-///
-/// Input :-
-/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
-/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
-/// (memref<16xf64, #tile>) {
-///   affine.for %arg3 = 0 to 16 {
-///         %a = affine.load %A[%arg3] : memref<16xf64, #tile>
-///         %p = mulf %a, %a : f64
-///         affine.store %p, %A[%arg3] : memref<16xf64, #tile>
-///   }
-///   %c = alloc() : memref<16xf64, #tile>
-///   %d = affine.load %c[0] : memref<16xf64, #tile>
-///   return %A: memref<16xf64, #tile>
-/// }
-///
-/// Output :-
-///   func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-///   -> memref<4x4xf64> {
-///     affine.for %arg3 = 0 to 16 {
-///       %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] :
-///       memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3
-///       floordiv 4, %arg3 mod 4] : memref<4x4xf64>
-///     }
-///     %0 = alloc() : memref<16xf64, #map0>
-///     %1 = affine.load %0[0] : memref<16xf64, #map0>
-///     return %arg0 : memref<4x4xf64>
-///   }
-///
 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
   void runOnOperation() override;
   void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
@@ -73,6 +45,7 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
 }
 
 void NormalizeMemRefs::runOnOperation() {
+  LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
   ModuleOp moduleOp = getOperation();
   // We maintain all normalizable FuncOps in a DenseSet. It is initialized
   // with all the functions within a module and then functions which are not
@@ -92,6 +65,9 @@ void NormalizeMemRefs::runOnOperation() {
   moduleOp.walk([&](FuncOp funcOp) {
     if (normalizableFuncs.contains(funcOp)) {
       if (!areMemRefsNormalizable(funcOp)) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "@" << funcOp.getName()
+                   << " contains ops that cannot normalize MemRefs\n");
         // Since this function is not normalizable, we set all the caller
         // functions and the callees of this function as not normalizable.
         // TODO: Drop this conservative assumption in the future.
@@ -101,6 +77,8 @@ void NormalizeMemRefs::runOnOperation() {
     }
   });
 
+  LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
+                          << " functions\n");
   // Those functions which can be normalized are subjected to normalization.
   for (FuncOp &funcOp : normalizableFuncs)
     normalizeFuncOpMemRefs(funcOp, moduleOp);
@@ -127,6 +105,9 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
   if (!normalizableFuncs.contains(funcOp))
     return;
 
+  LLVM_DEBUG(
+      llvm::dbgs() << "@" << funcOp.getName()
+                   << " calls or is called by non-normalizable function\n");
   normalizableFuncs.erase(funcOp);
   // Caller of the function.
   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);


        


More information about the Mlir-commits mailing list