[Mlir-commits] [mlir] 7517e24 - [mlir][linalg] Promote operands for convolution vectorization
Lei Zhang
llvmlistbot at llvm.org
Mon Apr 17 16:37:24 PDT 2023
Author: Lei Zhang
Date: 2023-04-17T16:37:06-07:00
New Revision: 7517e246aca2c410dba01ff1419596d7eef4a7e6
URL: https://github.com/llvm/llvm-project/commit/7517e246aca2c410dba01ff1419596d7eef4a7e6
DIFF: https://github.com/llvm/llvm-project/commit/7517e246aca2c410dba01ff1419596d7eef4a7e6.diff
LOG: [mlir][linalg] Promote operands for convolution vectorization
We are already doing this for depthwise convolution and pooling.
This helps to preserve the promotion semantics from Linalg op
definitions to lower layers.
Along the way, fixed the type mismatch issue in the existing
`promote` implementation.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D148471
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 14726a8c579f9..eff424f021eb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2512,6 +2512,29 @@ struct Conv1DGenerator
.getOperation();
}
+ // Take a value and widen to have the same element type as `ty`.
+ Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
+ const Type srcElementType = getElementTypeOrSelf(val.getType());
+ const Type dstElementType = getElementTypeOrSelf(ty);
+ assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
+ if (srcElementType == dstElementType)
+ return val;
+
+ const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
+ const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
+ const Type dstType =
+ cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
+
+ if (isa<FloatType>(dstElementType) && srcWidth < dstWidth)
+ return rewriter.create<arith::ExtFOp>(loc, dstType, val);
+
+ if (isa<IntegerType>(dstElementType) && srcWidth < dstWidth)
+ return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
+
+ assert(false && "unhandled promotion case");
+ return nullptr;
+ }
+
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
@@ -2519,6 +2542,8 @@ struct Conv1DGenerator
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
+ lhs = promote(rewriter, loc, lhs, res.getType());
+ rhs = promote(rewriter, loc, rhs, res.getType());
return rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
@@ -2666,24 +2691,6 @@ struct Conv1DGenerator
.getOperation();
}
- // Take a value of element type T and widen to the destination type.
- Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
- if (val.getType() == ty)
- return val;
-
- const int64_t srcWidth =
- getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth();
- const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
-
- if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
- return rewriter.create<arith::ExtFOp>(loc, ty, val);
-
- if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
- return rewriter.create<arith::ExtSIOp>(loc, ty, val);
-
- return nullptr;
- }
-
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index db235fad7d2ee..88b2c2242ecd9 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -100,18 +100,22 @@ func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: me
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
/// w == 0, kw == 0
+// CHECK: %[[EXT_LHS_0:.+]] = arith.extsi %[[V_INPUT_0]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+// CHECK: %[[EXT_RHS_0:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
-// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+// CHECK-SAME: %[[EXT_LHS_0]], %[[EXT_RHS_0]], %[[V_OUTPUT_0]]
+// CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
/// w == 1, kw == 0
+// CHECK: %[[EXT_LHS_1:.+]] = arith.extsi %[[V_INPUT_1]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+// CHECK: %[[EXT_RHS_1:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
-// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+// CHECK-SAME: %[[EXT_LHS_1]], %[[EXT_RHS_1]], %[[V_OUTPUT_1]]
+// CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
/// w == 0, kw == 0
// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
More information about the Mlir-commits
mailing list