[Mlir-commits] [mlir] 7356404 - [mlir] Delete most of the ops from the quant dialect.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Jul 27 17:50:58 PDT 2022
Author: Stella Laurenzo
Date: 2022-07-27T17:50:42-07:00
New Revision: 7356404ace4bdb09e8a81eb2d10e0f5e7a9ab3c0
URL: https://github.com/llvm/llvm-project/commit/7356404ace4bdb09e8a81eb2d10e0f5e7a9ab3c0
DIFF: https://github.com/llvm/llvm-project/commit/7356404ace4bdb09e8a81eb2d10e0f5e7a9ab3c0.diff
LOG: [mlir] Delete most of the ops from the quant dialect.
* https://discourse.llvm.org/t/rfc-removing-the-quant-dialect/3643/8
* Removes most ops. Leaves casts given final comment (can remove more in a followup).
* There are a few uses in Tosa keeping some of the utilities alive. In a followup, I will probably elect to just move simplified versions of them into Tosa itself vs having this quasi-library dependency.
Differential Revision: https://reviews.llvm.org/D120204
Added:
Modified:
mlir/include/mlir/Dialect/Quant/CMakeLists.txt
mlir/include/mlir/Dialect/Quant/QuantOps.td
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/Quant/CMakeLists.txt
mlir/lib/Dialect/Quant/IR/QuantOps.cpp
mlir/lib/Dialect/Quant/Utils/CMakeLists.txt
mlir/unittests/Dialect/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/include/mlir/Dialect/Quant/Passes.h
mlir/include/mlir/Dialect/Quant/Passes.td
mlir/include/mlir/Dialect/Quant/QuantizeUtils.h
mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
mlir/lib/Dialect/Quant/Transforms/PassDetail.h
mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp
mlir/test/Dialect/Quant/convert-const.mlir
mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir
mlir/test/Dialect/Quant/convert-fakequant.mlir
mlir/test/Dialect/Quant/parse-ops-invalid.mlir
mlir/test/Dialect/Quant/parse-ops.mlir
mlir/test/Dialect/Quant/quant_region.mlir
mlir/unittests/Dialect/Quant/CMakeLists.txt
mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
index c5b6a15df396b..a13240ed865df 100644
--- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
@@ -1,8 +1,2 @@
add_mlir_dialect(QuantOps quant)
add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
-
-set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
-add_public_tablegen_target(MLIRQuantPassIncGen)
-
-add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Quant/Passes.h b/mlir/include/mlir/Dialect/Quant/Passes.h
deleted file mode 100644
index ada9c8cee8b4e..0000000000000
--- a/mlir/include/mlir/Dialect/Quant/Passes.h
+++ /dev/null
@@ -1,50 +0,0 @@
-//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines all of the passes owned by the quantization dialect. As
-// things mature, it is expected that passes specific to certain frontend or
-// backend dialects will move to those dialects directly. For now, they are
-// incubated here.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_QUANT_PASSES_H
-#define MLIR_DIALECT_QUANT_PASSES_H
-
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace func {
-class FuncOp;
-} // namespace func
-
-namespace quant {
-
-/// Creates a pass that converts quantization simulation operations (i.e.
-/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes.
-std::unique_ptr<OperationPass<func::FuncOp>> createConvertSimulatedQuantPass();
-
-/// Creates a pass that converts constants followed by a qbarrier to a
-/// constant whose value is quantized. This is typically one of the last
-/// passes done when lowering to express actual quantized arithmetic in a
-/// low level representation. Because it modifies the constant, it is
-/// destructive and cannot be undone.
-std::unique_ptr<OperationPass<func::FuncOp>> createConvertConstPass();
-
-//===----------------------------------------------------------------------===//
-// Registration
-//===----------------------------------------------------------------------===//
-
-/// Generate the code for registering passes.
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Quant/Passes.h.inc"
-
-} // namespace quant
-} // namespace mlir
-
-#endif // MLIR_DIALECT_QUANT_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Quant/Passes.td b/mlir/include/mlir/Dialect/Quant/Passes.td
deleted file mode 100644
index a1afda4a89293..0000000000000
--- a/mlir/include/mlir/Dialect/Quant/Passes.td
+++ /dev/null
@@ -1,27 +0,0 @@
-//===-- Passes.td - Quant pass definition file -------------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_QUANT_PASSES
-#define MLIR_DIALECT_QUANT_PASSES
-
-include "mlir/Pass/PassBase.td"
-
-def QuantConvertConst : Pass<"quant-convert-const", "func::FuncOp"> {
- let summary = "Converts constants followed by qbarrier to actual quantized "
- "values";
- let constructor = "mlir::quant::createConvertConstPass()";
-}
-
-def QuantConvertSimulatedQuant
- : Pass<"quant-convert-simulated-quantization", "func::FuncOp"> {
- let summary = "Converts training-time simulated quantization ops to "
- "corresponding quantize/dequantize casts";
- let constructor = "mlir::quant::createConvertSimulatedQuantPass()";
-}
-
-#endif // MLIR_DIALECT_QUANT_PASSES
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td
index 1bbe832184cc6..de26f018b2e77 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.td
+++ b/mlir/include/mlir/Dialect/Quant/QuantOps.td
@@ -84,170 +84,4 @@ def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let hasFolder = 1;
}
-// A QuantizeRegion (region) represents a quantization unit which wraps
-// high-precision ops with quantization specifications for all the inputs
-// and outputs. Some quantization specifications can be undetermined and
-// derived from other ports by the target specification of the kernel.
-def quant_QuantizeRegionOp : quant_Op<"region", [
- NoSideEffect,
- IsolatedFromAbove,
- SingleBlockImplicitTerminator<"ReturnOp">]> {
- let summary = [{
- The `region` operation wraps high-precision ops as a logical low-precision
- quantized kernel.
- }];
-
- let arguments = (ins Variadic<AnyType>:$inputs,
- TypeArrayAttr:$input_specs,
- TypeArrayAttr:$output_specs,
- StrAttr:$logical_kernel);
- let results = (outs Variadic<AnyType>:$outputs);
- let regions = (region SizedRegion<1>:$body);
- let hasVerifier = 1;
-}
-
-def quant_ReturnOp : quant_Op<"return", [Terminator]> {
- let summary = [{
- The `return` operation terminates a quantize region and returns values.
- }];
-
- let arguments = (ins Variadic<AnyTensor>:$results);
-}
-
-//===----------------------------------------------------------------------===//
-// Training integration and instrumentation ops
-//===----------------------------------------------------------------------===//
-
-def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
- [SameOperandsAndResultType, NoSideEffect]> {
- let summary = [{
- Simulates the effect of uniform quantization with const range.
- }];
-
- let description = [{
- Given a const min, max, num_bits and narrow_range attribute, applies the
- same uniform quantization simulation as is done by the TensorFlow
- fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility
- method and the quant-convert-simulated-quantization pass for further details.
- }];
-
- let arguments = (ins
- F32Tensor:$inputs,
- F32Attr:$min,
- F32Attr:$max,
- // The bitwidth of the quantization; between 2 and 16, inclusive.
- I64Attr:$num_bits,
- // Quantization range starts from 0 or 1; starts from 1 if true.
- DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
- // The sign of the quantization.
- DefaultValuedAttr<BoolAttr, "false">:$is_signed
- );
-
- let results = (outs
- F32Tensor:$outputs
- );
-}
-
-def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis",
- [SameOperandsAndResultType, NoSideEffect]> {
- let summary = [{
- Simulates the effect of per axis uniform quantization with const range.
- }];
-
- let description = [{
- Given a const min, max, num_bits and narrow_range attribute, applies the
- same per axis uniform quantization simulation as is done by the TensorFlow
- fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType()
- utility method and the quant-convert-simulated-quantization pass for further
- details.
- }];
-
- let arguments = (ins
- F32Tensor:$inputs,
- F32ArrayAttr:$min,
- F32ArrayAttr:$max,
- // The quantized dimension of the inputs tensor.
- I64Attr:$axis,
- // The bitwidth of the quantization; between 2 and 16, inclusive.
- I64Attr:$num_bits,
- // Quantization range starts from 0 or 1; starts from 1 if true.
- DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
- // The sign of the quantization.
- DefaultValuedAttr<BoolAttr, "false">:$is_signed
- );
-
- let results = (outs
- F32Tensor:$outputs
- );
-}
-
-def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
- let summary = "Indicates that statistics are resolved by reference.";
-
- let description = [{
- This op acts as an identity that, when encountered at runtime, should result
- in statistics being collected about about the value of its operand/result.
- Such statistics will be stored with the provided key, allowing this node
- to later be converted to a 'stats' op if statistics with that key have been
- encountered.
- }];
-
- let arguments = (ins
- quant_RealValueType:$arg,
- StrAttr:$statsKey
- );
- let results = (outs quant_RealValueType);
-}
-
-def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
- let summary = "Identity op which associates statistics with the value.";
-
- let description = [{
- Associates statistics about the runtime ranges of values observed for
- evaluations of this node.
-
- Statistics about the entire type are reported in the 'layerStats' attribute
- and those for each axis, in the (optional) `axisStats` attribute. The
- interpretation of each is determined by the last dimension of its shape.
- Currently, only dim=2 is supported, which is interpreted as [min, max].
-
- `layerStats` must be a rank 1 tensor: [2]
- `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size
- splitted by the `axis` dimension. For example:
-
- ```
- <?x?x3x2>, axis=3 => N=2
- <?x?x3x2>, axis=2 => N=6
- ```
- }];
-
- let arguments = (ins
- quant_RealValueType:$arg,
- ElementsAttr:$layerStats,
- OptionalAttr<ElementsAttr>:$axisStats,
- OptionalAttr<I64Attr>:$axis);
- let results = (outs quant_RealValueType);
- let hasVerifier = 1;
-}
-
-def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> {
- let summary = [{
- Indicates that one point of the computation is coupled to another.
- }];
-
- let description = [{
- Ordinarily, relationships between ops for the purposes of determining
- compatible quantized types is explicit based on the use-def chain. However,
- in some situations, a use may be separated from its def by arbitrary
- external connections. In such a case, during analysis, all coupled_ref
- nodes in a module which share a coupledKey will be considered to be
- directly connected as via an identity op for the purpose of type inference.
- }];
-
- let arguments = (ins
- quant_RealValueType:$arg,
- StrAttr:$coupledKey);
- let results = (outs quant_RealValueType);
-}
-
#endif // DIALECT_QUANT_QUANT_OPS_
diff --git a/mlir/include/mlir/Dialect/Quant/QuantizeUtils.h b/mlir/include/mlir/Dialect/Quant/QuantizeUtils.h
deleted file mode 100644
index 4f4714f3aadf4..0000000000000
--- a/mlir/include/mlir/Dialect/Quant/QuantizeUtils.h
+++ /dev/null
@@ -1,61 +0,0 @@
-//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_QUANT_QUANTIZEUTILS_H_
-#define MLIR_DIALECT_QUANT_QUANTIZEUTILS_H_
-
-namespace mlir {
-class Attribute;
-class Type;
-
-namespace quant {
-class QuantizedType;
-class UniformQuantizedType;
-class UniformQuantizedValueConverter;
-
-/// Converts an attribute from a type based on
-/// quantizedElementType.getExpressedType() to one based on
-/// quantizedElementType.getStorageType(), where quantizedElementType is as from
-/// QuantizedType::getQuantizedElementType().
-/// Returns nullptr if the conversion is not supported. On success, stores the
-/// converted type in outConvertedType.
-///
-/// Examples:
-/// 1. realValue is a primitive value attribute:
-/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
-/// -> (IntegerAttr, outConvertedType: i8)
-/// 2. realValue is an elements attribute:
-/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
-/// quantizedElementType: UniformQuantizedType[i8:f32])
-/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
-Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
- Type &outConvertedType);
-
-/// Converts an attribute from a type based on
-/// quantizedElementType.getExpressedType() to one based on
-/// quantizedElementType.getStorageType(), where quantizedElementType is as from
-/// QuantizedType::getQuantizedElementType() and casted to an
-/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On
-/// success, stores the converted type in outConvertedType.
-///
-/// Examples:
-/// 1. realValue is a primitive value attribute:
-/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
-/// -> (IntegerAttr, outConvertedType: i8)
-/// 2. realValue is an elements attribute:
-/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
-/// quantizedElementType: UniformQuantizedType[i8:f32])
-/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
-Attribute quantizeAttrUniform(Attribute realValue,
- UniformQuantizedType quantizedElementType,
- const UniformQuantizedValueConverter &converter,
- Type &outConvertedType);
-} // namespace quant
-} // namespace mlir
-
-#endif // MLIR_DIALECT_QUANT_QUANTIZEUTILS_H_
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index cdc743f98539e..7644f4c3552cb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -25,7 +25,6 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Passes.h"
-#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -69,7 +68,6 @@ inline void registerAllPasses() {
registerSparseTensorPasses();
LLVM::registerLLVMPasses();
memref::registerMemRefPasses();
- quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt
index 31167e6af908b..037bba8dcb5c9 100644
--- a/mlir/lib/Dialect/Quant/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/CMakeLists.txt
@@ -1,3 +1,2 @@
add_subdirectory(IR)
-add_subdirectory(Transforms)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index b48d16cceb3c6..063f41e8e4e13 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -43,93 +43,5 @@ OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
return srcScastOp.getArg();
}
-/// The quantization specification should match the expressed type.
-static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
- if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
- Type spec = typeAttr.getValue();
- if (spec.isa<TensorType, VectorType>())
- return false;
-
- // The spec should be either a quantized type which is compatible to the
- // expressed type, or a primitive type which is as same as the
- // (element type of) the expressed type.
- if (auto quantizedType = spec.dyn_cast<QuantizedType>())
- return quantizedType.isCompatibleExpressedType(expressed);
-
- if (auto tensorType = expressed.dyn_cast<TensorType>())
- return spec == tensorType.getElementType();
-
- if (auto vectorType = expressed.dyn_cast<VectorType>())
- return spec == vectorType.getElementType();
- }
- return false;
-}
-
-LogicalResult QuantizeRegionOp::verify() {
- // There are specifications for both inputs and outputs.
- if (getNumOperands() != getInputSpecs().size() ||
- getNumResults() != getOutputSpecs().size())
- return emitOpError(
- "has unmatched operands/results number and spec attributes number");
-
- // Verify that quantization specifications are valid.
- for (auto input : llvm::zip(getOperandTypes(), getInputSpecs())) {
- Type inputType = std::get<0>(input);
- Attribute inputSpec = std::get<1>(input);
- if (!isValidQuantizationSpec(inputSpec, inputType)) {
- return emitOpError() << "has incompatible specification " << inputSpec
- << " and input type " << inputType;
- }
- }
-
- for (auto result : llvm::zip(getResultTypes(), getOutputSpecs())) {
- Type outputType = std::get<0>(result);
- Attribute outputSpec = std::get<1>(result);
- if (!isValidQuantizationSpec(outputSpec, outputType)) {
- return emitOpError() << "has incompatible specification " << outputSpec
- << " and output type " << outputType;
- }
- }
- return success();
-}
-
-LogicalResult StatisticsOp::verify() {
- auto tensorArg = getArg().getType().dyn_cast<TensorType>();
- if (!tensorArg)
- return emitOpError("arg needs to be tensor type.");
-
- // Verify layerStats attribute.
- {
- auto layerStatsType = getLayerStats().getType();
- if (!layerStatsType.getElementType().isa<FloatType>()) {
- return emitOpError("layerStats must have a floating point element type");
- }
- if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
- return emitOpError("layerStats must have shape [2]");
- }
- }
- // Verify axisStats (optional) attribute.
- if (getAxisStats()) {
- if (!getAxis())
- return emitOpError("axis must be specified for axisStats");
-
- auto shape = tensorArg.getShape();
- auto argSliceSize =
- std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1,
- std::multiplies<int64_t>());
-
- auto axisStatsType = getAxisStats()->getType();
- if (!axisStatsType.getElementType().isa<FloatType>()) {
- return emitOpError("axisStats must have a floating point element type");
- }
- if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 ||
- axisStatsType.getDimSize(0) != argSliceSize) {
- return emitOpError("axisStats must have shape [N,2] "
- "where N = the slice size defined by the axis dim");
- }
- }
- return success();
-}
-
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
deleted file mode 100644
index 099ab7537e7ba..0000000000000
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,20 +0,0 @@
-add_mlir_dialect_library(MLIRQuantTransforms
- ConvertConst.cpp
- ConvertSimQuant.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/QuantOps/Transforms
-
- DEPENDS
- MLIRQuantPassIncGen
-
- LINK_LIBS PUBLIC
- MLIRArithmeticDialect
- MLIRFuncDialect
- MLIRIR
- MLIRQuantDialect
- MLIRQuantUtils
- MLIRPass
- MLIRSupport
- MLIRTransformUtils
- )
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
deleted file mode 100644
index ece8b101d63af..0000000000000
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ /dev/null
@@ -1,104 +0,0 @@
-//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Quant/Passes.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantizeUtils.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-using namespace mlir::quant;
-
-namespace {
-struct ConvertConstPass : public QuantConvertConstBase<ConvertConstPass> {
- void runOnOperation() override;
-};
-
-struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
- using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
- PatternRewriter &rewriter) const override;
-};
-
-} // namespace
-
-/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
-/// quantized and the operand type is quantizable.
-
-LogicalResult
-QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
- PatternRewriter &rewriter) const {
- Attribute value;
-
- // Is the operand a constant?
- if (!matchPattern(qbarrier.getArg(), m_Constant(&value))) {
- return failure();
- }
-
- // Does the qbarrier convert to a quantized type. This will not be true
- // if a quantized type has not yet been chosen or if the cast to an equivalent
- // storage type is not supported.
- Type qbarrierResultType = qbarrier.getResult().getType();
- QuantizedType quantizedElementType =
- QuantizedType::getQuantizedElementType(qbarrierResultType);
- if (!quantizedElementType) {
- return failure();
- }
- if (!QuantizedType::castToStorageType(qbarrierResultType)) {
- return failure();
- }
-
- // Is the operand type compatible with the expressed type of the quantized
- // type? This will not be true if the qbarrier is superfluous (converts
- // from and to a quantized type).
- if (!quantizedElementType.isCompatibleExpressedType(
- qbarrier.getArg().getType())) {
- return failure();
- }
-
- // Is the constant value a type expressed in a way that we support?
- if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
- return failure();
- }
-
- Type newConstValueType;
- auto newConstValue =
- quantizeAttr(value, quantizedElementType, newConstValueType);
- if (!newConstValue) {
- return failure();
- }
-
- // When creating the new const op, use a fused location that combines the
- // original const and the qbarrier that led to the quantization.
- auto fusedLoc = rewriter.getFusedLoc(
- {qbarrier.getArg().getDefiningOp()->getLoc(), qbarrier.getLoc()});
- auto newConstOp = rewriter.create<arith::ConstantOp>(
- fusedLoc, newConstValueType, newConstValue);
- rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
- newConstOp);
- return success();
-}
-
-void ConvertConstPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- auto func = getOperation();
- auto *context = &getContext();
- patterns.add<QuantizedConstRewrite>(context);
- (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
-}
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::quant::createConvertConstPass() {
- return std::make_unique<ConvertConstPass>();
-}
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
deleted file mode 100644
index ca7f303bebf3b..0000000000000
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
+++ /dev/null
@@ -1,140 +0,0 @@
-//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/Passes.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-using namespace mlir::quant;
-
-namespace {
-struct ConvertSimulatedQuantPass
- : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> {
- void runOnOperation() override;
-};
-
-/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
-template <typename ConcreteRewriteClass, typename FakeQuantOp>
-class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
-public:
- using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
-
- FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
- : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
-
- LogicalResult matchAndRewrite(FakeQuantOp op,
- PatternRewriter &rewriter) const override {
- // TODO: If this pattern comes up more frequently, consider adding core
- // support for failable rewrites.
- if (failableRewrite(op, rewriter)) {
- *hadFailure = true;
- return failure();
- }
-
- return success();
- }
-
-private:
- bool *hadFailure;
-
- bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
- auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
- if (!converter) {
- return (op.emitError("unsupported quantized type conversion"), true);
- }
-
- QuantizedType elementType =
- static_cast<const ConcreteRewriteClass *>(this)
- ->convertFakeQuantAttrsToType(op, converter.expressedType);
-
- if (!elementType) {
- // Note that the fakeQuantAttrsToType will have emitted the error.
- return true;
- }
-
- Type quantizedType = converter.convert(elementType);
- assert(quantizedType &&
- "Converter accepted a type that it did not convert");
-
- // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
- // this is a forced/hard-coded constraint.
- auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
- op.getInputs());
- rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
- qbarrier.getResult());
-
- return false;
- }
-};
-
-class ConstFakeQuantRewrite
- : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
-public:
- using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
-
- ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
- : BaseRewrite(ctx, hadFailure) {}
-
- QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
- Type expressedType) const {
- return fakeQuantAttrsToType(
- fqOp.getLoc(), fqOp.getNumBits(), fqOp.getMin().convertToFloat(),
- fqOp.getMax().convertToFloat(), fqOp.getNarrowRange(), expressedType,
- fqOp.getIsSigned());
- }
-};
-
-class ConstFakeQuantPerAxisRewrite
- : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
- ConstFakeQuantPerAxis> {
-public:
- using BaseRewrite =
- FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
-
- ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
- : BaseRewrite(ctx, hadFailure) {}
-
- QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
- Type expressedType) const {
- SmallVector<double, 4> min, max;
- min.reserve(fqOp.getMin().size());
- max.reserve(fqOp.getMax().size());
- for (auto m : fqOp.getMin())
- min.push_back(m.cast<FloatAttr>().getValueAsDouble());
- for (auto m : fqOp.getMax())
- max.push_back(m.cast<FloatAttr>().getValueAsDouble());
-
- return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.getNumBits(),
- fqOp.getAxis(), min, max, fqOp.getNarrowRange(),
- expressedType, fqOp.getIsSigned());
- }
-};
-
-} // namespace
-
-void ConvertSimulatedQuantPass::runOnOperation() {
- bool hadFailure = false;
- auto func = getOperation();
- RewritePatternSet patterns(func.getContext());
- auto *ctx = func.getContext();
- patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
- ctx, &hadFailure);
- (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
- if (hadFailure)
- signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::quant::createConvertSimulatedQuantPass() {
- return std::make_unique<ConvertSimulatedQuantPass>();
-}
diff --git a/mlir/lib/Dialect/Quant/Transforms/PassDetail.h b/mlir/lib/Dialect/Quant/Transforms/PassDetail.h
deleted file mode 100644
index 358b6e078d587..0000000000000
--- a/mlir/lib/Dialect/Quant/Transforms/PassDetail.h
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- PassDetail.h - Quant Pass class details ------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_
-#define DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-
-#define GEN_PASS_CLASSES
-#include "mlir/Dialect/Quant/Passes.h.inc"
-
-} // namespace mlir
-
-#endif // DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_
diff --git a/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt b/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt
index 0a1d9ea2546fd..50381f053f85c 100644
--- a/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRQuantUtils
- QuantizeUtils.cpp
UniformSupport.cpp
FakeQuantSupport.cpp
diff --git a/mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp
deleted file mode 100644
index 66885fb7a5fc1..0000000000000
--- a/mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp
+++ /dev/null
@@ -1,147 +0,0 @@
-//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Quant/QuantizeUtils.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-
-using namespace mlir;
-using namespace mlir::quant;
-
-/// Converts a possible primitive, real expressed value attribute to a
-/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
-/// quantizedElementType is the QuantizedType that describes the expressed
-/// origValue.
-/// Returns a converter Attribute or nullptr if conversion is not possible.
-static Attribute convertPrimitiveValueAttr(
- Attribute origRealValue, QuantizedType quantizedElementType,
- const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
- if (origRealValue.isa<FloatAttr>()) {
- FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
- outConvertedType = quantizedElementType.getStorageType();
- return IntegerAttr::get(quantizedElementType.getStorageType(),
- converter.quantizeFloatToInt(floatAttr.getValue()));
- }
-
- return nullptr;
-}
-
-/// Converts a real expressed DenseFPElementsAttr to a corresponding
-/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
-/// storage values assuming the given quantizedElementType and converter.
-static DenseElementsAttr
-convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
- QuantizedType quantizedElementType,
- const UniformQuantizedValueConverter &converter) {
- // Convert to corresponding quantized value attributes.
- SmallVector<APInt, 8> quantValues;
- if (realFPElementsAttr.isSplat()) {
- quantValues.push_back(
- converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
- } else {
- quantValues.reserve(realFPElementsAttr.getNumElements());
- for (APFloat realVal : realFPElementsAttr) {
- quantValues.push_back(converter.quantizeFloatToInt(realVal));
- }
- }
-
- // Cast from an expressed-type-based type to storage-type-based type,
- // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
- ShapedType newDenseType =
- quantizedElementType
- .castExpressedToStorageType(realFPElementsAttr.getType())
- .dyn_cast_or_null<ShapedType>();
- if (!newDenseType) {
- return nullptr;
- }
- return DenseIntElementsAttr::get(newDenseType, quantValues);
-}
-
-/// Converts a real expressed SplatElementsAttr to a corresponding
-/// SplatElementsAttr containing quantized storage values assuming the given
-/// quantizedElementType and converter.
-static SparseElementsAttr
-convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
- QuantizedType quantizedElementType,
- const UniformQuantizedValueConverter &converter) {
- DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
- if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
- return nullptr;
- }
- DenseElementsAttr quantDenseAttr =
- convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
- quantizedElementType, converter);
- if (!quantDenseAttr) {
- return nullptr;
- }
-
- // Cast from an expressed-type-based type to storage-type-based type,
- // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
- ShapedType newSparseType =
- quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
- .dyn_cast_or_null<ShapedType>();
- if (!newSparseType) {
- return nullptr;
- }
- return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
- quantDenseAttr);
-}
-
-/// Converts a real expressed Attribute to a corresponding Attribute containing
-/// quantized storage values assuming the given uniform quantizedElementType and
-/// converter.
-Attribute mlir::quant::quantizeAttrUniform(
- Attribute realValue, UniformQuantizedType quantizedElementType,
- const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
- // Fork to handle
diff erent variants of constants supported.
- if (realValue.isa<DenseFPElementsAttr>()) {
- // Dense tensor or vector constant.
- auto converted = convertDenseFPElementsAttr(
- realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
- outConvertedType = converted.getType();
- return converted;
- }
- if (realValue.isa<SparseElementsAttr>()) {
- // Sparse tensor or vector constant.
- auto converted = convertSparseElementsAttr(
- realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
- outConvertedType = converted.getType();
- return converted;
- }
- // Nothing else matched: try to convert a primitive.
- return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
- outConvertedType);
-}
-
-/// Convert an attribute from a type based on
-/// quantizedElementType.getExpressedType() to one based on
-/// quantizedElementType.getStorageType().
-/// Returns nullptr if the conversion is not supported.
-/// On success, stores the converted type in outConvertedType.
-Attribute mlir::quant::quantizeAttr(Attribute realValue,
- QuantizedType quantizedElementType,
- Type &outConvertedType) {
- if (auto uniformQuantized =
- quantizedElementType.dyn_cast<UniformQuantizedType>()) {
- UniformQuantizedValueConverter converter(uniformQuantized);
- return quantizeAttrUniform(realValue, uniformQuantized, converter,
- outConvertedType);
- }
- if (auto uniformQuantizedPerAxis =
- quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
- UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
- auto converted = converter.convert(realValue);
- // TODO: why we need this outConvertedType? remove it?
- if (converted) {
- outConvertedType = converted.getType();
- }
- return converted;
- }
- return nullptr;
-}
diff --git a/mlir/test/Dialect/Quant/convert-const.mlir b/mlir/test/Dialect/Quant/convert-const.mlir
deleted file mode 100644
index 78fe85d561000..0000000000000
--- a/mlir/test/Dialect/Quant/convert-const.mlir
+++ /dev/null
@@ -1,193 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -quant-convert-const | FileCheck %s
-
-// Magic numbers:
-// 7.8125e-03 = 1/128 = 2/256 : real range = [-1.0, 0.9921875] (for 8bit, zeroPoint=128)
-// 1.250000e-01 = 1/8 = 2/16 : real range = [-1.0, 0.875] (for 4bit, zeroPoint=8)
-
-// -----
-// Verifies u8 affine quantization on a splat tensor.
-// Note that MLIR prints int attributes as signed, so the constant, when
-// quantized, is the signed printed version of an unsigned quantity
-// (-64 signed == 192 unsigned).
-// CHECK-LABEL: constant_splat_tensor_u8_affine
-func.func @constant_splat_tensor_u8_affine() -> tensor<4xf32> {
- // CHECK: %cst = arith.constant dense<-64> : tensor<4xi8>
- // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
- %cst = arith.constant dense<0.5> : tensor<4xf32>
- %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
- %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> (tensor<4xf32>)
- return %2 : tensor<4xf32>
-}
-
-// -----
-// Verifies i8 affine quantization on a splat tensor.
-// CHECK-LABEL: constant_splat_tensor_i8_affine
-func.func @constant_splat_tensor_i8_affine() -> tensor<4xf32> {
- // CHECK: %cst = arith.constant dense<63> : tensor<4xi8>
- // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 7.812500e-03:-1>>
- %cst = arith.constant dense<0.5> : tensor<4xf32>
- %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform<i8:f32, 7.812500e-03:-1>>
- %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform<i8:f32, 7.812500e-03:-1>>) -> (tensor<4xf32>)
- return %2 : tensor<4xf32>
-}
-
-// -----
-// Verifies i8 fixedpoint quantization on a splat tensor.
-// CHECK-LABEL: const_splat_tensor_i8_fixedpoint
-func.func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> {
- // CHECK: %cst = arith.constant dense<64> : tensor<4xi8>
- // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<i8:f32, 7.812500e-03>>
- %cst = arith.constant dense<0.5> : tensor<4xf32>
- %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform<i8:f32, 7.812500e-03>>
- %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform<i8:f32, 7.812500e-03>>) -> (tensor<4xf32>)
- return %2 : tensor<4xf32>
-}
-
-// -----
-// Verifies i8 fixedpoint quantization on a splat tensor resulting in a negative storage value.
-// CHECK-LABEL: const_splat_tensor_i8_fixedpoint_neg
-func.func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> {
- // CHECK: %cst = arith.constant dense<-64> : tensor<4xi8>
- %cst = arith.constant dense<-0.5> : tensor<4xf32>
- %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform<i8:f32, 7.812500e-03>>
- %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform<i8:f32, 7.812500e-03>>) -> (tensor<4xf32>)
- return %2 : tensor<4xf32>
-}
-
-// -----
-// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values.
-// CHECK-LABEL: const_dense_tensor_i8_fixedpoint
-func.func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> {
- // CHECK: %cst = arith.constant dense<[-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8>
- %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i8:f32, 7.812500e-03>>
- %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i8:f32, 7.812500e-03>>) -> (tensor<7xf32>)
- return %2 : tensor<7xf32>
-}
-
-// -----
-// Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values.
-// CHECK-LABEL: const_sparse_tensor_i8_fixedpoint
-func.func @const_sparse_tensor_i8_fixedpoint() -> tensor<2x7xf32> {
- // NOTE: Ugly regex match pattern for opening "[[" of indices tensor.
- // CHECK: %cst = arith.constant sparse<{{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<2x7xi8>
- %cst = arith.constant sparse<
- [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
- [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<2x7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<2x7xf32>) -> tensor<2x7x!quant.uniform<i8:f32, 7.812500e-03>>
- %2 = "quant.dcast"(%1) : (tensor<2x7x!quant.uniform<i8:f32, 7.812500e-03>>) -> (tensor<2x7xf32>)
- return %2 : tensor<2x7xf32>
-}
-
-// -----
-// Verifies i8 fixedpoint quantization on a primitive const.
-// CHECK-LABEL: const_primitive_float_i8_fixedpoint
-func.func @const_primitive_float_i8_fixedpoint() -> f32 {
- // CHECK: %c64_i8 = arith.constant 64 : i8
- // CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant.uniform<i8:f32, 7.812500e-03>
- %cst = arith.constant 0.5 : f32
- %1 = "quant.qcast"(%cst) : (f32) -> !quant.uniform<i8:f32, 7.812500e-03>
- %2 = "quant.dcast"(%1) : (!quant.uniform<i8:f32, 7.812500e-03>) -> (f32)
- return %2 : f32
-}
-
-// -----
-// Verifies u4 affine quantization on a dense tensor, sweeping values.
-// CHECK-LABEL: const_dense_tensor_u4_affine
-func.func @const_dense_tensor_u4_affine() -> tensor<7xf32> {
- // NOTE: Unsigned quantities printed by MLIR as signed.
- // CHECK: %cst = arith.constant dense<[0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4>
- %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<u4:f32, 1.250000e-01:8>>
- %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<u4:f32, 1.250000e-01:8>>) -> (tensor<7xf32>)
- return %2 : tensor<7xf32>
-}
-
-// -----
-// Verifies i4 affine quantization on a dense tensor, sweeping values.
-// CHECK-LABEL: const_dense_tensor_i4_affine
-func.func @const_dense_tensor_i4_affine() -> tensor<7xf32> {
- // NOTE: Unsigned quantities printed by MLIR as signed.
- // CHECK: %cst = arith.constant dense<[-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4>
- %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i4:f32, 1.250000e-01:-1>>
- %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i4:f32, 1.250000e-01:-1>>) -> (tensor<7xf32>)
- return %2 : tensor<7xf32>
-}
-
-// -----
-// Verifies i4 fixed point quantization on a dense tensor, sweeping values.
-// CHECK-LABEL: const_dense_tensor_i4_fixedpoint
-func.func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> {
- // CHECK: %cst = arith.constant dense<[-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4>
- %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i4:f32, 1.250000e-01>>
- %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i4:f32, 1.250000e-01>>) -> (tensor<7xf32>)
- return %2 : tensor<7xf32>
-}
-
-// -----
-// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values, and
-// custom storage range. (the -128 should be clamped to -100, and the 127 should
-// be clamped to 100).
-// CHECK-LABEL: const_custom_storage_range_i8_fixedpoint
-func.func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> {
- // CHECK: %cst = arith.constant dense<[-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8>
- %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i8<-100:100>:f32, 7.812500e-03>>
- %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i8<-100:100>:f32, 7.812500e-03>>) -> (tensor<7xf32>)
- return %2 : tensor<7xf32>
-}
-
-// -----
-// Verifies quantization results of all-0.0 tensors are quantized to zero points.
-// CHECK-LABEL: zero_tensors_to_zero_points
-func.func @zero_tensors_to_zero_points() -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32>) {
-
-// CHECK-DAG: %[[cst1:.*]] = arith.constant dense<1> : tensor<7xi8>
-// CHECK-DAG: %[[cst:.*]] = arith.constant dense<-127> : tensor<7xi8>
-// CHECK-DAG: %[[cst0:.*]] = arith.constant dense<0> : tensor<7xi8>
-// CHECK: "quant.scast"(%[[cst0]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<i8:f32, 1.000000e+00>>
-// CHECK: "quant.scast"(%[[cst]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>
-// CHECK: "quant.scast"(%[[cst0]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<u8:f32, 1.000000e+00>>
-// CHECK: "quant.scast"(%[[cst1]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>
-
- %cst = arith.constant dense<0.0> : tensor<7xf32>
- %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i8:f32, 1.0>>
- %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i8:f32, 1.0>>) -> (tensor<7xf32>)
-
- %cst0 = arith.constant dense<0.0> : tensor<7xf32>
- %3 = "quant.qcast"(%cst0) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i8<-127:127>:f32, 1.0:-127>>
- %4 = "quant.dcast"(%3) : (tensor<7x!quant.uniform<i8<-127:127>:f32, 1.0:-127>>) -> (tensor<7xf32>)
-
- %cst1 = arith.constant dense<0.0> : tensor<7xf32>
- %5 = "quant.qcast"(%cst1) : (tensor<7xf32>) -> tensor<7x!quant.uniform<u8:f32, 1.0>>
- %6 = "quant.dcast"(%5) : (tensor<7x!quant.uniform<u8:f32, 1.0>>) -> (tensor<7xf32>)
-
- %cst2 = arith.constant dense<0.0> : tensor<7xf32>
- %7 = "quant.qcast"(%cst2) : (tensor<7xf32>) -> tensor<7x!quant.uniform<u8<1:255>:f32, 1.0:1>>
- %8 = "quant.dcast"(%7) : (tensor<7x!quant.uniform<u8<1:255>:f32, 1.0:1>>) -> (tensor<7xf32>)
-
- return %2, %4, %6, %8 : tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32>
-}
-
-// -----
-// Verifies per-axis quantization results for dense.
-// CHECK-LABEL: per_axis_dense_quantization
-func.func @per_axis_dense_quantization() -> (tensor<2x3xf32>, tensor<2x3xf32>) {
-
-// CHECK-DAG: %[[cst0:.*]] = arith.constant dense<{{\[}}[-128, -1, 1], [127, 1, 3]]> : tensor<2x3xi8>
-// CHECK-DAG: %[[cst:.*]] = arith.constant dense<{{\[}}[-128, 64, 127], [0, 1, 2]]> : tensor<2x3xi8>
-// CHECK: "quant.scast"(%[[cst]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128,1.000000e+00}>>
-// CHECK: "quant.scast"(%[[cst0]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128,1.000000e+00,1.000000e+00:1}>>
-
- %cst = arith.constant dense<[[-2.0, -0.5, 0.0], [0.0, 1.0, 2.0]]> : tensor<2x3xf32>
- %1 = "quant.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128, 1.0}>>
- %2 = "quant.dcast"(%1) : (tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128, 1.0}>>) -> (tensor<2x3xf32>)
-
- %cst0 = arith.constant dense<[[-2.0, -0.5, 0.0], [0.0, 1.0, 2.0]]> : tensor<2x3xf32>
- %3 = "quant.qcast"(%cst0) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128, 1.0, 1.0:1}>>
- %4 = "quant.dcast"(%3) : (tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128, 1.0, 1.0:1}>>) -> (tensor<2x3xf32>)
-
- return %2, %4 : tensor<2x3xf32>, tensor<2x3xf32>
-}
diff --git a/mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir b/mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir
deleted file mode 100644
index bd4a0f96ababa..0000000000000
--- a/mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir
+++ /dev/null
@@ -1,12 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization
-
-// -----
-// Unsupported quantizable type (i1 is currently not a supported element type).
-func.func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
-^bb0(%arg0: tensor<8x4x3xi1>):
- // expected-error at +1 {{op operand #0 must be tensor of 32-bit float values}}
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 1.1 : f32, max = 1.0 : f32, num_bits = 8
- } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1>
- return %0 : tensor<8x4x3xi1>
-}
diff --git a/mlir/test/Dialect/Quant/convert-fakequant.mlir b/mlir/test/Dialect/Quant/convert-fakequant.mlir
deleted file mode 100644
index 14983591fac7c..0000000000000
--- a/mlir/test/Dialect/Quant/convert-fakequant.mlir
+++ /dev/null
@@ -1,233 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s
-
-// -----
-// Verifies a quint8 single point.
-// CHECK-LABEL: fakeQuantArgs_Quint8_0
-func.func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>>
- // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 0.0 : f32, num_bits = 8
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a quint8 single point (with narrow_range = true).
-// CHECK-LABEL: fakeQuantArgs_Quint8_0_NarrowRange
-func.func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>
- // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a quint8 asymmetric 0..1 range.
-// CHECK-LABEL: fakeQuantArgs_Quint8_0_1
-func.func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 0.0039215686274509803>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<u8:f32, 0.0039215686274509803>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
-// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange
-func.func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 0.003937007874015748:1>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 0.003937007874015748:1>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a quint8 symmetric range of -1..127/128.
-// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange
-func.func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a qint8 single point.
-// CHECK-LABEL: fakeQuantArgs_Qint8_0
-func.func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>>
- // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a qint8 single point (with narrow_range = true).
-// CHECK-LABEL: fakeQuantArgs_Qint8_0_NarrowRange
-func.func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>
- // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a qint8 asymmetric 0..1 range.
-// CHECK-LABEL: fakeQuantArgs_Qint8_0_1
-func.func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
-// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange
-func.func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a qint8 symmetric range of -1..127/128.
-// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange
-func.func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of
-// 0 and range -1.0 .. 32767/32768.
-// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric
-func.func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i16:f32, 3.0517578125E-5>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i16:f32, 3.0517578125E-5>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verify that lowering to barriers of unranked tensors functions.
-// CHECK-LABEL: fakeQuantArgs_UnrankedTensor
-func.func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
-^bb0(%arg0: tensor<f32>):
- // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<f32>)
- // CHECK-SAME: -> tensor<!quant.uniform<u8:f32, 0.0039215686274509803>>
- // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<!quant.uniform<u8:f32, 0.0039215686274509803>>)
- // CHECK-SAME: -> tensor<f32>
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8
- } : (tensor<f32>) -> tensor<f32>
- return %0 : tensor<f32>
-}
-
-// -----
-// CHECK-LABEL: fakeQuantArgs_all_positive
-func.func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
-
- // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
- // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
-
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// CHECK-LABEL: fakeQuantArgs_all_negative
-func.func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
-
- // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>
- // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>)
- // CHECK-SAME: -> tensor<8x4x3xf32>
-
- %0 = "quant.const_fake_quant"(%arg0) {
- min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// Verifies a qint8 per axis
-// CHECK-LABEL: fakeQuantPerAxis
-func.func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
-^bb0(%arg0: tensor<8x4x3xf32>):
-
- // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
- // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>
- // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]])
- // CHECK-SAME: (tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>)
-
- %0 = "quant.const_fake_quant_per_axis"(%arg0) {
- min = [-1.0 : f32, 0.0 : f32, 0.0 : f32],
- max = [0.9921875 : f32, 0.0: f32, 1.0 : f32],
- num_bits = 8, narrow_range = false, is_signed = true, axis = 2
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
diff --git a/mlir/test/Dialect/Quant/parse-ops-invalid.mlir b/mlir/test/Dialect/Quant/parse-ops-invalid.mlir
deleted file mode 100644
index 2b2a9eed84806..0000000000000
--- a/mlir/test/Dialect/Quant/parse-ops-invalid.mlir
+++ /dev/null
@@ -1,93 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
-
-// -----
-func.func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) ->
- tensor<8x4x3xf32> {
- // expected-error at +1 {{layerStats must have a floating point element type}}
- %0 = "quant.stats"(%arg0) {
- layerStats = dense<[-1, 1]> : tensor<2xi8>
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-func.func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) ->
- tensor<8x4x3xf32> {
- // expected-error at +1 {{layerStats must have shape [2]}}
- %0 = "quant.stats"(%arg0) {
- layerStats = dense<[[-1.0, 1.0]]> : tensor<1x2xf32>
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-func.func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) ->
- tensor<8x4x3xf32> {
- // expected-error at +1 {{layerStats must have shape [2]}}
- %0 = "quant.stats"(%arg0) {
- layerStats = dense<[-1.0, 1.0, 2.0]> : tensor<3xf32>
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// CHECK-LABEL: validStatistics
-func.func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
- // expected-error at +1 {{axisStats must have a floating point element type}}
- %0 = "quant.stats"(%0) {
- layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
- axisStats = dense<[
- [-1, 1],
- [-8, 8],
- [-1, 0]
- ]> : tensor<3x2xi8>, axis = 3 : i64
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-func.func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) ->
- tensor<8x4x3xf32> {
- // expected-error at +1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
- %0 = "quant.stats"(%arg0) {
- layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
- axisStats = dense<[
- [-1.0, 1.0],
- [-8.0, 8.0],
- [-0.5, 0.5],
- [-2.0, 3.5]
- ]> : tensor<4x2xf32>, axis = 3 : i64
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-func.func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
- tensor<8x4x3xf32> {
- // expected-error at +1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
- %0 = "quant.stats"(%arg0) {
- layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
- axisStats = dense<[
- [-1.0, 1.0, 1.0],
- [-8.0, 8.0, 1.0],
- [-0.5, 0.5, 1.0]
- ]> : tensor<3x3xf32>, axis = 3 : i64
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-func.func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
- // expected-error at +1 {{axis must be specified for axisStats}}
- %1 = "quant.stats"(%arg0) {
- layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
- axisStats = dense<[
- [-1.0, 1.0],
- [-8.0, 8.0],
- [-0.5, 0.5]
- ]> : tensor<3x2xf32>
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %1 : tensor<8x4x3xf32>
-}
-
-// -----
diff --git a/mlir/test/Dialect/Quant/parse-ops.mlir b/mlir/test/Dialect/Quant/parse-ops.mlir
deleted file mode 100644
index c20b0deb49865..0000000000000
--- a/mlir/test/Dialect/Quant/parse-ops.mlir
+++ /dev/null
@@ -1,64 +0,0 @@
-// RUN: mlir-opt %s -split-input-file | FileCheck %s
-
-// -----
-// CHECK-LABEL: validConstFakeQuant
-func.func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
- %0 = "quant.const_fake_quant"(%arg0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- %1 = "quant.const_fake_quant"(%0) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = false
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- %2 = "quant.const_fake_quant"(%1) {
- min = 0.0 : f32, max = 1.0 : f32, num_bits = 8
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %2 : tensor<8x4x3xf32>
-}
-
-// -----
-// CHECK-LABEL: validConstFakeQuantPerAxis
-func.func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> {
- %0 = "quant.const_fake_quant_per_axis"(%arg0) {
- min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true
- } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
- %1 = "quant.const_fake_quant_per_axis"(%0) {
- min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false
- } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
- %2 = "quant.const_fake_quant_per_axis"(%1) {
- min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8
- } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
- return %2 : tensor<8x4x2xf32>
-}
-
-// -----
-// CHECK-LABEL: validStatisticsRef
-func.func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
- %0 = "quant.stats_ref"(%arg0) { statsKey = "foobar" } :
- (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
-
-// -----
-// CHECK-LABEL: validStatistics
-func.func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
- %0 = "quant.stats"(%arg0) {
- layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- %1 = "quant.stats"(%0) {
- layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
- axisStats = dense<[
- [-1.0, 1.0],
- [-8.0, 8.0],
- [-0.5, 0.5]
- ]> : tensor<3x2xf32>, axis = 2 : i64
- } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %1 : tensor<8x4x3xf32>
-}
-
-// -----
-// CHECK-LABEL: validCoupledRef
-func.func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
- %0 = "quant.coupled_ref"(%arg0) { coupledKey = "foobar" } :
- (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
- return %0 : tensor<8x4x3xf32>
-}
diff --git a/mlir/test/Dialect/Quant/quant_region.mlir b/mlir/test/Dialect/Quant/quant_region.mlir
deleted file mode 100644
index 437edf8f3e04f..0000000000000
--- a/mlir/test/Dialect/Quant/quant_region.mlir
+++ /dev/null
@@ -1,131 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
-
-// CHECK-LABEL: @source
-func.func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: @annotated
-func.func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, f32],
- output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: @quantized
-func.func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f32, 2.0>],
- output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-func.func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform<i32:f16, 3.000000e+00> and input type 'tensor<4xf32>'}}
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f16, 3.0>],
- output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-func.func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}}
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, i32],
- output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-func.func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform<i32:f16, 4.000000e+00> and output type 'tensor<4xf32>'}}
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i8:f32, 3.0>],
- output_specs = [!quant.uniform<i32:f16, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-func.func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and output type 'tensor<4xf32>'}}
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f32, 2.0>],
- output_specs = [i32], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-func.func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- // @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}}
- %0 = "quant.region"(%arg0, %arg1, %arg2) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
- output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-func.func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
- // @expected-note @+1 {{required by region isolation constraints}}
- %0 = "quant.region"(%arg0, %arg1) ({
- ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>):
- %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- // @expected-error @+1 {{'bar' op using value defined outside the region}}
- %14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- "quant.return"(%14) : (tensor<4xf32>) -> ()
- }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
- output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
- : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
- return %0 : tensor<4xf32>
-}
-
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index cdff10654635b..befbffcf07561 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -9,7 +9,6 @@ target_link_libraries(MLIRDialectTests
add_subdirectory(Affine)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
-add_subdirectory(Quant)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
add_subdirectory(Transform)
diff --git a/mlir/unittests/Dialect/Quant/CMakeLists.txt b/mlir/unittests/Dialect/Quant/CMakeLists.txt
deleted file mode 100644
index 8b4e76cdd1ff2..0000000000000
--- a/mlir/unittests/Dialect/Quant/CMakeLists.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-add_mlir_unittest(MLIRQuantTests
- QuantizationUtilsTest.cpp
-)
-target_link_libraries(MLIRQuantTests
- PRIVATE
- MLIRQuantDialect
- MLIRQuantUtils
- )
diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
deleted file mode 100644
index 0b4085911675f..0000000000000
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ /dev/null
@@ -1,172 +0,0 @@
-//===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantizeUtils.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-using namespace mlir;
-using namespace mlir::quant;
-
-namespace {
-
-// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
-class TestUniformQuantizedValueConverter
- : public UniformQuantizedValueConverter {
-public:
- TestUniformQuantizedValueConverter(UniformQuantizedType type)
- : UniformQuantizedValueConverter(type), qtype(type) {}
- APInt quantizeFloatToInt(APFloat expressedValue) const override {
- return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
- }
-
-private:
- UniformQuantizedType qtype;
-};
-
-Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
- return FloatAttr::get(FloatType::getF32(ctx), value);
-}
-
-template <typename ConcreteAttrClass, typename... Arg>
-ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
- Arg... value) {
- auto eleType = FloatType::getF32(ctx);
- ShapedType tensorType;
- if (shape.size() == 1 && shape[0] == -1) {
- tensorType = UnrankedTensorType::get(eleType);
- } else {
- tensorType = RankedTensorType::get(shape, eleType);
- }
- return ConcreteAttrClass::get(tensorType, value...);
-}
-
-ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
- ArrayRef<int64_t> shape) {
- auto eleType = FloatType::getF32(ctx);
- ShapedType tensorType;
- if (shape.size() == 1 && shape[0] == -1) {
- tensorType = UnrankedTensorType::get(eleType);
- } else {
- tensorType = RankedTensorType::get(shape, eleType);
- }
- auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
- auto indices =
- DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
- auto valuesType = RankedTensorType::get({1}, eleType);
- auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
- return SparseElementsAttr::get(tensorType, indices, values);
-}
-
-UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
- return UniformQuantizedType::get(/*flags=*/false, storageType,
- FloatType::getF32(ctx), /*scale=*/1.0,
- /*zeroPoint=*/0, /*storageTypeMin=*/0,
- /*storageTypeMax=*/255);
-}
-
-TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
- MLIRContext ctx;
- ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(&ctx, 8);
- auto quantizedType = getTestQuantizedType(convertedType, &ctx);
- TestUniformQuantizedValueConverter converter(quantizedType);
-
- auto realValue = getTestFloatAttr(1.0, &ctx);
- Type typeResult;
- auto valueResult =
- quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
-
- EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
- EXPECT_EQ(
- valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
- convertedType.getWidth());
-}
-
-TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
- MLIRContext ctx;
- ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(&ctx, 8);
- auto quantizedType = getTestQuantizedType(convertedType, &ctx);
- TestUniformQuantizedValueConverter converter(quantizedType);
- auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
- &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
-
- Type returnedType;
- auto returnedValue =
- quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
-
- // Check Elements attribute shape and kind are not changed.
- auto tensorType = returnedType.cast<TensorType>();
- auto expectedTensorType = realValue.getType().cast<TensorType>();
- EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
- EXPECT_EQ(tensorType.getElementType(), convertedType);
- EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
-
- // Check Elements attribute element value is expected.
- auto firstValue =
- returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
- EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
-}
-
-TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
- MLIRContext ctx;
- ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(&ctx, 8);
- auto quantizedType = getTestQuantizedType(convertedType, &ctx);
- TestUniformQuantizedValueConverter converter(quantizedType);
- auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
- &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
-
- Type returnedType;
- auto returnedValue =
- quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
-
- // Check Elements attribute shape and kind are not changed.
- auto tensorType = returnedType.cast<TensorType>();
- auto expectedTensorType = realValue.getType().cast<TensorType>();
- EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
- EXPECT_EQ(tensorType.getElementType(), convertedType);
- EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
-
- // Check Elements attribute element value is expected.
- auto firstValue =
- returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
- EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
-}
-
-TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
- MLIRContext ctx;
- ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(&ctx, 8);
- auto quantizedType = getTestQuantizedType(convertedType, &ctx);
- TestUniformQuantizedValueConverter converter(quantizedType);
- auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
-
- Type returnedType;
- auto returnedValue =
- quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
-
- // Check Elements attribute shape and kind are not changed.
- auto tensorType = returnedType.cast<TensorType>();
- auto expectedTensorType = realValue.getType().cast<TensorType>();
- EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
- EXPECT_EQ(tensorType.getElementType(), convertedType);
- EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
-
- // Check Elements attribute element value is expected.
- auto firstValue =
- returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
- EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
-}
-
-} // namespace
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index fd56236ec9590..4fcc75f1547ea 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6264,7 +6264,6 @@ cc_library(
":PDLInterpDialect",
":PDLToPDLInterp",
":QuantOps",
- ":QuantPassIncGen",
":ROCDLDialect",
":ReconcileUnrealizedCasts",
":SCFDialect",
@@ -7003,23 +7002,6 @@ gentbl_cc_library(
deps = [":QuantizationOpsTdFiles"],
)
-gentbl_cc_library(
- name = "QuantPassIncGen",
- strip_include_prefix = "include",
- tbl_outs = [
- (
- [
- "-gen-pass-decls",
- "-name=Quant",
- ],
- "include/mlir/Dialect/Quant/Passes.h.inc",
- ),
- ],
- tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Quant/Passes.td",
- deps = [":PassBaseTdFiles"],
-)
-
cc_library(
name = "QuantOps",
srcs = [
@@ -7027,11 +7009,7 @@ cc_library(
"lib/Dialect/Quant/IR/QuantTypes.cpp",
"lib/Dialect/Quant/IR/TypeDetail.h",
"lib/Dialect/Quant/IR/TypeParser.cpp",
- "lib/Dialect/Quant/Transforms/ConvertConst.cpp",
- "lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp",
- "lib/Dialect/Quant/Transforms/PassDetail.h",
"lib/Dialect/Quant/Utils/FakeQuantSupport.cpp",
- "lib/Dialect/Quant/Utils/QuantizeUtils.cpp",
"lib/Dialect/Quant/Utils/UniformSupport.cpp",
],
hdrs = [
@@ -7039,7 +7017,6 @@ cc_library(
"include/mlir/Dialect/Quant/Passes.h",
"include/mlir/Dialect/Quant/QuantOps.h",
"include/mlir/Dialect/Quant/QuantTypes.h",
- "include/mlir/Dialect/Quant/QuantizeUtils.h",
"include/mlir/Dialect/Quant/UniformSupport.h",
],
includes = ["include"],
@@ -7050,7 +7027,6 @@ cc_library(
":InferTypeOpInterface",
":Pass",
":QuantOpsIncGen",
- ":QuantPassIncGen",
":SideEffectInterfaces",
":TransformUtils",
"//llvm:Support",
More information about the Mlir-commits
mailing list