[Mlir-commits] [mlir] 4c79766 - [mlir][linalg] Convert input type to accumulator type in im2col patterns
Quinn Dawkins
llvmlistbot at llvm.org
Fri Mar 24 15:42:03 PDT 2023
Author: Quinn Dawkins
Date: 2023-03-24T18:35:06-04:00
New Revision: 4c79766689f83055858acbdc2c9f5d652d0a46c8
URL: https://github.com/llvm/llvm-project/commit/4c79766689f83055858acbdc2c9f5d652d0a46c8
DIFF: https://github.com/llvm/llvm-project/commit/4c79766689f83055858acbdc2c9f5d652d0a46c8.diff
LOG: [mlir][linalg] Convert input type to accumulator type in im2col patterns
When the input types don't match the accumulator type in named
convolution ops there is supposed to be a conversion to the accumulator
type before the multiply and accumulate.
Differential Revision: https://reviews.llvm.org/D146824
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 491c533fc85e..3341f8a91644 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -35,11 +35,16 @@ static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
return builder.create<arith::AddFOp>(loc, x, y);
}
-static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
- bool isInt = x.getType().isa<IntegerType>();
- if (isInt)
- return builder.create<arith::MulIOp>(loc, x, y);
- return builder.create<arith::MulFOp>(loc, x, y);
+static Value createMul(Location loc, Value x, Value y, Type accType,
+ OpBuilder &builder) {
+ // Linalg named ops specify signed extend for named ops.
+ Value xConvert =
+ convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
+ Value yConvert =
+ convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
+ if (accType.isa<IntegerType>())
+ return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
+ return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
}
// Delinearizes the given composite `index` by the basis specified in `factors`.
@@ -185,7 +190,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
/*outputs=*/ValueRange{reshapedOutput},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+ Value mul =
+ createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
Value add = createAdd(loc, mul, args[2], nestedBuilder);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
});
@@ -468,7 +474,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
/*outputs=*/ValueRange{reshapedOutput},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- Value mul = createMul(loc, args[0], args[1], nestedBuilder);
+ Value mul =
+ createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
Value add = createAdd(loc, mul, args[2], nestedBuilder);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
});
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 38c63490cf44..4888d4aca593 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -276,3 +276,41 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation
%1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
}
+
+// -----
+
+// Check for signed extend when the input type is smaller than the accumulator type.
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: @conv_integer_extend
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xi8>, tensor<36x16xi8>)
+// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xi32>)
+// CHECK: ^bb0(%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i32)
+// CHECK: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32
+// CHECK: linalg.yield %[[ADD]] : i32
+// CHECK: } -> tensor<1x196x16xi32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
+// CHECK: return %[[RESULT]]
+
+func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xi8>, tensor<3x3x4x16xi8>)
+ outs(%arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32>
+ return %0 : tensor<1x14x14x16xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation
+ transform.print %transformed {name = "transformed"}: !pdl.operation
+}
More information about the Mlir-commits
mailing list