[Mlir-commits] [mlir] cc1ae54 - [tosa][mlir] Fix FullyConnected to correctly order dimensions

Rob Suderman llvmlistbot at llvm.org
Tue Apr 27 17:33:22 PDT 2021


Author: Rob Suderman
Date: 2021-04-27T17:26:04-07:00
New Revision: cc1ae54ebcc4072ce19b610511f500f15c7acd8a

URL: https://github.com/llvm/llvm-project/commit/cc1ae54ebcc4072ce19b610511f500f15c7acd8a
DIFF: https://github.com/llvm/llvm-project/commit/cc1ae54ebcc4072ce19b610511f500f15c7acd8a.diff

LOG: [tosa][mlir] Fix FullyConnected to correctly order dimensions

MatMul and FullyConnected have transposed dimensions for the weights.
Also, removed uneeded tensor reshape for bias.

Differential Revision: https://reviews.llvm.org/D101220

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 51de267170ad..b31656467c48 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -745,26 +745,28 @@ class FullyConnectedConverter
   LogicalResult
   matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
                   ConversionPatternRewriter &rewriter) const final {
-    tosa::FullyConnectedOp::Adaptor adaptor(args);
-
     Location loc = op.getLoc();
     auto outputTy = op.getType().cast<ShapedType>();
-    auto biasTy = op->getOperand(2).getType().cast<ShapedType>();
+    auto input = op.input();
+    auto weight = op.weight();
+    auto bias = op.bias();
 
-    // Reshaping the bias from n to [1, n] for broadcasting
-    SmallVector<int64_t> biasShapeReshaped;
-    biasShapeReshaped.push_back(1);
-    biasShapeReshaped.push_back(biasTy.getShape()[0]);
+    auto weightTy = weight.getType().cast<ShapedType>();
+    auto biasTy = bias.getType().cast<ShapedType>();
 
-    RankedTensorType reshapedBias =
-        RankedTensorType::get(biasShapeReshaped, outputTy.getElementType());
-    auto reshapeResult =
-        rewriter.create<tosa::ReshapeOp>(loc, reshapedBias, args[2])
-            ->getResult(0);
+    auto weightShape = weightTy.getShape();
+
+    if (op.quantization_info())
+      return failure();
 
     // Creating maps for the output of MatMul and the bias
     SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.push_back(createAffineMapForType(reshapedBias, rewriter));
+
+    // Broadcast the bias.
+    indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
+                                          {rewriter.getAffineDimExpr(1)},
+                                          rewriter.getContext()));
+
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
 
     auto initTensor =
@@ -776,7 +778,7 @@ class FullyConnectedConverter
     auto linalgOp =
         rewriter
             .create<linalg::GenericOp>(
-                loc, outputTy, reshapeResult, initTensor, indexingMaps,
+                loc, outputTy, bias, initTensor, indexingMaps,
                 getNParallelLoopsAttrs(outputTy.getRank()),
                 [&](OpBuilder &nested_builder, Location nested_loc,
                     ValueRange args) {
@@ -784,9 +786,21 @@ class FullyConnectedConverter
                 })
             ->getResults();
 
+    SmallVector<int64_t> permutation{1, 0};
+    auto permutationAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
+    Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
+
+    SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
+    Type newWeightTy =
+        RankedTensorType::get(newWeightShape, biasTy.getElementType());
+
+    Value transposedWeight = rewriter.create<tosa::TransposeOp>(
+        loc, newWeightTy, weight, permutationValue);
+
     rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
-        op, TypeRange{op.getType()},
-        ValueRange{adaptor.input(), adaptor.weight()}, linalgOp);
+        op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
+        linalgOp);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3f9940e4a203..57d8c86fa25b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -756,17 +756,22 @@ func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: @fully_connected
-func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
-  // CHECK: [[RS:%.+]] = linalg.tensor_reshape %arg2 [#[[$MAP0]]]
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RS]] : tensor<1x6xf32>) outs([[INIT]] : tensor<5x6xf32>) {
-  // CHECK: ^bb0([[IN:%.+]]: f32, [[MULTIPLIER:%.+]]: f32):
-  // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
-  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>)  -> (tensor<5x6xf32>)
+func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
+  // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
+  // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32):
+  // CHECK:   linalg.yield [[IN]] : f32
+  // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6]
+  // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]]
+  // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32):
+  // CHECK:   linalg.yield [[IN]] : f32
+  // CHECK: linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>)  -> (tensor<5x6xf32>)
   return %0 : tensor<5x6xf32>
 }
 


        


More information about the Mlir-commits mailing list