[Mlir-commits] [mlir] e31e8f1 - [MLIR][Linalg] Retire C++ MatvecOp in favor of a linalg-ods-gen'd op
Alex Zinenko
llvmlistbot at llvm.org
Thu Jun 18 02:36:59 PDT 2020
Author: lorenzo chelini
Date: 2020-06-18T11:36:49+02:00
New Revision: e31e8f1ed57eb25584903f1a67040babf2c48eda
URL: https://github.com/llvm/llvm-project/commit/e31e8f1ed57eb25584903f1a67040babf2c48eda
DIFF: https://github.com/llvm/llvm-project/commit/e31e8f1ed57eb25584903f1a67040babf2c48eda.diff
LOG: [MLIR][Linalg] Retire C++ MatvecOp in favor of a linalg-ods-gen'd op
Replace C++ MatvecOp, now that DRR rules have been dropped.
Differential Revision: https://reviews.llvm.org/D82007
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index c510010acd0b..bbd398585e5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -3,6 +3,11 @@ def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
}
+ods_def<MatvecOp>:
+def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
+ x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
+}
+
ods_def<BatchMatmulOp>:
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1292344484b0..cddd4f9b22f8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -197,34 +197,6 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
let hasFolder = 1;
}
-def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
-
- let arguments = (ins AnyStridedMemRefOfRank<2>,
- AnyStridedMemRefOfRank<1>,
- AnyStridedMemRefOfRank<1>);
-
- let extraClassDeclaration = libraryCallName # [{
- llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
- return SmallVector<StringRef, 8>{
- getParallelIteratorTypeName(), getReductionIteratorTypeName()};
- }
-
- // A(i, r_j) * B(r_j) -> C(i)
- llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
- MLIRContext *context = getContext();
- AffineExpr i, r_j;
- bindDims(context, i, r_j);
- return SmallVector<AffineMap, 8>{
- AffineMap::get(2, 0, {i, r_j}, context),
- AffineMap::get(2, 0, {r_j}, context),
- AffineMap::get(2, 0, {i}, context)
- };
- }
- }];
-
- let hasFolder = 1;
-}
-
/// A base class for pooling operation such as conv. The arguments must contain
/// optional arguments `strides`, `dilations` and `padding` with following type:
/// OptionalAttr<I64ArrayAttr>:$strides
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 7b30646136cb..7f13a7a609e9 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -240,11 +240,11 @@ void mlir::populateLinalgToStandardConversionPatterns(
LinalgOpConversion<DotOp>,
LinalgOpConversion<FillOp>,
LinalgOpConversion<GenericOp>,
- LinalgOpConversion<IndexedGenericOp>,
- LinalgOpConversion<MatvecOp>>(ctx);
+ LinalgOpConversion<IndexedGenericOp>>(ctx);
// TODO: collect all auto-generated named ops with a tblgen directive.
patterns.insert<
LinalgOpConversion<BatchMatmulOp>,
+ LinalgOpConversion<MatvecOp>,
LinalgOpConversion<MatmulOp>>(ctx);
// clang-format on
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2342cae661c8..c8401977d612 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,10 +1124,6 @@ LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
-LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
- SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
-}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
@@ -1242,3 +1238,7 @@ LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
+LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 5ccf2a469dce..d031712ce5a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -242,17 +242,6 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
// Emit scalar form.
C() = C() + A(r_i) * B(r_i);
}
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, MatvecOp matvecOp) {
- assert(matvecOp.hasBufferSemantics() &&
- "expected linalg op with buffer semantics");
- assert(allIvs.size() == 2);
- Value i(allIvs[0]), r_j(allIvs[1]);
- IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
- C(matvecOp.getOutputBuffer(0));
- // Emit scalar form.
- C(i) = C(i) + A(i, r_j) * B(r_j);
-}
template <typename IndexedValueType>
Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
@@ -624,8 +613,6 @@ Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
if (isa<DotOp>(op))
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
- if (isa<MatvecOp>(op))
- return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
if (isa<ConvOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
if (isa<PoolingMaxOp>(op))
@@ -642,6 +629,8 @@ Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
if (isa<MatmulOp>(op))
return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
+ if (isa<MatvecOp>(op))
+ return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
if (isa<BatchMatmulOp>(op))
return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 020e43da00b4..f03129c4d8be 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -77,7 +77,7 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
%2 = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
%3 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
%4 = view %arg0[%c0][%N] : memref<?xi8> to memref<?xf32>
- linalg.matvec(%2, %3, %4) : memref<?x?xf32>, memref<?xf32>, memref<?xf32>
+ linalg.matvec %2, %3, %4 : (memref<?x?xf32>, memref<?xf32>, memref<?xf32>)
return
}
// CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref<?xi8>,
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index aaa2890060e6..f210b185331c 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -86,9 +86,9 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
linalg.matmul %arg0, %arg0, %arg0 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>)
- linalg.matvec(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ linalg.matvec %arg0, %arg1, %arg2 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>
+ memref<?xf32, offset: ?, strides: [1]>)
linalg.dot(%arg1, %arg2, %arg3) : memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>,
memref<f32>
@@ -99,10 +99,10 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-SAME: (memref<?x?xf32, #[[$strided2D]]>,
// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]>,
// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]>)
-// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) :
-// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]>,
+// CHECK-NEXT: linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} :
+// CHECK-SAME: (memref<?x?xf32, #[[$strided2D]]>,
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
-// CHECK-SAME: memref<?xf32, #[[$strided1D]]>
+// CHECK-SAME: memref<?xf32, #[[$strided1D]]>)
// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) :
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index a36adf242d63..049fb571bd51 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -199,7 +199,10 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-234: memref<?x?xf32, #[[$strided2D]]>)
func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
+ linalg.matvec %arg0, %arg1, %arg2 : (
+ memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
return
}
// TILE-2-LABEL: func @matvec(
@@ -217,7 +220,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-2: %[[localN:.*]] = dim %{{.*}}, %c0
// TILE-2: %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]]
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
+// TILE-2: linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
// TILE-02-LABEL: func @matvec(
// TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -234,7 +237,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-02: %[[localN:.*]] = dim %{{.*}}, %c0
// TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]]
// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
-// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
+// TILE-02: linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
// TILE-002-LABEL: func @matvec(
// TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -265,7 +268,7 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
// TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
//
-// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
+// TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index f2ae0ba76ed0..9eedc31ef43a 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -36,10 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec(%A, %x, %y) :
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>
+ linalg.matvec %A, %x, %y :
+ (memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
return
}
// CHECK-LABEL: func @matvec
@@ -48,7 +48,7 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.parallel {{.*}} step (%[[c5]])
// CHECK: scf.for {{.*}} step %[[c6]]
-// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>
+// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -202,10 +202,10 @@ func @permute_generic_indexed(
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec(%A, %x, %y) {__internal_linalg_transform__ = "__with_perm__"} :
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>
+ linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} :
+ (memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>)
return
}
// CHECK-LABEL: func @matvec_perm
@@ -214,7 +214,7 @@ func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
-// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>
+// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?xf32, #[[$STRIDED_1D]]>, memref<?xf32, #[[$STRIDED_1D]]>)
func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
More information about the Mlir-commits
mailing list