[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 25 15:51:46 PDT 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 38d0b2d174efe05504a18988299b4d78d37999b7 cce8171c6d016d823e514ec304f94d2e8c4085c0 --extensions cpp,h -- mlir/include/mlir/Dialect/Quant/Transforms/Passes.h mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h mlir/include/mlir/InitAllDialects.h mlir/include/mlir/InitAllPasses.h mlir/lib/CAPI/Dialect/Quant.cpp mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h mlir/lib/Dialect/Quant/IR/QuantOps.cpp mlir/lib/Dialect/Quant/IR/QuantTypes.cpp mlir/lib/Dialect/Quant/IR/TypeParser.cpp mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp mlir/lib/Dialect/Tosa/IR/TosaOps.cpp mlir/include/mlir/Dialect/Quant/IR/Quant.h mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 6a709488ce..0c68e31900 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -23,7 +23,6 @@
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
-
namespace mlir {
namespace quant {
@@ -43,7 +42,8 @@ namespace {
LogicalResult verifyPerAxisQuantization(Operation *op,
QuantizedType quantizedType,
Type containerType) {
- auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+ auto quantizedPerAxisType =
+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
if (!quantizedPerAxisType)
return success();
@@ -60,7 +60,8 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
if (quantizedDimensionSize != ShapedType::kDynamic &&
- quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+ quantizedDimensionSize !=
+ (int64_t)quantizedPerAxisType.getScales().size())
return op->emitError(
"quantized dimension size does not match number of scales");
@@ -90,8 +91,7 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
return verifyPerAxisQuantization(op, quantizedType, containerType);
}
-} // namespace
-
+} // namespace
//===----------------------------------------------------------------------===//
// Dialect
@@ -107,7 +107,6 @@ void QuantDialect::initialize() {
detail::addBytecodeInterface(this);
}
-
//===----------------------------------------------------------------------===//
// DequantizeCastOp
//===----------------------------------------------------------------------===//
@@ -136,7 +135,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() {
return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
}
-
//===----------------------------------------------------------------------===//
// QuantizeCastOp
//===----------------------------------------------------------------------===//
@@ -166,7 +164,6 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
}
-
//===----------------------------------------------------------------------===//
// StorageCastOp
//===----------------------------------------------------------------------===//
@@ -211,10 +208,8 @@ QuantizedType StorageCastOp::getQuantizedType() {
return cast<QuantizedType>(resultScalarType);
}
-
} // namespace quant
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
-
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 2038a86bec..e5df4b5eab 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 4adeb9218f..6929d8861e 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
return inputType;
}
-// Return the shape of an input value as a list of attributes (static dimensions)
-// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
-// returned. If 'input' is a tensor, its shape is returned.
-SmallVector<OpFoldResult>
-getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+// Return the shape of an input value as a list of attributes (static
+// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
+// list is returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
+ Location loc, Value input) {
if (isa<TensorType>(input.getType()))
return tensor::getMixedSizes(builder, loc, input);
return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
// Turn input size into 1D tensor
auto flatShapeType = shape::getExtentTensorType(context, 1);
- auto flatInputShape = builder.create<tensor::FromElementsOp>(
- loc, flatShapeType, inputSize);
+ auto flatInputShape =
+ builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
// Reshape input tensor into 1D
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType =
RankedTensorType::get({ShapedType::kDynamic}, elementType);
- auto flatInput = builder.create<tensor::ReshapeOp>(
- loc, flatInputType, input, flatInputShape);
+ auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+ flatInputShape);
return std::make_pair(flatInput, inputShape);
}
@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
// - inputShape
// 1D extent tensor containing the shape of the original unranked input.
//
-std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
- Location loc,
- Value input,
- int64_t axis,
- int64_t axisSize) {
+std::pair<Value, Value>
+flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
+ int64_t axis, int64_t axisSize) {
// Get full tensor shape
auto *context = builder.getContext();
auto indexType = builder.getIndexType();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
// Get shape and sizes on left and right of axis
auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
- auto shapeLeft = builder.create<shape::SplitAtOp>(
- loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
- .getResult(0);
- auto sizeLeft = builder.create<shape::NumElementsOp>(
- loc, indexType, shapeLeft);
- auto shapeRight = builder.create<shape::SplitAtOp>(
- loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
- .getResult(1);
- auto sizeRight = builder.create<shape::NumElementsOp>(
- loc, indexType, shapeRight);
+ auto shapeLeft =
+ builder
+ .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+ inputShape, axisValue)
+ .getResult(0);
+ auto sizeLeft =
+ builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
+ auto shapeRight =
+ builder
+ .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+ inputShape, axisNextValue)
+ .getResult(1);
+ auto sizeRight =
+ builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
// Compute flat input shape as a 3-element 1D tensor
auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
auto elementType = inputType.getElementType();
auto flatInputType = RankedTensorType::get(
{ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
- auto flatInput = builder.create<tensor::ReshapeOp>(
- loc, flatInputType, input, flatInputShape);
+ auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+ flatInputShape);
return std::make_pair(flatInput, inputShape);
}
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
auto inputType = cast<RankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto unrankedType = UnrankedTensorType::get(elementType);
- return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
+ return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
+ inputShape);
}
// Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
return builder.getFloatAttr(expressedType, scale);
});
- auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+ auto tensorType =
+ RankedTensorType::get({(int64_t)scales.size()}, expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
}
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
UniformQuantizedPerAxisType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
- auto zeroPointAttrs = llvm::map_to_vector(
- zeroPoints,
- [&](int64_t zeroPoint) -> Attribute {
+ auto zeroPointAttrs =
+ llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
return builder.getIntegerAttr(storageType, zeroPoint);
});
auto tensorType =
@@ -299,7 +302,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
return builder.create<arith::UIToFPOp>(loc, resultType, input);
}
-// Quantize a scalar or ranked tensor value. The stored value is clamped using
+// Quantize a scalar or ranked tensor value. The stored value is clamped using
// the storage bounds encoded in the given quantized type.
//
// See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +311,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
- scale = getScalarOrTensorConstant(
- builder, loc, scale, inputType, inputShape);
+ scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Scale input
auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
@@ -322,8 +324,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
inputShape);
// Convert zero point from storage to expressed type
- zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
- scale.getType(),
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Add zero point to stored value
@@ -334,9 +335,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
// Convert stored value to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), inputType);
- auto storedValueInt = convertFloatToInteger(
- builder, loc, storedValueFloat, storageScalarOrTensorType,
- quantizedType.isSigned());
+ auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
+ storageScalarOrTensorType,
+ quantizedType.isSigned());
// Clamp stored value it if the storage type is bound
auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
@@ -352,12 +353,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
- scale = getScalarOrTensorConstant(
- builder, loc, scale, inputType, inputShape);
+ scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Convert stored value to float
- auto result = convertIntegerToFloat(
- builder, loc, input, scale.getType(), quantizedType.isSigned());
+ auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
+ quantizedType.isSigned());
// Skip unnecessary computations if no zero point is given
if (!matchPattern(zeroPoint, m_Zero())) {
@@ -366,8 +366,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
inputShape);
// Convert zero point from storage to expressed type
- zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
- scale.getType(),
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Subtract zero point to stored value
@@ -501,35 +500,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
auto initShape = tensor::getMixedSizes(builder, loc, input);
Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
- SmallVector<utils::IteratorType> iteratorTypes(
- inputRank, utils::IteratorType::parallel);
+ SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+ utils::IteratorType::parallel);
auto channelAxisAffineMap = AffineMap::get(
inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
SmallVector<AffineMap> indexingMaps{
- builder.getMultiDimIdentityMap(inputRank),
- channelAxisAffineMap,
- channelAxisAffineMap,
- builder.getMultiDimIdentityMap(inputRank)
- };
- auto result = builder.create<linalg::GenericOp>(
- loc,
- init.getType(), // resultType
- ValueRange{input, scales, zeroPoints}, // inputs
- ValueRange{init}, // outputs
- indexingMaps,
- iteratorTypes,
- [&](OpBuilder& builder, Location loc, ValueRange args) {
- assert(args.size() == 4);
- auto input = args[0];
- auto scale = args[1];
- auto zeroPoint = args[2];
-
- auto result = convertRanked(builder, loc, op, input, {}, scale,
- zeroPoint, quantizedType);
-
- builder.create<linalg::YieldOp>(loc, result);
- })
- .getResult(0);
+ builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
+ channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
+ auto result = builder
+ .create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto input = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result =
+ convertRanked(builder, loc, op, input, {}, scale,
+ zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, result);
+ })
+ .getResult(0);
return result;
}
@@ -551,7 +548,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
// Flatten unranked tensor into a 3D ranked tensor if necessary
bool isUnranked = isa<UnrankedTensorType>(input.getType());
int64_t channelAxis = quantizedType.getQuantizedDimension();
- int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+ int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
Value inputShape;
if (isUnranked) {
std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
@@ -597,7 +594,8 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
}
// Lowering pattern for 'quant.dcast'
-struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+struct DequantizeCastOpConversion
+ : public OpConversionPattern<quant::DequantizeCastOp> {
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
LogicalResult
@@ -622,7 +620,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
};
// Lowering pattern for 'quant.qcast'
-struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+struct QuantizeCastOpConversion
+ : public OpConversionPattern<quant::QuantizeCastOp> {
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
LogicalResult
@@ -650,12 +649,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
ConversionTarget target(getContext());
target.addLegalOp<quant::StorageCastOp>();
target.addIllegalDialect<quant::QuantDialect>();
- target.addLegalDialect<
- arith::ArithDialect,
- linalg::LinalgDialect,
- shape::ShapeDialect,
- tensor::TensorDialect
- >();
+ target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+ shape::ShapeDialect, tensor::TensorDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -666,10 +661,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
} // namespace
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
- patterns.add<
- DequantizeCastOpConversion,
- QuantizeCastOpConversion
- >(patterns.getContext());
+ patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
+ patterns.getContext());
}
} // namespace quant
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 8996eff61a..6191272266 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -36,9 +36,10 @@ class QuantizedTypeConverter : public TypeConverter {
static Type convertQuantizedType(QuantizedType quantizedType) {
return quantizedType.getStorageType();
}
-
+
static Type convertTensorType(TensorType tensorType) {
- if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+ if (auto quantizedType =
+ dyn_cast<QuantizedType>(tensorType.getElementType()))
return tensorType.clone(convertQuantizedType(quantizedType));
return tensorType;
}
@@ -50,7 +51,6 @@ class QuantizedTypeConverter : public TypeConverter {
}
public:
-
explicit QuantizedTypeConverter() {
addConversion([](Type type) { return type; });
addConversion(convertQuantizedType);
@@ -63,7 +63,8 @@ public:
};
// Conversion pass
-class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+class StripFuncQuantTypes
+ : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
// Return whether a type is considered legal when occurring in the header of
// a function or as an operand to a 'return' op.
@@ -74,11 +75,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
}
public:
-
void runOnOperation() override {
-
+
auto moduleOp = cast<ModuleOp>(getOperation());
- auto* context = &getContext();
+ auto *context = &getContext();
QuantizedTypeConverter typeConverter;
ConversionTarget target(*context);
@@ -111,4 +111,3 @@ public:
} // namespace quant
} // namespace mlir
-
diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index fb27640bfd..308ff35e01 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
using namespace mlir;
using namespace mlir::quant;
``````````
</details>
https://github.com/llvm/llvm-project/pull/100667
More information about the Mlir-commits
mailing list