[Mlir-commits] [mlir] 166f83f - [QuantOps] Add the quant region definition
Feng Liu
llvmlistbot at llvm.org
Mon Mar 16 15:49:13 PDT 2020
Author: Feng Liu
Date: 2020-03-16T15:44:43-07:00
New Revision: 166f83f436608d599f05f0c3d4eb7b5920c0d2e6
URL: https://github.com/llvm/llvm-project/commit/166f83f436608d599f05f0c3d4eb7b5920c0d2e6
DIFF: https://github.com/llvm/llvm-project/commit/166f83f436608d599f05f0c3d4eb7b5920c0d2e6.diff
LOG: [QuantOps] Add the quant region definition
Summary:
This regional op in the QuantOps dialect will be used to wrap
high-precision ops into atomic units for quantization. All the values
used by the internal ops are captured explicitly by the op inputs. The
quantization parameters of the inputs and outputs are stored in the
attributes.
Subscribers: jfb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D75972
Added:
mlir/test/Dialect/QuantOps/quant_region.mlir
Modified:
mlir/include/mlir/Dialect/QuantOps/QuantOps.td
mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
index 92e1e1d813ed..0047b41efd60 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
@@ -83,6 +83,36 @@ def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let hasFolder = 1;
}
+// A QuantizeRegion (region) represents a quantization unit which wraps
+// high-precision ops with quantization specifications for all the inputs
+// and outputs. Some quantization specifications can be undetermined and
+// derived from other ports by the target specification of the kernel.
+def quant_QuantizeRegionOp : quant_Op<"region", [
+ NoSideEffect,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"ReturnOp">]> {
+ let summary = [{
+ The `region operation wraps high-precision ops as a logical low-precision
+ quantized kernel.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$inputs,
+ TypeArrayAttr:$input_specs,
+ TypeArrayAttr:$output_specs,
+ StrAttr:$logical_kernel);
+ let results = (outs Variadic<AnyType>:$outputs);
+ let regions = (region SizedRegion<1>:$body);
+ let verifier = [{ return verifyRegionOp(*this); }];
+}
+
+def quant_ReturnOp : quant_Op<"return", [Terminator]> {
+ let summary = [{
+ The `return` operation terminates a quantize region and returns values.
+ }];
+
+ let arguments = (ins Variadic<AnyTensor>:$results);
+}
+
//===----------------------------------------------------------------------===//
// Training integration and instrumentation ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index 9a678260415a..f87330cff016 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -34,13 +34,63 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
}
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
- /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
- /// value of x if the casts invert each other.
+ // Matches x -> [scast -> scast] -> y, replacing the second scast with the
+ // value of x if the casts invert each other.
auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
if (!srcScastOp || srcScastOp.arg().getType() != getType())
return OpFoldResult();
return srcScastOp.arg();
}
+/// The quantization specification should match the expressed type.
+static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
+ if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
+ Type spec = typeAttr.getValue();
+ if (spec.isa<TensorType>() || spec.isa<VectorType>())
+ return false;
+
+ // The spec should be either a quantized type which is compatible to the
+ // expressed type, or a primitive type which is as same as the
+ // (element type of) the expressed type.
+ if (auto quantizedType = spec.dyn_cast<QuantizedType>())
+ return quantizedType.isCompatibleExpressedType(expressed);
+
+ if (auto tensorType = expressed.dyn_cast<TensorType>())
+ return spec == tensorType.getElementType();
+
+ if (auto vectorType = expressed.dyn_cast<VectorType>())
+ return spec == vectorType.getElementType();
+ }
+ return false;
+}
+
+static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
+ // There are specifications for both inputs and outputs.
+ if (op.getNumOperands() != op.input_specs().size() ||
+ op.getNumResults() != op.output_specs().size())
+ return op.emitOpError(
+ "has unmatched operands/results number and spec attributes number");
+
+ // Verify that quantization specifications are valid.
+ for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
+ Type inputType = std::get<0>(input);
+ Attribute inputSpec = std::get<1>(input);
+ if (!isValidQuantizationSpec(inputSpec, inputType)) {
+ return op.emitOpError() << "has incompatible specification " << inputSpec
+ << " and input type " << inputType;
+ }
+ }
+
+ for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
+ Type outputType = std::get<0>(result);
+ Attribute outputSpec = std::get<1>(result);
+ if (!isValidQuantizationSpec(outputSpec, outputType)) {
+ return op.emitOpError() << "has incompatible specification " << outputSpec
+ << " and output type " << outputType;
+ }
+ }
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
index 1dc9a0596a8b..d4b3b7404773 100644
--- a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
+++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
@@ -60,7 +60,7 @@ struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
// Op handlers.
addOpHandler<ConstantOp>(
std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
- addOpHandler<ReturnOp>(
+ addOpHandler<mlir::ReturnOp>(
std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
addOpHandler<quant::StatisticsOp>(
std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
diff --git a/mlir/test/Dialect/QuantOps/quant_region.mlir b/mlir/test/Dialect/QuantOps/quant_region.mlir
new file mode 100644
index 000000000000..ee874211a7ac
--- /dev/null
+++ b/mlir/test/Dialect/QuantOps/quant_region.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @source
+func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ %0 = "quant.region"(%arg0, %arg1, %arg2) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: @annotated
+func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ %0 = "quant.region"(%arg0, %arg1, %arg2) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, f32],
+ output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: @quantized
+func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ %0 = "quant.region"(%arg0, %arg1, %arg2) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f32, 2.0>],
+ output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform<i32:f16, 3.000000e+00> and input type 'tensor<4xf32>'}}
+ %0 = "quant.region"(%arg0, %arg1, %arg2) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f16, 3.0>],
+ output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}}
+ %0 = "quant.region"(%arg0, %arg1, %arg2) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, i32],
+ output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ // @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}}
+ %0 = "quant.region"(%arg0, %arg1, %arg2) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
+ output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ // @expected-note @+1 {{required by region isolation constraints}}
+ %0 = "quant.region"(%arg0, %arg1) ({
+ ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>):
+ %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ // @expected-error @+1 {{'bar' op using value defined outside the region}}
+ %14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ "quant.return"(%14) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
+ output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
+ : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
+ return %0 : tensor<4xf32>
+}
+
More information about the Mlir-commits
mailing list