[Mlir-commits] [mlir] 5576579 - Update affine.load folding hook to fold global splat constant loads

Uday Bondhugula llvmlistbot at llvm.org
Fri Mar 25 18:19:50 PDT 2022


Author: Uday Bondhugula
Date: 2022-03-26T06:44:03+05:30
New Revision: 5576579c865d481a4f32fe3d183e32d8807432e4

URL: https://github.com/llvm/llvm-project/commit/5576579c865d481a4f32fe3d183e32d8807432e4
DIFF: https://github.com/llvm/llvm-project/commit/5576579c865d481a4f32fe3d183e32d8807432e4.diff

LOG: Update affine.load folding hook to fold global splat constant loads

Enhance affine.load folding hook to fold loads on global splat constant
memrefs.

Differential Revision: https://reviews.llvm.org/D122292

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 829949180539b..6d4d041594ac9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2410,17 +2410,22 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
       SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
   if (!global)
     return {};
-  if (auto cstAttr =
-          global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
-    // We can fold only if we know the indices.
-    if (!getAffineMap().isConstant())
-      return {};
-    auto indices = llvm::to_vector<4>(
-        llvm::map_range(getAffineMap().getConstantResults(),
-                        [](int64_t v) -> uint64_t { return v; }));
-    return cstAttr.getValues<Attribute>()[indices];
-  }
-  return {};
+
+  // Check if the global memref is a constant.
+  auto cstAttr =
+      global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
+  if (!cstAttr)
+    return {};
+  // If it's a splat constant, we can fold irrespective of indices.
+  if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
+    return splatAttr.getSplatValue<Attribute>();
+  // Otherwise, we can fold only if we know the indices.
+  if (!getAffineMap().isConstant())
+    return {};
+  auto indices = llvm::to_vector<4>(
+      llvm::map_range(getAffineMap().getConstantResults(),
+                      [](int64_t v) -> uint64_t { return v; }));
+  return cstAttr.getValues<Attribute>()[indices];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 8b4a5ffaba1ab..7cebf8af9c19b 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1101,6 +1101,7 @@ func @canonicalize_multi_min_max(%i0: index, %i1: index) -> (index, index) {
 
 module {
   memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
+  memref.global "private" constant @__constant_32x64xf32 : memref<32x64xf32> = dense<0.000000e+00>
   // CHECK-LABEL: func @fold_const_init_global_memref
   func @fold_const_init_global_memref() -> (f32, f32) {
     %m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
@@ -1109,8 +1110,21 @@ module {
     return %v0, %v1 : f32, f32
     // CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
     // CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
-    // CHECK-NEXT: return
-    // CHECK-SAME: %[[C0]]
-    // CHECK-SAME: %[[C1]]
+    // CHECK-NEXT: return %[[C0]], %[[C1]]
+  }
+
+  // CHECK-LABEL: func @fold_const_splat_global
+  func @fold_const_splat_global() -> memref<32x64xf32> {
+    // CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    %m = memref.get_global @__constant_32x64xf32 : memref<32x64xf32>
+    %s = memref.alloc() : memref<32x64xf32>
+    affine.for %i = 0 to 32 {
+      affine.for %j = 0 to 64 {
+        %v = affine.load %m[%i, %j] : memref<32x64xf32>
+        affine.store %v, %s[%i, %j] : memref<32x64xf32>
+        // CHECK: affine.store %[[CST]], %{{.*}}
+      }
+    }
+    return %s: memref<32x64xf32>
   }
 }


        


More information about the Mlir-commits mailing list