[Mlir-commits] [mlir] d25e022 - [MLIR][Affine] Fix assumption on int type in memref elt size method

Uday Bondhugula llvmlistbot at llvm.org
Wed Mar 22 03:58:24 PDT 2023


Author: Uday Bondhugula
Date: 2023-03-22T16:23:59+05:30
New Revision: d25e022cd19b83c22a6022edb78c4b97a5fc1b49

URL: https://github.com/llvm/llvm-project/commit/d25e022cd19b83c22a6022edb78c4b97a5fc1b49
DIFF: https://github.com/llvm/llvm-project/commit/d25e022cd19b83c22a6022edb78c4b97a5fc1b49.diff

LOG: [MLIR][Affine] Fix assumption on int type in memref elt size method

Fix assumption on memref element type being int/float in memref elt size
related method and affine data copy generate.

Fixes https://github.com/llvm/llvm-project/issues/61310

Differential Revision: https://reviews.llvm.org/D146495

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
    mlir/include/mlir/Dialect/Affine/LoopUtils.h
    mlir/lib/Dialect/Affine/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
    mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
    mlir/test/Dialect/Affine/affine-data-copy.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 7f6eced071b15..99e511f152618 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -359,9 +359,9 @@ struct MemRefRegion {
   FlatAffineValueConstraints cst;
 };
 
-/// Returns the size of memref data in bytes if it's statically shaped,
-/// std::nullopt otherwise.
-std::optional<uint64_t> getMemRefSizeInBytes(MemRefType memRefType);
+/// Returns the size of a memref with element type int or float in bytes if it's
+/// statically shaped, std::nullopt otherwise.
+std::optional<uint64_t> getIntOrFloatMemRefSizeInBytes(MemRefType memRefType);
 
 /// Checks a load or store op for an out of bound access; returns failure if the
 /// access is out of bounds along any of the dimensions, success otherwise.
@@ -378,6 +378,10 @@ unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b);
 std::optional<int64_t> getMemoryFootprintBytes(AffineForOp forOp,
                                                int memorySpace = -1);
 
+/// Returns the memref's element type's size in bytes where the elemental type
+/// is an int or float or a vector of such types.
+std::optional<int64_t> getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType);
+
 /// Simplify the integer set by simplifying the underlying affine expressions by
 /// flattening and some simple inference. Also, drop any duplicate constraints.
 /// Returns the simplified integer set. This method runs in time linear in the

diff  --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 828f06129167c..8bab83bc0d992 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -184,7 +184,9 @@ struct AffineCopyOptions {
 /// available for processing this block range. When 'filterMemRef' is specified,
 /// copies are only generated for the provided MemRef. Returns success if the
 /// explicit copying succeeded for all memrefs on which affine load/stores were
-/// encountered.
+/// encountered. For memrefs for whose element types a size in bytes can't be
+/// computed (`index` type), their capacity is not accounted for and the
+/// `fastMemCapacityBytes` copy option would be non-functional in such cases.
 LogicalResult affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
                                      const AffineCopyOptions &copyOptions,
                                      std::optional<Value> filterMemRef,

diff  --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 8ab219af98c9d..db4fa354d4c2d 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -594,16 +594,21 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
   return success();
 }
 
-static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+std::optional<int64_t>
+mlir::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
   auto elementType = memRefType.getElementType();
 
   unsigned sizeInBits;
   if (elementType.isIntOrFloat()) {
     sizeInBits = elementType.getIntOrFloatBitWidth();
+  } else if (auto vectorType = elementType.dyn_cast<VectorType>()) {
+    if (vectorType.getElementType().isIntOrFloat())
+      sizeInBits =
+          vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+    else
+      return std::nullopt;
   } else {
-    auto vectorType = elementType.cast<VectorType>();
-    sizeInBits =
-        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+    return std::nullopt;
   }
   return llvm::divideCeil(sizeInBits, 8);
 }
@@ -629,23 +634,29 @@ std::optional<int64_t> MemRefRegion::getRegionSize() {
     LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
     return std::nullopt;
   }
-  return getMemRefEltSizeInBytes(memRefType) * *numElements;
+  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
+  if (!eltSize)
+    return std::nullopt;
+  return *eltSize * *numElements;
 }
 
 /// Returns the size of memref data in bytes if it's statically shaped,
 /// std::nullopt otherwise.  If the element of the memref has vector type, takes
 /// into account size of the vector as well.
 //  TODO: improve/complete this when we have target data.
-std::optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
+std::optional<uint64_t>
+mlir::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
   if (!memRefType.hasStaticShape())
     return std::nullopt;
   auto elementType = memRefType.getElementType();
   if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
     return std::nullopt;
 
-  uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
+  auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
+  if (!sizeInBytes)
+    return std::nullopt;
   for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
-    sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
+    sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
   }
   return sizeInBytes;
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 79e8949c92a56..f398526da34b5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -901,21 +901,6 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   node->op = newRootForOp;
 }
 
-//  TODO: improve/complete this when we have target data.
-static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
-  auto elementType = memRefType.getElementType();
-
-  unsigned sizeInBits;
-  if (elementType.isIntOrFloat()) {
-    sizeInBits = elementType.getIntOrFloatBitWidth();
-  } else {
-    auto vectorType = elementType.cast<VectorType>();
-    sizeInBits =
-        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
-  }
-  return llvm::divideCeil(sizeInBits, 8);
-}
-
 // Creates and returns a private (single-user) memref for fused loop rooted
 // at 'forOp', with (potentially reduced) memref size based on the
 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -976,7 +961,9 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
 
   // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
   // by 'srcStoreOpInst'.
-  uint64_t bufSize = getMemRefEltSizeInBytes(oldMemRefType) * *numElements;
+  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
+  assert(eltSize && "memrefs with size elt types expected");
+  uint64_t bufSize = *eltSize * *numElements;
   unsigned newMemSpace;
   if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
     newMemSpace = *fastMemorySpace;

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index bcd87fcc570a3..38d660d4ff90b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2181,7 +2181,11 @@ static LogicalResult generateCopy(
     // Record it.
     fastBufferMap[memref] = fastMemRef;
     // fastMemRefType is a constant shaped memref.
-    *sizeInBytes = *getMemRefSizeInBytes(fastMemRefType);
+    auto maySizeInBytes = getIntOrFloatMemRefSizeInBytes(fastMemRefType);
+    // We don't account for things of unknown size.
+    if (!maySizeInBytes)
+      maySizeInBytes = 0;
+
     LLVM_DEBUG(emitRemarkForBlock(*block)
                << "Creating fast buffer of type " << fastMemRefType
                << " and size " << llvm::divideCeil(*sizeInBytes, 1024)

diff  --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir
index 22fbd7306d253..fe3b4a206e2b9 100644
--- a/mlir/test/Dialect/Affine/affine-data-copy.mlir
+++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir
@@ -310,3 +310,26 @@ func.func @affine_parallel(%85:memref<2x5x4x2xi64>) {
   // CHECK-NEXT:    affine.parallel
   return
 }
+
+// CHECK-LABEL: func @index_elt_type
+func.func @index_elt_type(%arg0: memref<1x2x4x8xindex>) {
+  affine.for %arg1 = 0 to 1 {
+    affine.for %arg2 = 0 to 2 {
+      affine.for %arg3 = 0 to 4 {
+        affine.for %arg4 = 0 to 8 {
+          affine.store %arg4, %arg0[%arg1, %arg2, %arg3, %arg4] : memref<1x2x4x8xindex>
+        }
+      }
+    }
+  }
+
+  // CHECK:     affine.for %{{.*}} = 0 to 1
+  // CHECK-NEXT:  affine.for %{{.*}} = 0 to 2
+  // CHECK-NEXT:    affine.for %{{.*}} = 0 to 4
+  // CHECK-NEXT:      affine.for %{{.*}} = 0 to 8
+
+  // CHECK:     affine.for %{{.*}} = 0 to 2
+  // CHECK-NEXT:  affine.for %{{.*}} = 0 to 4
+  // CHECK-NEXT:    affine.for %{{.*}} = 0 to 8
+  return
+}


        


More information about the Mlir-commits mailing list