[Mlir-commits] [mlir] 9478bf0 - [mlir] Introduce `trailingNDimsContiguous` for MemRefs (#78247)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 17 00:47:13 PST 2024


Author: Andrzej WarzyƄski
Date: 2024-02-17T08:47:10Z
New Revision: 9478bf0ce625a5845139b0c9e3bb41ef88d2f005

URL: https://github.com/llvm/llvm-project/commit/9478bf0ce625a5845139b0c9e3bb41ef88d2f005
DIFF: https://github.com/llvm/llvm-project/commit/9478bf0ce625a5845139b0c9e3bb41ef88d2f005.diff

LOG: [mlir] Introduce `trailingNDimsContiguous` for MemRefs (#78247)

Extracts logic from `vector::isContiguousSlice` to check whether
the trailing dim of a memref are contiguous into a dedicated hook
in BuiitinTypes.{h|cpp}.

Follow-up for https://github.com/llvm/llvm-project/pull/76848.

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/lib/IR/BuiltinTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829..2361cf1371237b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
 /// stride. Also return "true" for types with no strides.
 bool isLastMemrefDimUnitStride(MemRefType type);
 
+/// Return "true" if the last N dimensions of the given type are contiguous.
+///
+/// Examples:
+///   - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
+///   considering both _all_ and _only_ the trailing 3 dims,
+///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
+///   considering the trailing 3 dims.
+///
+bool trailingNDimsContiguous(MemRefType type, int64_t n);
+
 } // namespace mlir
 
 #endif // MLIR_IR_BUILTINTYPES_H

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 377f3d8c557474..cfa4a6e93a4a7c 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
   auto vecRank = vectorType.getRank();
 
-  // Extract the trailing dims and strides of the input memref
-  auto memrefShape = memrefType.getShape().take_back(vecRank);
-  int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
-  memrefType.getLayout().isIdentity();
-
-  // TODO: Add support for memref with trailing dynamic shapes. Memrefs
-  // with leading dynamic dimensions are already supported.
-  if (ShapedType::isDynamicShape(memrefShape))
+  if (!trailingNDimsContiguous(memrefType, vecRank))
     return false;
 
-  // Cond 1: Check whether `memrefType` is contiguous.
-  if (!strides.empty()) {
-    // Cond 1.1: A contiguous memref will always have a unit trailing stride.
-    if (strides.back() != 1)
-      return false;
-
-    // Cond 1.2: Strides of a contiguous memref have to match the flattened
-    // dims.
-    strides = strides.drop_back(1);
-    SmallVector<int64_t> flattenedDims;
-    for (size_t i = 1; i < memrefShape.size(); i++)
-      flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
-
-    if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
-      return false;
-  }
+  // Extract the trailing dims and strides of the input memref
+  auto memrefShape = memrefType.getShape().take_back(vecRank);
 
-  // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
+  // Compare the dims of `vectorType` against `memrefType` (in reverse).
   // In the most basic case, all dims will match.
   auto firstNonMatchingDim =
       std::mismatch(vectorShape.rbegin(), vectorShape.rend(),

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1794b38478a72d..a2738946de410e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -967,3 +967,35 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
   auto successStrides = getStridesAndOffset(type, strides, offset);
   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
 }
+
+bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
+  if (!isLastMemrefDimUnitStride(type))
+    return false;
+
+  auto memrefShape = type.getShape().take_back(n);
+  if (ShapedType::isDynamicShape(memrefShape))
+    return false;
+
+  if (type.getLayout().isIdentity())
+    return true;
+
+  int64_t offset;
+  SmallVector<int64_t> stridesFull;
+  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
+    return false;
+  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
+
+  if (strides.empty())
+    return true;
+
+  // Check whether strides match "flattened" dims.
+  SmallVector<int64_t> flattenedDims;
+  auto dimProduct = 1;
+  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
+    dimProduct *= dim;
+    flattenedDims.push_back(dimProduct);
+  }
+
+  strides = strides.drop_back(1);
+  return llvm::equal(strides, llvm::reverse(flattenedDims));
+}


        


More information about the Mlir-commits mailing list