[Mlir-commits] [mlir] 7517e24 - [mlir][linalg] Promote operands for convolution vectorization

Lei Zhang llvmlistbot at llvm.org
Mon Apr 17 16:37:24 PDT 2023


Author: Lei Zhang
Date: 2023-04-17T16:37:06-07:00
New Revision: 7517e246aca2c410dba01ff1419596d7eef4a7e6

URL: https://github.com/llvm/llvm-project/commit/7517e246aca2c410dba01ff1419596d7eef4a7e6
DIFF: https://github.com/llvm/llvm-project/commit/7517e246aca2c410dba01ff1419596d7eef4a7e6.diff

LOG: [mlir][linalg] Promote operands for convolution vectorization

We are already doing this for depthwise convolution and pooling.
This helps to preserve the promotion semantics from Linalg op
definitions to lower layers.

Along the way, fixed the type mismatch issue in the existing
`promote` implementation.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D148471

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 14726a8c579f9..eff424f021eb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2512,6 +2512,29 @@ struct Conv1DGenerator
         .getOperation();
   }
 
+  // Take a value and widen to have the same element type as `ty`.
+  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
+    const Type srcElementType = getElementTypeOrSelf(val.getType());
+    const Type dstElementType = getElementTypeOrSelf(ty);
+    assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
+    if (srcElementType == dstElementType)
+      return val;
+
+    const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
+    const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
+    const Type dstType =
+        cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
+
+    if (isa<FloatType>(dstElementType) && srcWidth < dstWidth)
+      return rewriter.create<arith::ExtFOp>(loc, dstType, val);
+
+    if (isa<IntegerType>(dstElementType) && srcWidth < dstWidth)
+      return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
+
+    assert(false && "unhandled promotion case");
+    return nullptr;
+  }
+
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
   Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
                                  Value lhs, Value rhs, Value res) {
@@ -2519,6 +2542,8 @@ struct Conv1DGenerator
     vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
+    lhs = promote(rewriter, loc, lhs, res.getType());
+    rhs = promote(rewriter, loc, rhs, res.getType());
     return rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
@@ -2666,24 +2691,6 @@ struct Conv1DGenerator
         .getOperation();
   }
 
-  // Take a value of element type T and widen to the destination type.
-  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
-    if (val.getType() == ty)
-      return val;
-
-    const int64_t srcWidth =
-        getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth();
-    const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
-
-    if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
-      return rewriter.create<arith::ExtFOp>(loc, ty, val);
-
-    if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
-      return rewriter.create<arith::ExtSIOp>(loc, ty, val);
-
-    return nullptr;
-  }
-
   /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
                                      Value lhs, Value rhs, Value res) {

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index db235fad7d2ee..88b2c2242ecd9 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -100,18 +100,22 @@ func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: me
 // CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
 
 /// w == 0, kw == 0
+//      CHECK:   %[[EXT_LHS_0:.+]] = arith.extsi %[[V_INPUT_0]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+//      CHECK:   %[[EXT_RHS_0:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
-// CHECK-SAME:     : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+// CHECK-SAME:     %[[EXT_LHS_0]], %[[EXT_RHS_0]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
 
 /// w == 1, kw == 0
+//      CHECK:   %[[EXT_LHS_1:.+]] = arith.extsi %[[V_INPUT_1]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+//      CHECK:   %[[EXT_RHS_1:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
 //      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
-// CHECK-SAME:     : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+// CHECK-SAME:     %[[EXT_LHS_1]], %[[EXT_RHS_1]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
 
 /// w == 0, kw == 0
 //      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]


        


More information about the Mlir-commits mailing list