[Mlir-commits] [mlir] 4c48f7e - [mlir][tosa] Create basic dynamic shape support for several ops.

Rob Suderman llvmlistbot at llvm.org
Wed Oct 6 10:37:27 PDT 2021


Author: natashaknk
Date: 2021-10-06T10:36:04-07:00
New Revision: 4c48f7e29b7014af5ba8292a508b8386e6b00f03

URL: https://github.com/llvm/llvm-project/commit/4c48f7e29b7014af5ba8292a508b8386e6b00f03
DIFF: https://github.com/llvm/llvm-project/commit/4c48f7e29b7014af5ba8292a508b8386e6b00f03.diff

LOG: [mlir][tosa] Create basic dynamic shape support for several ops.

Transpose, Matmul and Fully-connected dynamic shape support

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D111167

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 77e4c267bfbf..f24a849810e2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -91,6 +91,14 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
       .result();
 }
 
+static SmallVector<Value> filterDynamicDims(SmallVector<Value> dynDims) {
+  SmallVector<Value> filteredDims;
+  for (auto dim : dynDims)
+    if (dim)
+      filteredDims.push_back(dim);
+  return filteredDims;
+}
+
 static Value
 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                             ArrayRef<Type> resultTypes,
@@ -690,10 +698,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
     }
   }
 
-  SmallVector<Value> filteredDims;
-  for (auto dim : dynDims)
-    if (dim)
-      filteredDims.push_back(dim);
+  SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
 
   for (auto result : results) {
     auto resultTy = result.getType().template cast<ShapedType>();
@@ -1355,10 +1360,31 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
 
     auto outputTy = op.getType().cast<ShapedType>();
     auto outputElementTy = outputTy.getElementType();
+
+    auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
+    auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
+
+    SmallVector<Value> dynDims;
+    dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
+
+    if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
+      dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
+    }
+
+    if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
+      dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
+    }
+
+    if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
+      dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
+    }
+
+    SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+
     auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
     Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, outputTy.getShape(), outputTy.getElementType());
+        loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
     Value zeroTensor =
         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
     if (!op.quantization_info()) {
@@ -1393,14 +1419,29 @@ class FullyConnectedConverter
     Location loc = op.getLoc();
     auto outputTy = op.getType().cast<ShapedType>();
     auto input = op.input();
-    auto weight = op.weight();
+    auto inputTy = input.getType().cast<ShapedType>();
+
     auto bias = op.bias();
 
+    auto weight = op.weight();
     auto weightTy = weight.getType().cast<ShapedType>();
     auto weightShape = weightTy.getShape();
 
     auto outputETy = outputTy.getElementType();
 
+    SmallVector<Value> dynDims;
+    dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
+
+    if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
+      dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
+    }
+
+    if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
+      dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
+    }
+
+    SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+
     // Creating maps for the output of MatMul and the bias
     SmallVector<AffineMap, 4> indexingMaps;
 
@@ -1413,7 +1454,7 @@ class FullyConnectedConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
 
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, outputTy.getShape(), outputTy.getElementType());
+        loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
 
     // When quantized, the input elemeny type is not the same as the output
     Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
@@ -1435,7 +1476,8 @@ class FullyConnectedConverter
 
     auto biasInitTensor =
         rewriter
-            .create<linalg::InitTensorOp>(loc, outputTy.getShape(), outputETy)
+            .create<linalg::InitTensorOp>(loc, filteredDims,
+                                          outputTy.getShape(), outputETy)
             ->getResults();
 
     if (!op.quantization_info()) {
@@ -1614,20 +1656,29 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
       return failure();
     }
 
+    auto loc = op.getLoc();
+    auto input = op->getOperand(0);
     auto resultTy = op.getType().cast<ShapedType>();
-    if (!resultTy.hasStaticShape())
-      return failure();
+
+    SmallVector<Value> dynDims;
+    dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
 
     SmallVector<AffineExpr, 2> inputExprs;
     inputExprs.resize(resultTy.getRank());
+    auto operandTy = input.getType().cast<ShapedType>();
     for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
-      inputExprs[permutation.value().getZExtValue()] =
-          rewriter.getAffineDimExpr(permutation.index());
+      auto index = permutation.index();
+      auto value = permutation.value().getZExtValue();
+      if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
+        dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
+      }
+      inputExprs[value] = rewriter.getAffineDimExpr(index);
     }
 
+    SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
-        op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
-        resultTy.getElementType());
+        loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
 
     SmallVector<AffineMap, 2> affineMaps = {
         AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
@@ -1638,7 +1689,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
         op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
         getNParallelLoopsAttrs(resultTy.getRank()),
         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-          nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
+          nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
         });
     return success();
   }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index ba9b0d25ab5c..1c81a2aa731c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -592,6 +592,48 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @test_transpose_dyn
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
+func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
+  %0 = constant dense<[1, 3, 0, 2]> : tensor<4xi32>
+  // CHECK: %[[C1:.+]] = constant 1
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 4, 1, 3]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs([[OUT:%.+]] : tensor<?x4x1x3xi32>)
+  // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
+  // CHECK:   linalg.yield [[ARG1]]
+  // CHECK: }
+  %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> (tensor<?x4x1x3xi32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_transpose_dyn
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
+func @test_transpose_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
+  %0 = constant dense<[1, 0]> : tensor<2xi32>
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[C1:.+]] = constant 1
+  // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM1]], %[[DIM0]]]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs([[OUT:%.+]] : tensor<?x?xf32>)
+  // CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32)
+  // CHECK:   linalg.yield [[ARG1]]
+  // CHECK: }
+  %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?xf32>, tensor<2xi32>) -> (tensor<?x?xf32>)
+  return
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -987,7 +1029,7 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
 
 
 // CHECK-LABEL: @matmul
-func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>, %arg2: tensor<1x6xf32>) -> (tensor<1x5x6xf32>) {
+func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
   // CHECK: [[C0:%.+]] = constant 0
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6]
   // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32>
@@ -1013,6 +1055,46 @@ func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (ten
 
 // -----
 
+// CHECK-LABEL: @matmul_dyn_batch
+func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>) -> (tensor<?x5x6xf32>) {
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[C0_0:.+]] = constant 0
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 5, 6]
+  // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0_0]], %[[INIT]]) : f32, tensor<?x5x6xf32> -> tensor<?x5x6xf32>
+  // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<?x5x3xf32>, tensor<?x3x6xf32>)  -> (tensor<?x5x6xf32>)
+  return %0 : tensor<?x5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @matmul_dyn_independent_dim
+func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>) {
+  // CHECK: %[[C2:.+]] = constant 2
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg1, %[[C2]]
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, %[[DIM]]]
+  // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x?xf32> -> tensor<1x5x?xf32>
+  // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x?xf32>)  -> (tensor<1x5x?xf32>)
+  return %0 : tensor<1x5x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @matmul_dyn_independent_dim
+func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) {
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, 6]
+  // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32>
+  // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
+  %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x?xf32>, tensor<1x?x6xf32>)  -> (tensor<1x5x6xf32>)
+  return %0 : tensor<1x5x6xf32>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
@@ -1055,7 +1137,7 @@ func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %a
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK:   linalg.yield [[IN]] : i8
   // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
-  // CHECK: [[ONE:%.+]] = constant 1 
+  // CHECK: [[ONE:%.+]] = constant 1
   // CHECK: [[TWO:%.+]] = constant 2
   // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
   // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
@@ -1068,6 +1150,31 @@ func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %a
 
 // -----
 
+// CHECK-LABEL: @fully_connected_dyn
+func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INITT:.+]] = linalg.init_tensor [%[[DIM]], 6]
+  // CHECK: %[[ZERO:.+]] = constant 0
+  // CHECK: %[[FILL:.+]] = linalg.fill(%[[ZERO]], %[[INITT]])
+  // CHECK: %[[PERM:.+]] = constant dense<[1, 0]>
+  // CHECK: %[[INITT:.+]] = linalg.init_tensor [3, 6]
+  // CHECK: %[[TRANSPOSE:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs(%[[INITT]] : tensor<3x6xf32>) {
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %[[UNUSED:.+]]: f32):
+  // CHECK:   linalg.yield %[[IN]] : f32
+  // CHECK: %[[INITB:.+]] = linalg.init_tensor [%[[DIM]], 6]
+  // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
+  // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
+  // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+  // CHECK:   %[[ADD:.+]] = addf %arg3, %arg4 : f32
+  // CHECK:   linalg.yield %[[ADD]] : f32
+
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>)  -> (tensor<?x6xf32>)
+  return %0 : tensor<?x6xf32>
+}
+
+// -----
+
 func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
   // TODO: Output contains multiple "constant 1 : index".


        


More information about the Mlir-commits mailing list