[Mlir-commits] [mlir] 2a00891 - [mlir][linalg] Fix linalg.conv vectorization for mixed int-fp types
Rob Suderman
llvmlistbot at llvm.org
Thu Jun 15 11:15:35 PDT 2023
Author: Rob Suderman
Date: 2023-06-15T11:13:18-07:00
New Revision: 2a00891107bc08225d5ba5183a672436db929752
URL: https://github.com/llvm/llvm-project/commit/2a00891107bc08225d5ba5183a672436db929752
DIFF: https://github.com/llvm/llvm-project/commit/2a00891107bc08225d5ba5183a672436db929752.diff
LOG: [mlir][linalg] Fix linalg.conv vectorization for mixed int-fp types
We always assume mixed same type values. Instead of ExtF or ExtSI, we
need SIToFp when the values must be promoted.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D152982
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c953a1a1879b1..685567d1631be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2750,10 +2750,16 @@ struct Conv1DGenerator
const Type dstType =
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
- if (isa<FloatType>(dstElementType) && srcWidth < dstWidth)
+ if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
+ return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
+ }
+
+ if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
+ srcWidth < dstWidth)
return rewriter.create<arith::ExtFOp>(loc, dstType, val);
- if (isa<IntegerType>(dstElementType) && srcWidth < dstWidth)
+ if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
+ srcWidth < dstWidth)
return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
assert(false && "unhandled promotion case");
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 88b2c2242ecd9..29dd016b803a4 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -631,6 +631,31 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
// -----
+func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter: memref<1x3x2xi8>, %output: memref<1x2x2xf32>) {
+ linalg.conv_1d_nwc_wcf
+ {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : memref<1x2x3xi8>, memref<1x3x2xi8>)
+ outs(%output : memref<1x2x2xf32>)
+ return
+}
+
+
+// CHECK-LABEL: func @conv_1d_nwc_wcf_mixed_int_fp_memref
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xi8>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i8
+// CHECK: %[[READ0:.+]] = vector.transfer_read %arg0[%[[I0]], %[[I0]], %[[I0]]], %[[C0]]
+// CHECK: %[[READ1:.+]] = vector.transfer_read %arg1[%[[I0]], %[[I0]], %[[I0]]], %[[C0]]
+// CHECK: %[[READ2:.+]] = vector.transfer_read %arg2[%[[I0]], %[[I0]], %[[I0]]], %[[CST]]
+// CHECK: %[[EXT:.+]] = vector.extract %[[READ1]][0] : vector<1x3x2xi8>
+// CHECK: %[[CAST0:.+]] = arith.sitofp %[[READ0]] : vector<1x2x3xi8> to vector<1x2x3xf32>
+// CHECK: %[[CAST1:.+]] = arith.sitofp %[[EXT]] : vector<3x2xi8> to vector<3x2xf32>
+// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
+
+// -----
+
func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
linalg.pooling_nwc_sum
{dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
More information about the Mlir-commits
mailing list