[Mlir-commits] [mlir] 7b007c0 - [mlir][tosa-to-linalg] Add acc_type lowering Support (#134267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 7 03:37:29 PDT 2025


Author: Jack Frankland
Date: 2025-04-07T11:37:26+01:00
New Revision: 7b007c092d665bcb3f00ff937e04b20e6ec32c55

URL: https://github.com/llvm/llvm-project/commit/7b007c092d665bcb3f00ff937e04b20e6ec32c55
DIFF: https://github.com/llvm/llvm-project/commit/7b007c092d665bcb3f00ff937e04b20e6ec32c55.diff

LOG: [mlir][tosa-to-linalg] Add acc_type lowering Support (#134267)

Add support for lowering of convolution operations where the `acc_type`
attribute differs from the result type of the operation. The only case
of this in for convolutions in the TOSA-v1.0 specification is an fp16
convolution which internally uses an fp32 accumulator; all other
operations have accumulator types that match their output/result types.

Add lit tests for the fp16 convolution with fp32 accumulator operators
described above.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index fc1cad2423450..86f5e9baf4a94 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
 }
 
 // Broadcast the source value to all the outer dimensions of the result value.
-// If required, the element type is expanded using an arith.extsi operation.
-static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
-                                                Location loc, Value source,
-                                                Value result) {
+// If required, the element type is expanded using an arith.extsi or arith.extf
+// operation as appropriate.
+static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
+                                              Location loc, Value source,
+                                              Value result) {
   ShapedType resultTy = cast<ShapedType>(result.getType());
   const int64_t resultRank = resultTy.getRank();
   // Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
       .create<linalg::GenericOp>(
           loc, resultTy, ValueRange({source}), result, indexingMaps,
           getNParallelLoopsAttrs(resultTy.getRank()),
-          [](OpBuilder &builder, Location loc, ValueRange args) {
+          [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
             Value biasVal = args[0];
             Type resType = args[1].getType();
             if (resType != biasVal.getType()) {
-              biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+              biasVal =
+                  resultTy.getElementType().isFloat()
+                      ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+                            .getResult()
+                      : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+                            .getResult();
             }
             builder.create<linalg::YieldOp>(loc, biasVal);
           })
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
 
     Type inputETy = inputTy.getElementType();
-    Type resultETy = resultTy.getElementType();
 
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
 
+    Type accETy = op.getAccType();
+    Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
+
     // Get and verify zero points.
     FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
     if (failed(maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
 
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultETy, filteredDims);
+        loc, resultTy.getShape(), accETy, filteredDims);
 
     Value broadcastBias =
-        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+        linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
 
     if (hasZp) {
       auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     Value conv = rewriter
                      .create<LinalgConvOp>(
-                         loc, resultTy, ValueRange{input, weight},
+                         loc, accTy, ValueRange{input, weight},
                          ValueRange{broadcastBias}, strideAttr, dilationAttr)
                      ->getResult(0);
 
+    // We may need to truncate back to the result type if the accumulator was
+    // wider than the result.
+    if (resultTy != accTy)
+      conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+
     rewriter.replaceOp(op, conv);
     return success();
   }
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
     auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
     auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
 
+    Type accETy = op.getAccType();
+
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
           op, "tosa.depthwise_conv ops require static shapes");
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
     ShapedType linalgConvTy =
         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
                                weightShape[2], weightShape[3]},
-                              resultETy);
+                              accETy);
 
-    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto resultZeroAttr = rewriter.getZeroAttr(accETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, linalgConvTy.getShape(), resultETy, filteredDims);
+        loc, linalgConvTy.getShape(), accETy, filteredDims);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor = rewriter
                            .create<linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
                            ValueRange{zeroTensor}, strideAttr, dilationAttr)
                        .getResult(0);
 
+      // We may need to truncate back to the result type if the accumulator was
+      // wider than the result.
+      if (accETy != resultETy)
+        conv = rewriter.create<tosa::CastOp>(
+            loc,
+            RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
+                                  resultETy),
+            conv);
+
       SmallVector<ReassociationExprs, 4> reassociationMap;
       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 19c12ba3edbd4..242772fe5cdcf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
 
 // -----
 
+// CHECK-LABEL: @conv2d_f16_f32_acc
+func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
@@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
 
 // -----
 
+// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
+func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
@@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
 
 // -----
 
+// CHECK-LABEL: @conv3d_f16_f32_acc
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
 func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {


        


More information about the Mlir-commits mailing list