[Mlir-commits] [mlir] 53ffeea - [mlir][Linalg] Reduction dimensions specified in TC definition of ConvOps.
Jakub Lichman
llvmlistbot at llvm.org
Wed Sep 9 08:17:40 PDT 2020
Author: Jakub Lichman
Date: 2020-09-09T15:17:07Z
New Revision: 53ffeea6d59ae5ba78b8c85a31c06677c3ab7719
URL: https://github.com/llvm/llvm-project/commit/53ffeea6d59ae5ba78b8c85a31c06677c3ab7719
DIFF: https://github.com/llvm/llvm-project/commit/53ffeea6d59ae5ba78b8c85a31c06677c3ab7719.diff
LOG: [mlir][Linalg] Reduction dimensions specified in TC definition of ConvOps.
This commit specifies reduction dimensions for ConvOps. This prevents
running reduction loops in parallel and enables easier detection of kernel dimensions
which we will need later on.
Differential Revision: https://reviews.llvm.org/D87288
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/test/Dialect/Linalg/loops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 27d4330a54d5..9c54a5f0c3c7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -20,52 +20,50 @@ def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M,
ods_def<ConvWOp>:
def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
- O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw)));
+ O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
}
ods_def<ConvNWCOp>:
def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
- O(n, w, f) = std_addf(O(n, w, f),
- std_mulf(I(n, w + kw, c), K(f, kw, c)));
+ O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
}
ods_def<ConvNCWOp>:
def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
- O(n, f, w) = std_addf(O(n, f, w),
- std_mulf(I(n, c, w + kw), K(f, c, kw)));
+ O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
}
ods_def<ConvHWOp>:
def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
- O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
+ O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
}
ods_def<ConvNHWCOp>:
def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
- O(n, h, w, f) = std_addf(O(n, h, w, f),
- std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+ O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
+ I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
}
ods_def<ConvNCHWOp>:
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
- O(n, f, h, w) = std_addf(O(n, f, h, w),
- std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+ O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
+ I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
}
ods_def<ConvDHWOp>:
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
- O(d, h, w) = std_addf(O(d, h, w),
- std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+ O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
+ I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
}
ods_def<ConvNDHWCOp>:
def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
- O(n, d, h, w, f) = std_addf(O(n, d, h, w, f),
- std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+ O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
+ I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
}
ods_def<ConvNCDHWOp>:
def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
- O(n, f, d, h, w) = std_addf(O(n, f, d, h, w),
- std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+ O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
+ I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6af53a2b8d22..1e10e036ee2d 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1318,14 +1318,15 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
-// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) {
-// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
-// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
-// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
-// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
-// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+// CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
+// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
@@ -1367,15 +1368,17 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
-// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
-// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
-// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
-// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
-// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
-// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
-// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
+// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
@@ -1427,13 +1430,16 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
-// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
-// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
-// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
-// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
-// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
-// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
-// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
-// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
+// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
+// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
More information about the Mlir-commits
mailing list