[PATCH] D74705: [mlir][quantizer] Support quantizing sparse tensors
River Riddle via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 10 15:54:13 PDT 2020
rriddle added inline comments.
================
Comment at: mlir/include/mlir/IR/StandardTypes.h:239
+ /// Returns the same kind of type with the same shape and new element type.
+ ShapedType withElementType(Type newElementType) const;
+
----------------
stellaraccident wrote:
> Nice - thank you for not just continuing to repeat this pattern at all call sites (note: when a lot of the code was written, there was no ShapedType in the hierarchy.
>
> River: any objection to this addition?
Seems reasonable, thanks.
================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:62
}
- // TODO(fengliuai): handles sparse elements attribute
+ if (auto attr = realValue.dyn_cast<SparseElementsAttr>()) {
+ return convert(attr);
----------------
nit: Please remove trivial braces.
================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:98
+ ArrayRef<uint64_t> index) {
+ // Duplicates ElementsAttr::getFlattenedIndex logic
+ size_t rank = shape.size();
----------------
stellaraccident wrote:
> River: do you have any preference about duplicating this code vs updating the attribute API in some way?
Yeah, can we move this to Support/MathExtras.h and de-duplicate?
================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:114
+ size_t dimSize = type.getDimSize(quantizationDim);
+ if (dimSize != scales.size()) {
+ return {};
----------------
nit: please drop trivial braces, here and below.
================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:127
+ SmallVector<APInt, 32> newValues;
+ // if all zero-points are 0, emit a sparse attribute with the same indices
+ if (llvm::all_of(zeroPoints,
----------------
nit: Start comments with a capital letter, and use punctuation. There are a few other places as well.
================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:135
+ indicesAttr
+ .getValue<APInt>({i, static_cast<uint64_t>(quantizationDim)})
+ .getZExtValue();
----------------
nit: The indices of a SparseElementsAttr are guaranteed to be i64, so you can simply all of these by using `getValue<int64_t>`.
================
Comment at: mlir/lib/IR/StandardTypes.cpp:211
+ default:
+ return {};
+ }
----------------
nit: Use `llvm_unreachable(...)` here instead. This method should cover all of the derived classes.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D74705/new/
https://reviews.llvm.org/D74705
More information about the llvm-commits
mailing list