[Mlir-commits] [mlir] a14a280 - [MLIR] MemRef Normalization for Dialects

Uday Bondhugula llvmlistbot at llvm.org
Thu Aug 27 07:59:04 PDT 2020


Author: Alexandre E. Eichenberger
Date: 2020-08-27T20:26:59+05:30
New Revision: a14a2805b04d49bfbbff6f79f141738c67ad14fd

URL: https://github.com/llvm/llvm-project/commit/a14a2805b04d49bfbbff6f79f141738c67ad14fd
DIFF: https://github.com/llvm/llvm-project/commit/a14a2805b04d49bfbbff6f79f141738c67ad14fd.diff

LOG: [MLIR] MemRef Normalization for Dialects

When dealing with dialects that will results in function calls to
external libraries, it is important to be able to handle maps as some
dialects may require mapped data.  Before this patch, the detection of
whether normalization can apply or not, operations are compared to an
explicit list of operations (`alloc`, `dealloc`, `return`) or to the
presence of specific operation interfaces (`AffineReadOpInterface`,
`AffineWriteOpInterface`, `AffineDMAStartOp`, or `AffineDMAWaitOp`).

This patch add a trait, `MemRefsNormalizable` to determine if an
operation can have its `memrefs` normalized.

This trait can be used in turn by dialects to assert that such
operations are compatible with normalization of `memrefs` with
nontrivial memory layout specification. An example is given in the
literal tests.

Differential Revision: https://reviews.llvm.org/D86236

Added: 
    mlir/test/Transforms/normalize-memrefs-ops.mlir

Modified: 
    mlir/docs/Traits.md
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/Transforms/NormalizeMemRefs.cpp
    mlir/lib/Transforms/Utils/Utils.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index c9ef132717b8..5867f220e97b 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -247,6 +247,18 @@ foo.region_op {
 This trait is an important structural property of the IR, and enables operations
 to have [passes](PassManagement.md) scheduled under them.
 
+### MemRefsNormalizable
+
+* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
+
+This trait is used to flag operations that can accommodate `MemRefs` with
+non-identity memory-layout specifications. This trait indicates that the
+normalization of memory layout can be performed for such operations.
+`MemRefs` normalization consists of replacing an original memory reference
+with layout specifications to an equivalent memory reference where
+the specified memory layout is applied by rewritting accesses and types
+associated with that memory reference.
+
 ### Single Block with Implicit Terminator
 
 *   `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` :

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index b88028115a6f..b8b29ff63355 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -80,8 +80,9 @@ bool isTopLevelValue(Value value);
 // multiple stride levels (possibly using AffineMaps to specify multiple levels
 // of striding).
 // TODO: Consider replacing src/dst memref indices with view memrefs.
-class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
-                                   OpTrait::ZeroResult> {
+class AffineDmaStartOp
+    : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
+                OpTrait::VariadicOperands, OpTrait::ZeroResult> {
 public:
   using Op::Op;
 
@@ -268,8 +269,9 @@ class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
 //   ...
 //   affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
 //
-class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
-                                  OpTrait::ZeroResult> {
+class AffineDmaWaitOp
+    : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
+                OpTrait::VariadicOperands, OpTrait::ZeroResult> {
 public:
   using Op::Op;
 

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index e5273a8c9662..480e1717c588 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -405,7 +405,8 @@ def AffineIfOp : Affine_Op<"if",
 
 class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
-        [DeclareOpInterfaceMethods<AffineReadOpInterface>])> {
+        [DeclareOpInterfaceMethods<AffineReadOpInterface>,
+        MemRefsNormalizable])> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$memref,
       Variadic<Index>:$indices);
@@ -732,7 +733,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
 
 class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
-        [DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
+    [DeclareOpInterfaceMethods<AffineWriteOpInterface>,
+    MemRefsNormalizable])> {
   code extraClassDeclarationBase = [{
     /// Returns the operand index of the value to be stored.
     unsigned getStoredValOperandIndex() { return 0; }

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b80da2958019..063c34ceedbd 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -658,7 +658,7 @@ def BranchOp : Std_Op<"br",
 // CallOp
 //===----------------------------------------------------------------------===//
 
-def CallOp : Std_Op<"call", [CallOpInterface]> {
+def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
   let summary = "call operation";
   let description = [{
     The `call` operation represents a direct call to a function that is within
@@ -1388,7 +1388,8 @@ def SinOp : FloatUnaryOp<"sin"> {
 // DeallocOp
 //===----------------------------------------------------------------------===//
 
-def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> {
+def DeallocOp : Std_Op<"dealloc",
+    [MemoryEffects<[MemFree]>, MemRefsNormalizable]> {
   let summary = "memory deallocation operation";
   let description = [{
     The `dealloc` operation frees the region of memory referenced by a memref
@@ -2144,8 +2145,8 @@ def RemFOp : FloatArithmeticOp<"remf"> {
 // ReturnOp
 //===----------------------------------------------------------------------===//
 
-def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike,
-                                 Terminator]> {
+def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
+                                MemRefsNormalizable, ReturnLike, Terminator]> {
   let summary = "return operation";
   let description = [{
     The `return` operation represents a return operation within a function.

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index a28410f028d5..8375f2416062 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1698,6 +1698,9 @@ def SameOperandsAndResultElementType :
   NativeOpTrait<"SameOperandsAndResultElementType">;
 // Op is a terminator.
 def Terminator : NativeOpTrait<"IsTerminator">;
+// Op can be safely normalized in the presence of MemRefs with
+// non-identity maps.
+def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
 
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index db77935e1325..9579c8121536 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1212,6 +1212,20 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
   }
 };
 
+/// This trait is used to flag operations that can accommodate MemRefs with
+/// non-identity memory-layout specifications. This trait indicates that the
+/// normalization of memory layout can be performed for such operations.
+/// MemRefs normalization consists of replacing an original memory reference
+/// with layout specifications to an equivalent memory reference where the
+/// specified memory layout is applied by rewritting accesses and types
+/// associated with that memory reference.
+// TODO: Right now, the operands of an operation are either all normalizable,
+// or not. In the future, we may want to allow some of the operands to be
+// normalizable.
+template <typename ConcrentType>
+struct MemRefsNormalizable
+    : public TraitBase<ConcrentType, MemRefsNormalizable> {};
+
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index 1736fa989a83..c4f91eb7a9d2 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -106,23 +106,15 @@ void NormalizeMemRefs::runOnOperation() {
     normalizeFuncOpMemRefs(funcOp, moduleOp);
 }
 
-/// Return true if this operation dereferences one or more memref's.
-/// TODO: Temporary utility, will be replaced when this is modeled through
-/// side-effects/op traits.
-static bool isMemRefDereferencingOp(Operation &op) {
-  return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
-             AffineDmaWaitOp>(op);
-}
-
 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
 /// are satisfied will the value become a candidate for replacement.
 /// TODO: Extend this for DimOps.
 static bool isMemRefNormalizable(Value::user_range opUsers) {
   if (llvm::any_of(opUsers, [](Operation *op) {
-        if (isMemRefDereferencingOp(*op))
+        if (op->hasTrait<OpTrait::MemRefsNormalizable>())
           return false;
-        return !isa<DeallocOp, CallOp, ReturnOp>(*op);
+        return true;
       }))
     return false;
   return true;

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index c310702745a2..516f8c060a93 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -279,7 +279,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
       // Currently we support the following non-dereferencing ops to be a
       // candidate for replacement: Dealloc, CallOp and ReturnOp.
       // TODO: Add support for other kinds of ops.
-      if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
+      if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
         return failure();
     }
 

diff  --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
new file mode 100644
index 000000000000..8ce841e0d692
--- /dev/null
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s
+
+// For all these cases, we test if MemRefs Normalization works with the test
+// operations.
+// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests
+//   that include this operation are constructed so that the normalization should
+//   happen.
+// * test_op_nonnorm: this operation does not have the MemRefsNormalization
+//   attribute. The tests that include this operation are contructed so that the
+//    normalization should not happen.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)>
+
+// Test with op_norm and maps in arguments and in the operations in the function.
+
+// CHECK-LABEL: test_norm
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>)
+func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map0>
+    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
+    dealloc %0 :  memref<1x16x14x14xf32, #map0>
+
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
+    // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
+    return
+}
+
+// Same test with op_nonnorm, with maps in the argmentets and the operations in the function.
+
+// CHECK-LABEL: test_nonnorm
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map0>)
+func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map0>
+    "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
+    dealloc %0 :  memref<1x16x14x14xf32, #map0>
+
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map0>
+    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
+    // CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map0>
+    return
+}
+
+// Test with op_norm, with maps in the operations in the function.
+
+// CHECK-LABEL: test_norm_mix
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>
+func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map0>
+    "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
+    dealloc %0 :  memref<1x16x14x14xf32, #map0>
+
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
+    // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
+    return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c938081b1642..022732d55016 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -618,6 +618,16 @@ def OpM : TEST_Op<"op_m"> {
   let arguments = (ins I32, OptionalAttr<I32Attr>:$optional_attr);
   let results = (outs I32);
 }
+
+// Test for memrefs normalization of an op with normalizable memrefs.
+def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
+  let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
+}
+// Test for memrefs normalization of an op without normalizable memrefs.
+def OpNonNorm : TEST_Op<"op_nonnorm"> {
+  let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
+}
+
 // Pattern add the argument plus a increasing static number hidden in
 // OpMTest function. That value is set into the optional argument.
 // That way, we will know if operations is called once or twice.


        


More information about the Mlir-commits mailing list