[PATCH] D74705: [mlir][quantizer] Support quantizing sparse tensors

River Riddle via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 10 15:54:13 PDT 2020


rriddle added inline comments.


================
Comment at: mlir/include/mlir/IR/StandardTypes.h:239
+  /// Returns the same kind of type with the same shape and new element type.
+  ShapedType withElementType(Type newElementType) const;
+
----------------
stellaraccident wrote:
> Nice - thank you for not just continuing to repeat this pattern at all call sites (note: when a lot of the code was written, there was no ShapedType in the hierarchy.
> 
> River: any objection to this addition?
Seems reasonable, thanks.


================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:62
   }
-  // TODO(fengliuai): handles sparse elements attribute
+  if (auto attr = realValue.dyn_cast<SparseElementsAttr>()) {
+    return convert(attr);
----------------
nit: Please remove trivial braces.


================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:98
+                                  ArrayRef<uint64_t> index) {
+  // Duplicates ElementsAttr::getFlattenedIndex logic
+  size_t rank = shape.size();
----------------
stellaraccident wrote:
> River: do you have any preference about duplicating this code vs updating the attribute API in some way?
Yeah, can we move this to Support/MathExtras.h and de-duplicate?


================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:114
+    size_t dimSize = type.getDimSize(quantizationDim);
+    if (dimSize != scales.size()) {
+      return {};
----------------
nit: please drop trivial braces, here and below.


================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:127
+    SmallVector<APInt, 32> newValues;
+    // if all zero-points are 0, emit a sparse attribute with the same indices
+    if (llvm::all_of(zeroPoints,
----------------
nit: Start comments with a capital letter, and use punctuation. There are a few other places as well.


================
Comment at: mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp:135
+            indicesAttr
+                .getValue<APInt>({i, static_cast<uint64_t>(quantizationDim)})
+                .getZExtValue();
----------------
nit: The indices of a SparseElementsAttr are guaranteed to be i64, so you can simply all of these by using `getValue<int64_t>`.


================
Comment at: mlir/lib/IR/StandardTypes.cpp:211
+  default:
+    return {};
+  }
----------------
nit: Use `llvm_unreachable(...)` here instead. This method should cover all of the derived classes.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D74705/new/

https://reviews.llvm.org/D74705





More information about the llvm-commits mailing list