[Mlir-commits] [mlir] 9e54d5e - [mlir] NFC - Basic improvements to IndexingUtils (product and sum)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jul 14 07:44:42 PDT 2023
Author: Nicolas Vasilache
Date: 2023-07-14T16:41:31+02:00
New Revision: 9e54d5e7786c5c901cf6c129ab44e90718fce1eb
URL: https://github.com/llvm/llvm-project/commit/9e54d5e7786c5c901cf6c129ab44e90718fce1eb
DIFF: https://github.com/llvm/llvm-project/commit/9e54d5e7786c5c901cf6c129ab44e90718fce1eb.diff
LOG: [mlir] NFC - Basic improvements to IndexingUtils (product and sum)
Added:
Modified:
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/lib/Dialect/Utils/IndexingUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 39ae6dc015651e..72becd8cc01c43 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -52,13 +52,21 @@ inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
ArrayRef<int64_t> v2);
+/// Self-explicit.
+int64_t computeSum(ArrayRef<int64_t> basis);
+
+/// Self-explicit.
+int64_t computeProduct(ArrayRef<int64_t> basis);
+
/// Return the number of elements of basis (i.e. the max linear index).
/// Return `0` if `basis` is empty.
///
/// `basis` elements are asserted to be non-negative.
///
/// Return `0` if `basis` is empty.
-int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
+inline int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+ return computeProduct(basis);
+}
/// Return the linearized index of 'offsets' w.r.t. 'basis'.
///
@@ -130,6 +138,12 @@ inline SmallVector<AffineExpr> computeStrides(ArrayRef<AffineExpr> sizes) {
SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
ArrayRef<AffineExpr> v2);
+/// Self-explicit.
+AffineExpr computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
+
+/// Self-explicit.
+AffineExpr computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
+
/// Return the number of elements of basis (i.e. the max linear index).
/// Return `0` if `basis` is empty.
///
@@ -140,7 +154,10 @@ SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
/// `basis` elements are expected to bind to non-negative values.
///
/// Return the `0` AffineConstantExpr if `basis` is empty.
-AffineExpr computeMaxLinearIndex(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
+inline AffineExpr computeMaxLinearIndex(MLIRContext *ctx,
+ ArrayRef<AffineExpr> basis) {
+ return computeProduct(ctx, basis);
+}
/// Return the linearized index of 'offsets' w.r.t. 'basis'.
///
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index e3efa9ca97e306..5821876139d064 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -81,7 +81,15 @@ SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
return computeElementwiseMulImpl(v1, v2);
}
-int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+int64_t mlir::computeSum(ArrayRef<int64_t> basis) {
+ assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
+ "basis must be nonnegative");
+ if (basis.empty())
+ return 0;
+ return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());
+}
+
+int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
"basis must be nonnegative");
if (basis.empty())
@@ -149,8 +157,15 @@ SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
return computeElementwiseMulImpl(v1, v2);
}
-AffineExpr mlir::computeMaxLinearIndex(MLIRContext *ctx,
- ArrayRef<AffineExpr> basis) {
+AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
+ if (basis.empty())
+ return getAffineConstantExpr(0, ctx);
+ return std::accumulate(basis.begin(), basis.end(),
+ getAffineConstantExpr(1, ctx),
+ std::plus<AffineExpr>());
+}
+
+AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
if (basis.empty())
return getAffineConstantExpr(0, ctx);
return std::accumulate(basis.begin(), basis.end(),
More information about the Mlir-commits
mailing list