[Mlir-commits] [mlir] 54691a5 - [MLIR] Add affine.load fold hook on global constant memrefs

Uday Bondhugula llvmlistbot at llvm.org
Thu Mar 17 10:59:48 PDT 2022


Author: Uday Bondhugula
Date: 2022-03-17T23:27:43+05:30
New Revision: 54691a58db55cb4ca3ced4ede93bca23eb4bd3c6

URL: https://github.com/llvm/llvm-project/commit/54691a58db55cb4ca3ced4ede93bca23eb4bd3c6
DIFF: https://github.com/llvm/llvm-project/commit/54691a58db55cb4ca3ced4ede93bca23eb4bd3c6.diff

LOG: [MLIR] Add affine.load fold hook on global constant memrefs

Fold affine.load ops on global constant memrefs when indices are all
constant.

Reviewed By: ayzhuang

Differential Revision: https://reviews.llvm.org/D120612

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index dc04c461ec785..5c4fbe291b5d2 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -825,7 +825,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
     The `memref.global` operation declares or defines a named global memref
     variable. The backing memory for the variable is allocated statically and is
     described by the type of the variable (which should be a statically shaped
-    memref type). The operation is a declaration if no `inital_value` is
+    memref type). The operation is a declaration if no `initial_value` is
     specified, else it is a definition. The `initial_value` can either be a unit
     attribute to represent a definition of an uninitialized global variable, or
     an elements attribute to represent the definition of a global variable with
@@ -878,6 +878,9 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
      bool isUninitialized() {
        return !isExternal() && initial_value().getValue().isa<UnitAttr>();
      }
+     /// Returns the constant initial value if the memref.global is a constant,
+     /// or null otherwise.
+     ElementsAttr getConstantInitValue();
   }];
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 6a2e2430549fd..a43da29a3f431 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -8,17 +8,13 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
@@ -2400,7 +2396,30 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
   /// load(memrefcast) -> load
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
-  return OpFoldResult();
+
+  // Fold load from a global constant memref.
+  auto getGlobalOp = memref().getDefiningOp<memref::GetGlobalOp>();
+  if (!getGlobalOp)
+    return {};
+  // Get to the memref.global defining the symbol.
+  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
+  if (!symbolTableOp)
+    return {};
+  auto global = dyn_cast_or_null<memref::GlobalOp>(
+      SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
+  if (!global)
+    return {};
+  if (auto cstAttr =
+          global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
+    // We can fold only if we know the indices.
+    if (!getAffineMap().isConstant())
+      return {};
+    auto indices = llvm::to_vector<4>(
+        llvm::map_range(getAffineMap().getConstantResults(),
+                        [](int64_t v) -> uint64_t { return v; }));
+    return cstAttr.getValues<Attribute>()[indices];
+  }
+  return {};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 18d1b7db744a7..b5c771e9f276c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1275,6 +1275,13 @@ LogicalResult GlobalOp::verify() {
   return success();
 }
 
+ElementsAttr GlobalOp::getConstantInitValue() {
+  auto initVal = initial_value();
+  if (constant() && initVal.hasValue())
+    return initVal.getValue().cast<ElementsAttr>();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // GetGlobalOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d0e9d1dbf95d5..c351414dc9274 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1080,3 +1080,21 @@ func @canonicalize_single_min_max(%i0: index, %i1: index) -> (index, index) {
 
   return %0, %1: index, index
 }
+
+// -----
+
+module {
+  memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
+  // CHECK-LABEL: func @fold_const_init_global_memref
+  func @fold_const_init_global_memref() -> (f32, f32) {
+    %m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
+    %v0 = affine.load %m[0, 0, 0] : memref<1x5x1xf32>
+    %v1 = affine.load %m[0, 1, 0] : memref<1x5x1xf32>
+    return %v0, %v1 : f32, f32
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
+    // CHECK-NEXT: return
+    // CHECK-SAME: %[[C0]]
+    // CHECK-SAME: %[[C1]]
+  }
+}


        


More information about the Mlir-commits mailing list