[Mlir-commits] [mlir] 4c79766 - [mlir][linalg] Convert input type to accumulator type in im2col patterns

Quinn Dawkins llvmlistbot at llvm.org
Fri Mar 24 15:42:03 PDT 2023


Author: Quinn Dawkins
Date: 2023-03-24T18:35:06-04:00
New Revision: 4c79766689f83055858acbdc2c9f5d652d0a46c8

URL: https://github.com/llvm/llvm-project/commit/4c79766689f83055858acbdc2c9f5d652d0a46c8
DIFF: https://github.com/llvm/llvm-project/commit/4c79766689f83055858acbdc2c9f5d652d0a46c8.diff

LOG: [mlir][linalg] Convert input type to accumulator type in im2col patterns

When the input types don't match the accumulator type in named
convolution ops there is supposed to be a conversion to the accumulator
type before the multiply and accumulate.

Differential Revision: https://reviews.llvm.org/D146824

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
    mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 491c533fc85e..3341f8a91644 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -35,11 +35,16 @@ static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
   return builder.create<arith::AddFOp>(loc, x, y);
 }
 
-static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
-  bool isInt = x.getType().isa<IntegerType>();
-  if (isInt)
-    return builder.create<arith::MulIOp>(loc, x, y);
-  return builder.create<arith::MulFOp>(loc, x, y);
+static Value createMul(Location loc, Value x, Value y, Type accType,
+                       OpBuilder &builder) {
+  // Linalg named ops specify signed extend for named ops.
+  Value xConvert =
+      convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
+  Value yConvert =
+      convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
+  if (accType.isa<IntegerType>())
+    return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
+  return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
 }
 
 // Delinearizes the given composite `index` by the basis specified in `factors`.
@@ -185,7 +190,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
       /*outputs=*/ValueRange{reshapedOutput},
       ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+        Value mul =
+            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
         Value add = createAdd(loc, mul, args[2], nestedBuilder);
         nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
       });
@@ -468,7 +474,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
       /*outputs=*/ValueRange{reshapedOutput},
       ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+        Value mul =
+            createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
         Value add = createAdd(loc, mul, args[2], nestedBuilder);
         nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
       });

diff  --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 38c63490cf44..4888d4aca593 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -276,3 +276,41 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation
   %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
 }
+
+// -----
+
+// Check for signed extend when the input type is smaller than the accumulator type.
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: @conv_integer_extend
+//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
+//           CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xi8>, tensor<36x16xi8>)
+//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xi32>)
+//                CHECK: ^bb0(%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i32)
+//                CHECK:     %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+//                CHECK:     %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+//                CHECK:     %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
+//                CHECK:     %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32
+//                CHECK:     linalg.yield %[[ADD]] : i32
+//                CHECK: } -> tensor<1x196x16xi32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
+//      CHECK: return %[[RESULT]]
+
+func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> {
+    %0 = linalg.conv_2d_nhwc_hwcf
+      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+       ins(%arg0, %arg1: tensor<1x16x16x4xi8>, tensor<3x3x4x16xi8>)
+      outs(%arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32>
+    return %0 : tensor<1x14x14x16xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation
+  transform.print %transformed {name = "transformed"}: !pdl.operation
+}


        


More information about the Mlir-commits mailing list