[Mlir-commits] [mlir] 0312b25 - [mlir][tosa] Add tosa.table lowering to linalg.generic

Rob Suderman llvmlistbot at llvm.org
Tue Apr 6 13:58:30 PDT 2021


Author: Rob Suderman
Date: 2021-04-06T13:57:18-07:00
New Revision: 0312b25df0a872295f8db203fbebfb4a0d7f0f3e

URL: https://github.com/llvm/llvm-project/commit/0312b25df0a872295f8db203fbebfb4a0d7f0f3e
DIFF: https://github.com/llvm/llvm-project/commit/0312b25df0a872295f8db203fbebfb4a0d7f0f3e.diff

LOG: [mlir][tosa] Add tosa.table lowering to linalg.generic

Table op lowering to linalg.generic for both i8 (behaves like a gather) and a
pair of gathers with a quantized interpolation.

Differential Revision: https://reviews.llvm.org/D99756

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3fdbc6f89733b..a6271f7097563 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1407,37 +1407,178 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
   }
 };
 
+// Lowerings the TableOp to a series of gathers and numerica operations. This
+// includes interpolation between the high/low values. For the I8 varient, this
+// simplifies to a single gather operation.
+class TableConverter : public OpRewritePattern<tosa::TableOp> {
+public:
+  using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TableOp op,
+                                PatternRewriter &rewriter) const final {
+    auto loc = op.getLoc();
+    Value input = op.input();
+    Value table = op.table();
+    auto inputTy = input.getType().cast<ShapedType>();
+    auto tableTy = table.getType().cast<ShapedType>();
+    auto resultTy = op.getType().cast<ShapedType>();
+
+    if (!inputTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          op, "require input type to have static shape");
+
+    auto inputElementTy = inputTy.getElementType();
+    auto tableElementTy = tableTy.getElementType();
+    auto resultElementTy = resultTy.getElementType();
+
+    auto initTensor =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
+                                          resultTy.getShape(), resultElementTy)
+            .result();
+
+    SmallVector<AffineMap, 2> affineMaps = {
+        rewriter.getMultiDimIdentityMap(resultTy.getRank()),
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()));
+    rewriter.replaceOp(op, genericOp.getResult(0));
+
+    {
+      OpBuilder::InsertionGuard regionGuard(rewriter);
+      Block *block =
+          rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
+                               TypeRange({inputElementTy, resultElementTy}));
+
+      auto inputValue = block->getArgument(0);
+      rewriter.setInsertionPointToStart(block);
+      if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
+          resultElementTy.isInteger(8)) {
+        Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
+                                                   inputValue);
+        Value extract =
+            rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
+        rewriter.create<linalg::YieldOp>(loc, extract);
+        return success();
+      }
+
+      if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
+          resultElementTy.isInteger(32)) {
+        Value extend = rewriter.create<SignExtendIOp>(
+            loc, rewriter.getI32Type(), inputValue);
+
+        auto offset =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
+        auto seven =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
+        auto one =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
+        auto b1111111 =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
+
+        // Compute the index and fractional part from the input value:
+        // value = value + 32768
+        // index = value >> 7;
+        // fraction = 0x01111111 & value
+        auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
+        Value index =
+            rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
+        Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
+
+        // Extract the base and next values from the table.
+        // base = (int32_t) table[index];
+        // next = (int32_t) table[index + 1];
+        Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
+
+        index =
+            rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
+        indexPlusOne = rewriter.create<IndexCastOp>(
+            loc, rewriter.getIndexType(), indexPlusOne);
+
+        Value base =
+            rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
+        Value next = rewriter.create<tensor::ExtractOp>(
+            loc, table, ValueRange{indexPlusOne});
+
+        base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
+        next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
+
+        // Use the fractional part to interpolate between the input values:
+        // result = (base << 7) + (next - base) * fraction
+        Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
+        Value 
diff  = rewriter.create<SubIOp>(loc, next, base);
+        Value 
diff Scaled = rewriter.create<MulIOp>(loc, 
diff , fraction);
+        Value result = rewriter.create<AddIOp>(loc, baseScaled, 
diff Scaled);
+
+        rewriter.create<linalg::YieldOp>(loc, result);
+
+        return success();
+      }
+    }
+
+    return rewriter.notifyMatchFailure(
+        op, "unable to create body for tosa.table op");
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<
-      PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
-      PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::ReciprocalOp>,
-      PointwiseConverter<tosa::NegateOp>, PointwiseConverter<tosa::PowOp>,
-      PointwiseConverter<tosa::RsqrtOp>, PointwiseConverter<tosa::LogOp>,
-      PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
-      PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
+      // clang-format off
+      PointwiseConverter<tosa::AddOp>,
+      PointwiseConverter<tosa::SubOp>,
+      PointwiseConverter<tosa::MulOp>,
+      PointwiseConverter<tosa::NegateOp>,
+      PointwiseConverter<tosa::PowOp>,
+      PointwiseConverter<tosa::ReciprocalOp>,
+      PointwiseConverter<tosa::RsqrtOp>,
+      PointwiseConverter<tosa::LogOp>,
+      PointwiseConverter<tosa::ExpOp>,
+      PointwiseConverter<tosa::AbsOp>,
+      PointwiseConverter<tosa::TanhOp>,
+      PointwiseConverter<tosa::BitwiseAndOp>,
       PointwiseConverter<tosa::BitwiseOrOp>,
       PointwiseConverter<tosa::BitwiseNotOp>,
       PointwiseConverter<tosa::BitwiseXorOp>,
       PointwiseConverter<tosa::LogicalAndOp>,
       PointwiseConverter<tosa::LogicalNotOp>,
       PointwiseConverter<tosa::LogicalOrOp>,
-      PointwiseConverter<tosa::LogicalXorOp>, PointwiseConverter<tosa::CastOp>,
+      PointwiseConverter<tosa::LogicalXorOp>,
+      PointwiseConverter<tosa::CastOp>,
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
       PointwiseConverter<tosa::LogicalRightShiftOp>,
-      PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
+      PointwiseConverter<tosa::SelectOp>,
+      PointwiseConverter<tosa::GreaterOp>,
       PointwiseConverter<tosa::GreaterEqualOp>,
-      PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
-      PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
-      PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
-      PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
-      IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceAllOp>,
-      ReduceConverter<tosa::ReduceAnyOp>, ReduceConverter<tosa::ReduceMinOp>,
-      ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
-      ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
-      PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,
-      TileConverter, TransposeConverter, MatMulConverter,
+      PointwiseConverter<tosa::MaximumOp>,
+      PointwiseConverter<tosa::MinimumOp>,
+      PointwiseConverter<tosa::CeilOp>,
+      PointwiseConverter<tosa::FloorOp>,
+      PointwiseConverter<tosa::ClampOp>,
+      PointwiseConverter<tosa::ReluNOp>,
+      PointwiseConverter<tosa::SigmoidOp>,
+      IdentityNConverter<tosa::IdentityOp>,
+      IdentityNConverter<tosa::IdentityNOp>,
+      ReduceConverter<tosa::ReduceAllOp>,
+      ReduceConverter<tosa::ReduceAnyOp>,
+      ReduceConverter<tosa::ReduceMinOp>,
+      ReduceConverter<tosa::ReduceMaxOp>,
+      ReduceConverter<tosa::ReduceSumOp>,
+      ReduceConverter<tosa::ReduceProdOp>,
+      ArgMaxConverter,
+      ConcatConverter,
+      PadConverter,
+      ReshapeConverter,
+      RescaleConverter,
+      ReverseConverter,
+      TableConverter,
+      TileConverter,
+      TransposeConverter,
+      MatMulConverter,
       FullyConnectedConverter>(patterns->getContext());
+      // clang-format on
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1bc4d6a8d0132..5d77c932bf121 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -830,3 +830,46 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @table8
+func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
+  // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
+  // CHECK:   %[[CAST:.+]] = index_cast %[[ARG_IN]]
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %arg1[%[[CAST]]]
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.table"(%arg0, %arg1)  : (tensor<6xi8>, tensor<513xi8>)  -> (tensor<6xi8>)
+  return
+}
+
+// CHECK-LABEL: @table16
+func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi16>) outs(%[[INIT]] : tensor<6xi32>)
+  // CHECK: ^bb0(%arg2: i16, %arg3: i32)
+  // CHECK: %[[EXT_IN:.+]] = sexti %arg2
+  // CHECK: %[[C32768:.+]] = constant 32768
+  // CHECK: %[[C7:.+]] = constant 7
+  // CHECK: %[[C1:.+]] = constant 1
+  // CHECK: %[[C127:.+]] = constant 127
+  // CHECK: %[[INADD:.+]] = addi %[[EXT_IN]], %[[C32768]]
+  // CHECK: %[[IDX:.+]] = shift_right_unsigned %[[INADD]], %[[C7]]
+  // CHECK: %[[FRACTION:.+]] = and %[[INADD]], %[[C127]]
+  // CHECK: %[[IDXPLUS1:.+]] = addi %[[IDX]], %[[C1]]
+  // CHECK: %[[IDX_CAST:.+]] = index_cast %[[IDX]]
+  // CHECK: %[[IDXPLUS1_CAST:.+]] = index_cast %[[IDXPLUS1]]
+  // CHECK: %[[BASE:.+]] = tensor.extract %arg1[%[[IDX_CAST]]]
+  // CHECK: %[[NEXT:.+]] = tensor.extract %arg1[%[[IDXPLUS1_CAST]]]
+  // CHECK: %[[BASE_EXT:.+]] = sexti %[[BASE]]
+  // CHECK: %[[NEXT_EXT:.+]] = sexti %[[NEXT]]
+  // CHECK: %[[BASE_MUL:.+]] = shift_left %[[BASE_EXT]], %[[C7]]
+  // CHECK: %[[DIFF:.+]] = subi %[[NEXT_EXT]], %[[BASE_EXT]]
+  // CHECK: %[[DIFF_MUL:.+]] = muli %[[DIFF]], %[[FRACTION]]
+  // CHECK: %[[RESULT:.+]] = addi %[[BASE_MUL]], %[[DIFF_MUL]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %0 = "tosa.table"(%arg0, %arg1)  : (tensor<6xi16>, tensor<513xi16>)  -> (tensor<6xi32>)
+  return
+}


        


More information about the Mlir-commits mailing list