[Mlir-commits] [mlir] 323fd11 - [mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM.

Alexander Belyaev llvmlistbot at llvm.org
Tue Oct 13 12:41:43 PDT 2020


Author: Alexander Belyaev
Date: 2020-10-13T21:41:25+02:00
New Revision: 323fd11df7718e68c37f9220a8e1056bb56778cf

URL: https://github.com/llvm/llvm-project/commit/323fd11df7718e68c37f9220a8e1056bb56778cf
DIFF: https://github.com/llvm/llvm-project/commit/323fd11df7718e68c37f9220a8e1056bb56778cf.diff

LOG: [mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM.

For some reason the variable `cumulativeSizeInBytes` in
`getCumulativeSizeInBytes` was actually storing number of elements. I decided
to fix it and refactor the function a bit.

Differential Revision: https://reviews.llvm.org/D89336

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 645f4cd26581..36734f809175 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -470,6 +470,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
   Value getSizeInBytes(Location loc, Type type,
                        ConversionPatternRewriter &rewriter) const;
 
+  /// Computes total number of elements for the given shape.
+  Value getNumElements(Location loc, ArrayRef<Value> shape,
+                       ConversionPatternRewriter &rewriter) const;
+
   /// Computes total size in bytes of to store the given shape.
   Value getCumulativeSizeInBytes(Location loc, Type elementType,
                                  ArrayRef<Value> shape,

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index e042fc3d1c4e..3fe60f5e88d4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -979,18 +979,23 @@ Value ConvertToLLVMPattern::getSizeInBytes(
   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
 }
 
-Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
-    Location loc, Type elementType, ArrayRef<Value> shape,
+Value ConvertToLLVMPattern::getNumElements(
+    Location loc, ArrayRef<Value> shape,
     ConversionPatternRewriter &rewriter) const {
   // Compute the total number of memref elements.
-  Value cumulativeSizeInBytes =
+  Value numElements =
       shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
   for (unsigned i = 1, e = shape.size(); i < e; ++i)
-    cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
-        loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, shape[i]});
-  auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
-  return rewriter.create<LLVM::MulOp>(
-      loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
+    numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
+  return numElements;
+}
+
+Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
+    Location loc, Type elementType, ArrayRef<Value> shape,
+    ConversionPatternRewriter &rewriter) const {
+  Value numElements = this->getNumElements(loc, shape, rewriter);
+  Value elementSize = this->getSizeInBytes(loc, elementType, rewriter);
+  return rewriter.create<LLVM::MulOp>(loc, numElements, elementSize);
 }
 
 /// Creates and populates the memref descriptor struct given all its fields.


        


More information about the Mlir-commits mailing list