[Mlir-commits] [mlir] 53ffeea - [mlir][Linalg] Reduction dimensions specified in TC definition of ConvOps.

Jakub Lichman llvmlistbot at llvm.org
Wed Sep 9 08:17:40 PDT 2020


Author: Jakub Lichman
Date: 2020-09-09T15:17:07Z
New Revision: 53ffeea6d59ae5ba78b8c85a31c06677c3ab7719

URL: https://github.com/llvm/llvm-project/commit/53ffeea6d59ae5ba78b8c85a31c06677c3ab7719
DIFF: https://github.com/llvm/llvm-project/commit/53ffeea6d59ae5ba78b8c85a31c06677c3ab7719.diff

LOG: [mlir][Linalg] Reduction dimensions specified in TC definition of ConvOps.

This commit specifies reduction dimensions for ConvOps. This prevents
running reduction loops in parallel and enables easier detection of kernel dimensions
which we will need later on.

Differential Revision: https://reviews.llvm.org/D87288

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 27d4330a54d5..9c54a5f0c3c7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -20,52 +20,50 @@ def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M,
 
 ods_def<ConvWOp>:
 def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
-  O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw)));
+  O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
 }
 
 ods_def<ConvNWCOp>:
 def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
-  O(n, w, f) = std_addf(O(n, w, f),
-    std_mulf(I(n, w + kw, c), K(f, kw, c)));
+  O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
 }
 
 ods_def<ConvNCWOp>:
 def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
-  O(n, f, w) = std_addf(O(n, f, w),
-    std_mulf(I(n, c, w + kw), K(f, c, kw)));
+  O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
 }
 
 ods_def<ConvHWOp>:
 def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
-  O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
+  O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
 }
 
 ods_def<ConvNHWCOp>:
 def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
-  O(n, h, w, f) = std_addf(O(n, h, w, f),
-    std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+  O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
+    I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
 }
 
 ods_def<ConvNCHWOp>:
 def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
-  O(n, f, h, w) = std_addf(O(n, f, h, w),
-    std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+  O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
+    I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
 }
 
 ods_def<ConvDHWOp>:
 def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
-  O(d, h, w) = std_addf(O(d, h, w),
-    std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+  O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
+    I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
 }
 
 ods_def<ConvNDHWCOp>:
 def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
-  O(n, d, h, w, f) = std_addf(O(n, d, h, w, f),
-    std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+  O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
+    I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
 }
 
 ods_def<ConvNCDHWOp>:
 def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
-  O(n, f, d, h, w) = std_addf(O(n, f, d, h, w),
-    std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+  O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
+    I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
 }
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6af53a2b8d22..1e10e036ee2d 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1318,14 +1318,15 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
 //       CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
 //       CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
 //       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
-//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
-//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
-//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
-//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
+//       CHECKPARALLEL:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKPARALLEL:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
+//       CHECKPARALLEL:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+//       CHECKPARALLEL:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+//       CHECKPARALLEL:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
 
 
 func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
@@ -1367,15 +1368,17 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
 //       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
 //       CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
 //       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
-//       CHECKPARALLEL:   %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
-//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
-//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
-//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
-//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKPARALLEL:     scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKPARALLEL:       %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
+//       CHECKPARALLEL:       %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
+//       CHECKPARALLEL:       %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
+//       CHECKPARALLEL:       %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
+//       CHECKPARALLEL:       %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL:       %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:       store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
 
 
 func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
@@ -1427,13 +1430,16 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
 //       CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
 //       CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
-//       CHECKPARALLEL:   %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
-//       CHECKPARALLEL:   %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
-//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKPARALLEL:     scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKPARALLEL:       scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//       CHECKPARALLEL:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
+//       CHECKPARALLEL:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
+//       CHECKPARALLEL:         %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
+//       CHECKPARALLEL:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>


        


More information about the Mlir-commits mailing list