[Mlir-commits] [mlir] 137415a - [mlir][EDSC][Linalg] Compose linalg_matmul and vector.contract
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Feb 12 10:55:01 PST 2020
Author: Nicolas Vasilache
Date: 2020-02-12T13:50:50-05:00
New Revision: 137415ad285b1d3c1fa1dfb8f44c2ac62d3ebbe4
URL: https://github.com/llvm/llvm-project/commit/137415ad285b1d3c1fa1dfb8f44c2ac62d3ebbe4
DIFF: https://github.com/llvm/llvm-project/commit/137415ad285b1d3c1fa1dfb8f44c2ac62d3ebbe4.diff
LOG: [mlir][EDSC][Linalg] Compose linalg_matmul and vector.contract
Summary:
This revision allows model builder to create a linalg_matmul whose body
is a vector.contract. This shows the abstractions compose nicely.
Differential Revision: https://reviews.llvm.org/D74457
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/test/EDSC/builder-api-test.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index 3bd83b433589..2cd80259d33e 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -18,6 +18,9 @@
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
namespace mlir {
class AffineForOp;
@@ -127,8 +130,12 @@ using edsc::ValueHandle;
// EDSC builders for linalg generic operations.
//===----------------------------------------------------------------------===//
-/// Build the body of a region to compute a multiply-accumulate, under the
-/// current ScopedContext, at the current insert point.
+/// Build the body of a region to compute a scalar multiply, under the current
+/// ScopedContext, at the current insert point.
+void mulRegionBuilder(ArrayRef<BlockArgument> args);
+
+/// Build the body of a region to compute a scalar multiply-accumulate, under
+/// the current ScopedContext, at the current insert point.
void macRegionBuilder(ArrayRef<BlockArgument> args);
/// TODO(ntv): In the future we should tie these implementations to something in
@@ -182,6 +189,8 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
// TODO(ntv): Implement more useful pointwise operations on a per-need basis.
+using MatmulRegionBuilder = function_ref<void(ArrayRef<BlockArgument> args)>;
+
/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
/// ```
@@ -189,7 +198,8 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
/// |
/// | C(m, n) += A(m, k) * B(k, n)
/// ```
-Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
+Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
+ MatmulRegionBuilder regionBuilder = macRegionBuilder);
/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
@@ -199,7 +209,8 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
/// | C(m, n) = sum_k(A(m, k) * B(k, n))
/// ```
/// and returns the tensor `C`.
-Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC);
+Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
+ MatmulRegionBuilder regionBuilder = mulRegionBuilder);
/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
@@ -210,11 +221,14 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC);
/// ```
/// and returns the tensor `D`.
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
- RankedTensorType tD);
+ RankedTensorType tD,
+ MatmulRegionBuilder regionBuilder = macRegionBuilder);
-template <typename Container> Operation *linalg_matmul(Container values) {
+template <typename Container>
+Operation *linalg_matmul(Container values,
+ MatmulRegionBuilder regionBuilder = macRegionBuilder) {
assert(values.size() == 3 && "Expected exactly 3 values");
- return linalg_matmul(values[0], values[1], values[2]);
+ return linalg_matmul(values[0], values[1], values[2], regionBuilder);
}
/// Build a linalg.generic, under the current ScopedContext, at the current
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 296370af03e1..25d6091f9fed 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -212,6 +212,14 @@ static void mulRegionBuilder(ArrayRef<BlockArgument> args) {
linalg_yield((a * b).getValue());
}
+void mlir::edsc::ops::mulRegionBuilder(ArrayRef<BlockArgument> args) {
+ using edsc::op::operator+;
+ using edsc::op::operator*;
+ assert(args.size() == 2 && "expected 2 block arguments");
+ ValueHandle a(args[0]), b(args[1]);
+ linalg_yield((a * b).getValue());
+}
+
void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
using edsc::op::operator+;
using edsc::op::operator*;
@@ -291,7 +299,8 @@ Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
- ValueHandle vC) {
+ ValueHandle vC,
+ MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
@@ -300,12 +309,13 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
{A({m, k}), B({k, n})},
{C({m, n})},
- macRegionBuilder);
+ regionBuilder);
// clang-format on
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
- RankedTensorType tC) {
+ RankedTensorType tC,
+ MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
@@ -314,12 +324,13 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
{A({m, k}), B({k, n})},
{C({m, n})},
- mulRegionBuilder);
+ regionBuilder);
// clang-format on
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
- ValueHandle vC, RankedTensorType tD) {
+ ValueHandle vC, RankedTensorType tD,
+ MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
@@ -328,7 +339,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
{A({m, k}), B({k, n}), C({m, n})},
{D({m, n})},
- macRegionBuilder);
+ regionBuilder);
// clang-format on
}
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index ea7c0b530ff8..ea9dde9f8906 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -990,17 +990,20 @@ TEST_FUNC(linalg_tensors_test) {
f.erase();
}
-// CHECK-LABEL: func @vector_matmul_test(
-// CHECK-SAME: %[[A:.*]]: vector<4x16xf32>,
-// CHECK-SAME: %[[B:.*]]: vector<16x8xf32>,
-// CHECK-SAME: %[[C:.*]]: vector<4x8xf32>)
-// CHECK: vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
-// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
-// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
-// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"]
-// CHECK-SAME: %[[A]], %[[B]], %[[C]]
-// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
-TEST_FUNC(vector_matmul_test) {
+// CHECK-LABEL: func @memref_vector_matmul_test(
+// CHECK-SAME: %[[A:.*]]: memref<?x?xvector<4x16xf32>>,
+// CHECK-SAME: %[[B:.*]]: memref<?x?xvector<16x8xf32>>,
+// CHECK-SAME: %[[C:.*]]: memref<?x?xvector<4x8xf32>>)
+// CHECK: linalg.generic {{.*}} %[[A]], %[[B]], %[[C]]
+// CHECK: vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0,
+// d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
+// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"]
+// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
+// CHECK: memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>,
+// CHECK-SAME: memref<?x?xvector<4x8xf32>>
+TEST_FUNC(memref_vector_matmul_test) {
using namespace edsc;
using namespace edsc::ops;
@@ -1009,13 +1012,26 @@ TEST_FUNC(vector_matmul_test) {
auto mkVectorType = VectorType::get({M, K}, f32Type);
auto knVectorType = VectorType::get({K, N}, f32Type);
auto mnVectorType = VectorType::get({M, N}, f32Type);
- auto f = makeFunction("vector_matmul_test", {},
- {mkVectorType, knVectorType, mnVectorType});
+ auto typeA =
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+ mkVectorType, {}, 0);
+ auto typeB =
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+ knVectorType, {}, 0);
+ auto typeC =
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+ mnVectorType, {}, 0);
+ auto f = makeFunction("memref_vector_matmul_test", {}, {typeA, typeB, typeC});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
- vector_matmul(A, B, C);
+ auto contractionBuilder = [](ArrayRef<BlockArgument> args) {
+ assert(args.size() == 3 && "expected 3 block arguments");
+ (linalg_yield(vector_matmul(args[0], args[1], args[2])));
+ };
+ linalg_matmul(A, B, C, contractionBuilder);
+
f.print(llvm::outs());
f.erase();
}
More information about the Mlir-commits
mailing list