[Mlir-commits] [mlir] 323fd11 - [mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Oct 13 12:41:43 PDT 2020
Author: Alexander Belyaev
Date: 2020-10-13T21:41:25+02:00
New Revision: 323fd11df7718e68c37f9220a8e1056bb56778cf
URL: https://github.com/llvm/llvm-project/commit/323fd11df7718e68c37f9220a8e1056bb56778cf
DIFF: https://github.com/llvm/llvm-project/commit/323fd11df7718e68c37f9220a8e1056bb56778cf.diff
LOG: [mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM.
For some reason the variable `cumulativeSizeInBytes` in
`getCumulativeSizeInBytes` was actually storing number of elements. I decided
to fix it and refactor the function a bit.
Differential Revision: https://reviews.llvm.org/D89336
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 645f4cd26581..36734f809175 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -470,6 +470,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
Value getSizeInBytes(Location loc, Type type,
ConversionPatternRewriter &rewriter) const;
+ /// Computes total number of elements for the given shape.
+ Value getNumElements(Location loc, ArrayRef<Value> shape,
+ ConversionPatternRewriter &rewriter) const;
+
/// Computes total size in bytes of to store the given shape.
Value getCumulativeSizeInBytes(Location loc, Type elementType,
ArrayRef<Value> shape,
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index e042fc3d1c4e..3fe60f5e88d4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -979,18 +979,23 @@ Value ConvertToLLVMPattern::getSizeInBytes(
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
}
-Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
- Location loc, Type elementType, ArrayRef<Value> shape,
+Value ConvertToLLVMPattern::getNumElements(
+ Location loc, ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const {
// Compute the total number of memref elements.
- Value cumulativeSizeInBytes =
+ Value numElements =
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
for (unsigned i = 1, e = shape.size(); i < e; ++i)
- cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, shape[i]});
- auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
- return rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
+ numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
+ return numElements;
+}
+
+Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
+ Location loc, Type elementType, ArrayRef<Value> shape,
+ ConversionPatternRewriter &rewriter) const {
+ Value numElements = this->getNumElements(loc, shape, rewriter);
+ Value elementSize = this->getSizeInBytes(loc, elementType, rewriter);
+ return rewriter.create<LLVM::MulOp>(loc, numElements, elementSize);
}
/// Creates and populates the memref descriptor struct given all its fields.
More information about the Mlir-commits
mailing list