[Mlir-commits] [mlir] 2a00891 - [mlir][linalg] Fix linalg.conv vectorization for mixed int-fp types

Rob Suderman llvmlistbot at llvm.org
Thu Jun 15 11:15:35 PDT 2023


Author: Rob Suderman
Date: 2023-06-15T11:13:18-07:00
New Revision: 2a00891107bc08225d5ba5183a672436db929752

URL: https://github.com/llvm/llvm-project/commit/2a00891107bc08225d5ba5183a672436db929752
DIFF: https://github.com/llvm/llvm-project/commit/2a00891107bc08225d5ba5183a672436db929752.diff

LOG: [mlir][linalg] Fix linalg.conv vectorization for mixed int-fp types

We always assume mixed same type values. Instead of ExtF or ExtSI, we
need SIToFp when the values must be promoted.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D152982

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c953a1a1879b1..685567d1631be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2750,10 +2750,16 @@ struct Conv1DGenerator
     const Type dstType =
         cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
 
-    if (isa<FloatType>(dstElementType) && srcWidth < dstWidth)
+    if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
+      return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
+    }
+
+    if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
+        srcWidth < dstWidth)
       return rewriter.create<arith::ExtFOp>(loc, dstType, val);
 
-    if (isa<IntegerType>(dstElementType) && srcWidth < dstWidth)
+    if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
+        srcWidth < dstWidth)
       return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
 
     assert(false && "unhandled promotion case");

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 88b2c2242ecd9..29dd016b803a4 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -631,6 +631,31 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
 
 // -----
 
+func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter: memref<1x3x2xi8>, %output: memref<1x2x2xf32>) {
+  linalg.conv_1d_nwc_wcf
+  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+   ins(%input, %filter : memref<1x2x3xi8>, memref<1x3x2xi8>)
+   outs(%output : memref<1x2x2xf32>)
+  return
+}
+
+
+// CHECK-LABEL: func @conv_1d_nwc_wcf_mixed_int_fp_memref
+// CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xi8>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i8
+// CHECK: %[[READ0:.+]] = vector.transfer_read %arg0[%[[I0]], %[[I0]], %[[I0]]], %[[C0]]
+// CHECK: %[[READ1:.+]] = vector.transfer_read %arg1[%[[I0]], %[[I0]], %[[I0]]], %[[C0]]
+// CHECK: %[[READ2:.+]] = vector.transfer_read %arg2[%[[I0]], %[[I0]], %[[I0]]], %[[CST]]
+// CHECK: %[[EXT:.+]] = vector.extract %[[READ1]][0] : vector<1x3x2xi8>
+// CHECK: %[[CAST0:.+]] = arith.sitofp %[[READ0]] : vector<1x2x3xi8> to vector<1x2x3xf32>
+// CHECK: %[[CAST1:.+]] = arith.sitofp %[[EXT]] : vector<3x2xi8> to vector<3x2xf32>
+// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
+
+// -----
+
 func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
   linalg.pooling_nwc_sum
     {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}


        


More information about the Mlir-commits mailing list