[Mlir-commits] [mlir] 0312b25 - [mlir][tosa] Add tosa.table lowering to linalg.generic
Rob Suderman
llvmlistbot at llvm.org
Tue Apr 6 13:58:30 PDT 2021
Author: Rob Suderman
Date: 2021-04-06T13:57:18-07:00
New Revision: 0312b25df0a872295f8db203fbebfb4a0d7f0f3e
URL: https://github.com/llvm/llvm-project/commit/0312b25df0a872295f8db203fbebfb4a0d7f0f3e
DIFF: https://github.com/llvm/llvm-project/commit/0312b25df0a872295f8db203fbebfb4a0d7f0f3e.diff
LOG: [mlir][tosa] Add tosa.table lowering to linalg.generic
Table op lowering to linalg.generic for both i8 (behaves like a gather) and a
pair of gathers with a quantized interpolation.
Differential Revision: https://reviews.llvm.org/D99756
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3fdbc6f89733b..a6271f7097563 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1407,37 +1407,178 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
}
};
+// Lowerings the TableOp to a series of gathers and numerica operations. This
+// includes interpolation between the high/low values. For the I8 varient, this
+// simplifies to a single gather operation.
+class TableConverter : public OpRewritePattern<tosa::TableOp> {
+public:
+ using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TableOp op,
+ PatternRewriter &rewriter) const final {
+ auto loc = op.getLoc();
+ Value input = op.input();
+ Value table = op.table();
+ auto inputTy = input.getType().cast<ShapedType>();
+ auto tableTy = table.getType().cast<ShapedType>();
+ auto resultTy = op.getType().cast<ShapedType>();
+
+ if (!inputTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "require input type to have static shape");
+
+ auto inputElementTy = inputTy.getElementType();
+ auto tableElementTy = tableTy.getElementType();
+ auto resultElementTy = resultTy.getElementType();
+
+ auto initTensor =
+ rewriter
+ .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
+ resultTy.getShape(), resultElementTy)
+ .result();
+
+ SmallVector<AffineMap, 2> affineMaps = {
+ rewriter.getMultiDimIdentityMap(resultTy.getRank()),
+ rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
+ getNParallelLoopsAttrs(resultTy.getRank()));
+ rewriter.replaceOp(op, genericOp.getResult(0));
+
+ {
+ OpBuilder::InsertionGuard regionGuard(rewriter);
+ Block *block =
+ rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
+ TypeRange({inputElementTy, resultElementTy}));
+
+ auto inputValue = block->getArgument(0);
+ rewriter.setInsertionPointToStart(block);
+ if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
+ resultElementTy.isInteger(8)) {
+ Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
+ inputValue);
+ Value extract =
+ rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
+ rewriter.create<linalg::YieldOp>(loc, extract);
+ return success();
+ }
+
+ if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
+ resultElementTy.isInteger(32)) {
+ Value extend = rewriter.create<SignExtendIOp>(
+ loc, rewriter.getI32Type(), inputValue);
+
+ auto offset =
+ rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
+ auto seven =
+ rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
+ auto one =
+ rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
+ auto b1111111 =
+ rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
+
+ // Compute the index and fractional part from the input value:
+ // value = value + 32768
+ // index = value >> 7;
+ // fraction = 0x01111111 & value
+ auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
+ Value index =
+ rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
+ Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
+
+ // Extract the base and next values from the table.
+ // base = (int32_t) table[index];
+ // next = (int32_t) table[index + 1];
+ Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
+
+ index =
+ rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
+ indexPlusOne = rewriter.create<IndexCastOp>(
+ loc, rewriter.getIndexType(), indexPlusOne);
+
+ Value base =
+ rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
+ Value next = rewriter.create<tensor::ExtractOp>(
+ loc, table, ValueRange{indexPlusOne});
+
+ base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
+ next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
+
+ // Use the fractional part to interpolate between the input values:
+ // result = (base << 7) + (next - base) * fraction
+ Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
+ Value
diff = rewriter.create<SubIOp>(loc, next, base);
+ Value
diff Scaled = rewriter.create<MulIOp>(loc,
diff , fraction);
+ Value result = rewriter.create<AddIOp>(loc, baseScaled,
diff Scaled);
+
+ rewriter.create<linalg::YieldOp>(loc, result);
+
+ return success();
+ }
+ }
+
+ return rewriter.notifyMatchFailure(
+ op, "unable to create body for tosa.table op");
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<
- PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
- PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::ReciprocalOp>,
- PointwiseConverter<tosa::NegateOp>, PointwiseConverter<tosa::PowOp>,
- PointwiseConverter<tosa::RsqrtOp>, PointwiseConverter<tosa::LogOp>,
- PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
- PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
+ // clang-format off
+ PointwiseConverter<tosa::AddOp>,
+ PointwiseConverter<tosa::SubOp>,
+ PointwiseConverter<tosa::MulOp>,
+ PointwiseConverter<tosa::NegateOp>,
+ PointwiseConverter<tosa::PowOp>,
+ PointwiseConverter<tosa::ReciprocalOp>,
+ PointwiseConverter<tosa::RsqrtOp>,
+ PointwiseConverter<tosa::LogOp>,
+ PointwiseConverter<tosa::ExpOp>,
+ PointwiseConverter<tosa::AbsOp>,
+ PointwiseConverter<tosa::TanhOp>,
+ PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseNotOp>,
PointwiseConverter<tosa::BitwiseXorOp>,
PointwiseConverter<tosa::LogicalAndOp>,
PointwiseConverter<tosa::LogicalNotOp>,
PointwiseConverter<tosa::LogicalOrOp>,
- PointwiseConverter<tosa::LogicalXorOp>, PointwiseConverter<tosa::CastOp>,
+ PointwiseConverter<tosa::LogicalXorOp>,
+ PointwiseConverter<tosa::CastOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
- PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
+ PointwiseConverter<tosa::SelectOp>,
+ PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
- PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
- PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
- PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
- PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
- IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceAllOp>,
- ReduceConverter<tosa::ReduceAnyOp>, ReduceConverter<tosa::ReduceMinOp>,
- ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
- ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
- PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,
- TileConverter, TransposeConverter, MatMulConverter,
+ PointwiseConverter<tosa::MaximumOp>,
+ PointwiseConverter<tosa::MinimumOp>,
+ PointwiseConverter<tosa::CeilOp>,
+ PointwiseConverter<tosa::FloorOp>,
+ PointwiseConverter<tosa::ClampOp>,
+ PointwiseConverter<tosa::ReluNOp>,
+ PointwiseConverter<tosa::SigmoidOp>,
+ IdentityNConverter<tosa::IdentityOp>,
+ IdentityNConverter<tosa::IdentityNOp>,
+ ReduceConverter<tosa::ReduceAllOp>,
+ ReduceConverter<tosa::ReduceAnyOp>,
+ ReduceConverter<tosa::ReduceMinOp>,
+ ReduceConverter<tosa::ReduceMaxOp>,
+ ReduceConverter<tosa::ReduceSumOp>,
+ ReduceConverter<tosa::ReduceProdOp>,
+ ArgMaxConverter,
+ ConcatConverter,
+ PadConverter,
+ ReshapeConverter,
+ RescaleConverter,
+ ReverseConverter,
+ TableConverter,
+ TileConverter,
+ TransposeConverter,
+ MatMulConverter,
FullyConnectedConverter>(patterns->getContext());
+ // clang-format on
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1bc4d6a8d0132..5d77c932bf121 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -830,3 +830,46 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: @table8
+func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
+ // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
+ // CHECK: %[[CAST:.+]] = index_cast %[[ARG_IN]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[CAST]]]
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor<513xi8>) -> (tensor<6xi8>)
+ return
+}
+
+// CHECK-LABEL: @table16
+func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi16>) outs(%[[INIT]] : tensor<6xi32>)
+ // CHECK: ^bb0(%arg2: i16, %arg3: i32)
+ // CHECK: %[[EXT_IN:.+]] = sexti %arg2
+ // CHECK: %[[C32768:.+]] = constant 32768
+ // CHECK: %[[C7:.+]] = constant 7
+ // CHECK: %[[C1:.+]] = constant 1
+ // CHECK: %[[C127:.+]] = constant 127
+ // CHECK: %[[INADD:.+]] = addi %[[EXT_IN]], %[[C32768]]
+ // CHECK: %[[IDX:.+]] = shift_right_unsigned %[[INADD]], %[[C7]]
+ // CHECK: %[[FRACTION:.+]] = and %[[INADD]], %[[C127]]
+ // CHECK: %[[IDXPLUS1:.+]] = addi %[[IDX]], %[[C1]]
+ // CHECK: %[[IDX_CAST:.+]] = index_cast %[[IDX]]
+ // CHECK: %[[IDXPLUS1_CAST:.+]] = index_cast %[[IDXPLUS1]]
+ // CHECK: %[[BASE:.+]] = tensor.extract %arg1[%[[IDX_CAST]]]
+ // CHECK: %[[NEXT:.+]] = tensor.extract %arg1[%[[IDXPLUS1_CAST]]]
+ // CHECK: %[[BASE_EXT:.+]] = sexti %[[BASE]]
+ // CHECK: %[[NEXT_EXT:.+]] = sexti %[[NEXT]]
+ // CHECK: %[[BASE_MUL:.+]] = shift_left %[[BASE_EXT]], %[[C7]]
+ // CHECK: %[[DIFF:.+]] = subi %[[NEXT_EXT]], %[[BASE_EXT]]
+ // CHECK: %[[DIFF_MUL:.+]] = muli %[[DIFF]], %[[FRACTION]]
+ // CHECK: %[[RESULT:.+]] = addi %[[BASE_MUL]], %[[DIFF_MUL]]
+ // CHECK: linalg.yield %[[RESULT]]
+ %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>)
+ return
+}
More information about the Mlir-commits
mailing list