[Mlir-commits] [mlir] 5576579 - Update affine.load folding hook to fold global splat constant loads
Uday Bondhugula
llvmlistbot at llvm.org
Fri Mar 25 18:19:50 PDT 2022
Author: Uday Bondhugula
Date: 2022-03-26T06:44:03+05:30
New Revision: 5576579c865d481a4f32fe3d183e32d8807432e4
URL: https://github.com/llvm/llvm-project/commit/5576579c865d481a4f32fe3d183e32d8807432e4
DIFF: https://github.com/llvm/llvm-project/commit/5576579c865d481a4f32fe3d183e32d8807432e4.diff
LOG: Update affine.load folding hook to fold global splat constant loads
Enhance affine.load folding hook to fold loads on global splat constant
memrefs.
Differential Revision: https://reviews.llvm.org/D122292
Added:
Modified:
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 829949180539b..6d4d041594ac9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2410,17 +2410,22 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
if (!global)
return {};
- if (auto cstAttr =
- global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
- // We can fold only if we know the indices.
- if (!getAffineMap().isConstant())
- return {};
- auto indices = llvm::to_vector<4>(
- llvm::map_range(getAffineMap().getConstantResults(),
- [](int64_t v) -> uint64_t { return v; }));
- return cstAttr.getValues<Attribute>()[indices];
- }
- return {};
+
+ // Check if the global memref is a constant.
+ auto cstAttr =
+ global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
+ if (!cstAttr)
+ return {};
+ // If it's a splat constant, we can fold irrespective of indices.
+ if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
+ return splatAttr.getSplatValue<Attribute>();
+ // Otherwise, we can fold only if we know the indices.
+ if (!getAffineMap().isConstant())
+ return {};
+ auto indices = llvm::to_vector<4>(
+ llvm::map_range(getAffineMap().getConstantResults(),
+ [](int64_t v) -> uint64_t { return v; }));
+ return cstAttr.getValues<Attribute>()[indices];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 8b4a5ffaba1ab..7cebf8af9c19b 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1101,6 +1101,7 @@ func @canonicalize_multi_min_max(%i0: index, %i1: index) -> (index, index) {
module {
memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
+ memref.global "private" constant @__constant_32x64xf32 : memref<32x64xf32> = dense<0.000000e+00>
// CHECK-LABEL: func @fold_const_init_global_memref
func @fold_const_init_global_memref() -> (f32, f32) {
%m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
@@ -1109,8 +1110,21 @@ module {
return %v0, %v1 : f32, f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
// CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
- // CHECK-NEXT: return
- // CHECK-SAME: %[[C0]]
- // CHECK-SAME: %[[C1]]
+ // CHECK-NEXT: return %[[C0]], %[[C1]]
+ }
+
+ // CHECK-LABEL: func @fold_const_splat_global
+ func @fold_const_splat_global() -> memref<32x64xf32> {
+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ %m = memref.get_global @__constant_32x64xf32 : memref<32x64xf32>
+ %s = memref.alloc() : memref<32x64xf32>
+ affine.for %i = 0 to 32 {
+ affine.for %j = 0 to 64 {
+ %v = affine.load %m[%i, %j] : memref<32x64xf32>
+ affine.store %v, %s[%i, %j] : memref<32x64xf32>
+ // CHECK: affine.store %[[CST]], %{{.*}}
+ }
+ }
+ return %s: memref<32x64xf32>
}
}
More information about the Mlir-commits
mailing list