[Mlir-commits] [mlir] [mlir] Introduce `trailingNDimsContiguous` for MemRefs (PR #78247)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Jan 16 12:33:37 PST 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/78247
>From 86e8587e1ca1aaeb3494bb0c8d922ab33a18fe86 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 16 Jan 2024 09:44:09 +0000
Subject: [PATCH 1/2] [mlir] Introduce `trailingNDimsContiguous` for MemRefs
Extracts logic to check whether the trailing dim of a memref are
contiguous into a dedicated hook in BuiitinTypes.{h|cpp}.
Follow-up for https://github.com/llvm/llvm-project/pull/76848.
---
mlir/include/mlir/IR/BuiltinTypes.h | 10 ++++++
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 33 +++----------------
mlir/lib/IR/BuiltinTypes.cpp | 28 ++++++++++++++++
3 files changed, 42 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829b..2361cf1371237bc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
/// stride. Also return "true" for types with no strides.
bool isLastMemrefDimUnitStride(MemRefType type);
+/// Return "true" if the last N dimensions of the given type are contiguous.
+///
+/// Examples:
+/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
+/// considering both _all_ and _only_ the trailing 3 dims,
+/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
+/// considering the trailing 3 dims.
+///
+bool trailingNDimsContiguous(MemRefType type, int64_t n);
+
} // namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 377f3d8c557474b..cfa4a6e93a4a7c7 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank();
- // Extract the trailing dims and strides of the input memref
- auto memrefShape = memrefType.getShape().take_back(vecRank);
- int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
- memrefType.getLayout().isIdentity();
-
- // TODO: Add support for memref with trailing dynamic shapes. Memrefs
- // with leading dynamic dimensions are already supported.
- if (ShapedType::isDynamicShape(memrefShape))
+ if (!trailingNDimsContiguous(memrefType, vecRank))
return false;
- // Cond 1: Check whether `memrefType` is contiguous.
- if (!strides.empty()) {
- // Cond 1.1: A contiguous memref will always have a unit trailing stride.
- if (strides.back() != 1)
- return false;
-
- // Cond 1.2: Strides of a contiguous memref have to match the flattened
- // dims.
- strides = strides.drop_back(1);
- SmallVector<int64_t> flattenedDims;
- for (size_t i = 1; i < memrefShape.size(); i++)
- flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
-
- if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
- return false;
- }
+ // Extract the trailing dims and strides of the input memref
+ auto memrefShape = memrefType.getShape().take_back(vecRank);
- // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
+ // Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d45280353..c6a919a90554c49 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "TypeDetail.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -967,3 +968,30 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
auto successStrides = getStridesAndOffset(type, strides, offset);
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
}
+
+bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
+ if (!isLastMemrefDimUnitStride(type))
+ return false;
+
+ auto memrefShape = type.getShape().take_back(n);
+ int64_t offset;
+ SmallVector<int64_t> stridesFull;
+ if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
+ return false;
+ auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
+
+ if (ShapedType::isDynamicShape(memrefShape))
+ return false;
+
+ if (strides.empty())
+ return true;
+
+ // Strides of a contiguous memref have to match the flattened
+ // dims.
+ strides = strides.drop_back(1);
+ SmallVector<int64_t> flattenedDims;
+ for (size_t i = 1; i < memrefShape.size(); i++)
+ flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+
+ return llvm::equal(strides, llvm::reverse(flattenedDims));
+}
>From 02c208ccc9969b620efb08eae4e31c298d6e0e1e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 16 Jan 2024 20:14:56 +0000
Subject: [PATCH 2/2] fixup! [mlir] Introduce `trailingNDimsContiguous` for
MemRefs
Fix Windows build
---
mlir/lib/IR/BuiltinTypes.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index c6a919a90554c49..8b0da432cecb981 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -8,7 +8,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -986,12 +985,14 @@ bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
if (strides.empty())
return true;
- // Strides of a contiguous memref have to match the flattened
- // dims.
- strides = strides.drop_back(1);
+ // Check whether strides match "flattened" dims.
SmallVector<int64_t> flattenedDims;
- for (size_t i = 1; i < memrefShape.size(); i++)
- flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+ auto dimProduct = 1;
+ for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
+ dimProduct *= dim;
+ flattenedDims.push_back(dimProduct);
+ }
+ strides = strides.drop_back(1);
return llvm::equal(strides, llvm::reverse(flattenedDims));
}
More information about the Mlir-commits
mailing list