[Mlir-commits] [mlir] e31e8f1 - [MLIR][Linalg] Retire C++ MatvecOp in favor of a linalg-ods-gen'd op

Alex Zinenko llvmlistbot at llvm.org
Thu Jun 18 02:36:59 PDT 2020


Author: lorenzo chelini
Date: 2020-06-18T11:36:49+02:00
New Revision: e31e8f1ed57eb25584903f1a67040babf2c48eda

URL: https://github.com/llvm/llvm-project/commit/e31e8f1ed57eb25584903f1a67040babf2c48eda
DIFF: https://github.com/llvm/llvm-project/commit/e31e8f1ed57eb25584903f1a67040babf2c48eda.diff

LOG: [MLIR][Linalg] Retire C++ MatvecOp in favor of a linalg-ods-gen'd op

Replace C++ MatvecOp, now that DRR rules have been dropped.

Differential Revision: https://reviews.llvm.org/D82007

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index c510010acd0b..bbd398585e5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -3,6 +3,11 @@ def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
   C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
 }
 
+ods_def<MatvecOp>:
+def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
+  x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
+}
+
 ods_def<BatchMatmulOp>:
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1292344484b0..cddd4f9b22f8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -197,34 +197,6 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
-def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
-
-  let arguments = (ins AnyStridedMemRefOfRank<2>,
-                       AnyStridedMemRefOfRank<1>,
-                       AnyStridedMemRefOfRank<1>);
-
-  let extraClassDeclaration = libraryCallName # [{
-    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
-      return SmallVector<StringRef, 8>{
-        getParallelIteratorTypeName(), getReductionIteratorTypeName()};
-    }
-
-    // A(i, r_j) * B(r_j) -> C(i)
-    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
-      MLIRContext *context = getContext();
-      AffineExpr i, r_j;
-      bindDims(context, i, r_j);
-      return SmallVector<AffineMap, 8>{
-        AffineMap::get(2, 0, {i, r_j}, context),
-        AffineMap::get(2, 0, {r_j}, context),
-        AffineMap::get(2, 0, {i}, context)
-      };
-    }
-  }];
-
-  let hasFolder = 1;
-}
-
 /// A base class for pooling operation such as conv. The arguments must contain
 /// optional arguments `strides`, `dilations` and `padding` with following type:
 ///   OptionalAttr<I64ArrayAttr>:$strides

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 7b30646136cb..7f13a7a609e9 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -240,11 +240,11 @@ void mlir::populateLinalgToStandardConversionPatterns(
       LinalgOpConversion<DotOp>,
       LinalgOpConversion<FillOp>,
       LinalgOpConversion<GenericOp>,
-      LinalgOpConversion<IndexedGenericOp>,
-      LinalgOpConversion<MatvecOp>>(ctx);
+      LinalgOpConversion<IndexedGenericOp>>(ctx);
   // TODO: collect all auto-generated named ops with a tblgen directive.
   patterns.insert<
       LinalgOpConversion<BatchMatmulOp>,
+      LinalgOpConversion<MatvecOp>,
       LinalgOpConversion<MatmulOp>>(ctx);
   // clang-format on
 }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2342cae661c8..c8401977d612 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,10 +1124,6 @@ LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
                                      SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
-LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
@@ -1242,3 +1238,7 @@ LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
                              SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
+LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 5ccf2a469dce..d031712ce5a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -242,17 +242,6 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
   // Emit scalar form.
   C() = C() + A(r_i) * B(r_i);
 }
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, MatvecOp matvecOp) {
-  assert(matvecOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 2);
-  Value i(allIvs[0]), r_j(allIvs[1]);
-  IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
-      C(matvecOp.getOutputBuffer(0));
-  // Emit scalar form.
-  C(i) = C(i) + A(i, r_j) * B(r_j);
-}
 
 template <typename IndexedValueType>
 Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
@@ -624,8 +613,6 @@ Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
     return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
   if (isa<DotOp>(op))
     return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
-  if (isa<MatvecOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
   if (isa<ConvOp>(op))
     return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
   if (isa<PoolingMaxOp>(op))
@@ -642,6 +629,8 @@ Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
     return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
   if (isa<MatmulOp>(op))
     return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
+  if (isa<MatvecOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
   if (isa<BatchMatmulOp>(op))
     return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
   llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 020e43da00b4..f03129c4d8be 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -77,7 +77,7 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
   %2 = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
   %3 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
   %4 = view %arg0[%c0][%N] : memref<?xi8> to memref<?xf32>
-  linalg.matvec(%2, %3, %4) : memref<?x?xf32>, memref<?xf32>, memref<?xf32>
+  linalg.matvec %2, %3, %4 : (memref<?x?xf32>, memref<?xf32>, memref<?xf32>)
   return
 }
 // CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref<?xi8>,

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index aaa2890060e6..f210b185331c 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -86,9 +86,9 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
   linalg.matmul %arg0, %arg0, %arg0 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
                                        memref<?x?xf32, offset: ?, strides: [?, 1]>,
                                        memref<?x?xf32, offset: ?, strides: [?, 1]>)
-  linalg.matvec(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+  linalg.matvec %arg0, %arg1, %arg2 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
                                        memref<?xf32, offset: ?, strides: [1]>,
-                                       memref<?xf32, offset: ?, strides: [1]>
+                                       memref<?xf32, offset: ?, strides: [1]>)
   linalg.dot(%arg1, %arg2, %arg3) : memref<?xf32, offset: ?, strides: [1]>,
                                     memref<?xf32, offset: ?, strides: [1]>,
                                     memref<f32>
@@ -99,10 +99,10 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 //  CHECK-SAME:    (memref<?x?xf32, #[[$strided2D]]>,
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>,
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>)
-//  CHECK-NEXT:  linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) :
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]>,
+//  CHECK-NEXT:  linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} :
+//  CHECK-SAME:    (memref<?x?xf32, #[[$strided2D]]>,
 //  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>,
-//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>
+//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>)
 //  CHECK-NEXT:  linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) :
 //  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>,
 //  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>,

diff  --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index a36adf242d63..049fb571bd51 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -199,7 +199,10 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-234:           memref<?x?xf32, #[[$strided2D]]>)
 
 func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
+  linalg.matvec %arg0, %arg1, %arg2 : (
+    memref<?x?xf32, offset: ?, strides: [?, 1]>, 
+    memref<?xf32, offset: ?, strides: [1]>, 
+    memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // TILE-2-LABEL: func @matvec(
@@ -217,7 +220,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
 //       TILE-2:   %[[localN:.*]] = dim %{{.*}}, %c0
 //       TILE-2:   %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]]
 //       TILE-2:   %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-//       TILE-2:   linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
+//       TILE-2:   linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
 
 // TILE-02-LABEL: func @matvec(
 // TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -234,7 +237,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
 //       TILE-02:   %[[localN:.*]] = dim %{{.*}}, %c0
 //       TILE-02:   %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]]
 //       TILE-02:   %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-//       TILE-02:   linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
+//       TILE-02:   linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
 
 // TILE-002-LABEL: func @matvec(
 // TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -265,7 +268,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
 //       TILE-234:      %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
 //       TILE-234:      %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
 //
-//       TILE-234:      linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
+//       TILE-234:      linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
 
 func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
   linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index f2ae0ba76ed0..9eedc31ef43a 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -36,10 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
 func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %x: memref<?xf32, offset: ?, strides: [1]>,
              %y: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec(%A, %x, %y) :
-                memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                memref<?xf32, offset: ?, strides: [1]>,
-                memref<?xf32, offset: ?, strides: [1]>
+  linalg.matvec %A, %x, %y :
+                (memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                 memref<?xf32, offset: ?, strides: [1]>,
+                 memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // CHECK-LABEL: func @matvec
@@ -48,7 +48,7 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK-DAG:     %[[c6:.*]] = constant 6 : index
 // CHECK:         scf.parallel {{.*}} step (%[[c5]])
 // CHECK:           scf.for {{.*}} step %[[c6]]
-// CHECK:             linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>
+// CHECK:             linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
 
 func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -202,10 +202,10 @@ func @permute_generic_indexed(
 func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %x: memref<?xf32, offset: ?, strides: [1]>,
              %y: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec(%A, %x, %y) {__internal_linalg_transform__ = "__with_perm__"} :
-               memref<?x?xf32, offset: ?, strides: [?, 1]>,
-               memref<?xf32, offset: ?, strides: [1]>,
-               memref<?xf32, offset: ?, strides: [1]>
+  linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} :
+               (memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?xf32, offset: ?, strides: [1]>,
+                memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // CHECK-LABEL: func @matvec_perm
@@ -214,7 +214,7 @@ func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK-DAG:     %[[c6:.*]] = constant 6 : index
 // CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
 // CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
-// CHECK:             linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>
+// CHECK:             linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
 
 func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,


        


More information about the Mlir-commits mailing list