[Mlir-commits] [mlir] 137415a - [mlir][EDSC][Linalg] Compose linalg_matmul and vector.contract

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 12 10:55:01 PST 2020


Author: Nicolas Vasilache
Date: 2020-02-12T13:50:50-05:00
New Revision: 137415ad285b1d3c1fa1dfb8f44c2ac62d3ebbe4

URL: https://github.com/llvm/llvm-project/commit/137415ad285b1d3c1fa1dfb8f44c2ac62d3ebbe4
DIFF: https://github.com/llvm/llvm-project/commit/137415ad285b1d3c1fa1dfb8f44c2ac62d3ebbe4.diff

LOG: [mlir][EDSC][Linalg] Compose linalg_matmul and vector.contract

Summary:
This revision allows model builder to create a linalg_matmul whose body
is a vector.contract. This shows the abstractions compose nicely.

Differential Revision: https://reviews.llvm.org/D74457

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
    mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
    mlir/test/EDSC/builder-api-test.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index 3bd83b433589..2cd80259d33e 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -18,6 +18,9 @@
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
 
 namespace mlir {
 class AffineForOp;
@@ -127,8 +130,12 @@ using edsc::ValueHandle;
 // EDSC builders for linalg generic operations.
 //===----------------------------------------------------------------------===//
 
-/// Build the body of a region to compute a multiply-accumulate, under the
-/// current ScopedContext, at the current insert point.
+/// Build the body of a region to compute a scalar multiply, under the current
+/// ScopedContext, at the current insert point.
+void mulRegionBuilder(ArrayRef<BlockArgument> args);
+
+/// Build the body of a region to compute a scalar multiply-accumulate, under
+/// the current ScopedContext, at the current insert point.
 void macRegionBuilder(ArrayRef<BlockArgument> args);
 
 /// TODO(ntv): In the future we should tie these implementations to something in
@@ -182,6 +189,8 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
 
 // TODO(ntv): Implement more useful pointwise operations on a per-need basis.
 
+using MatmulRegionBuilder = function_ref<void(ArrayRef<BlockArgument> args)>;
+
 /// Build a linalg.generic, under the current ScopedContext, at the current
 /// insert point, that computes:
 /// ```
@@ -189,7 +198,8 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
 ///    |
 ///    |  C(m, n) += A(m, k) * B(k, n)
 /// ```
-Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
+Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
+                         MatmulRegionBuilder regionBuilder = macRegionBuilder);
 
 /// Build a linalg.generic, under the current ScopedContext, at the current
 /// insert point, that computes:
@@ -199,7 +209,8 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
 ///    |  C(m, n) = sum_k(A(m, k) * B(k, n))
 /// ```
 /// and returns the tensor `C`.
-Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC);
+Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
+                         MatmulRegionBuilder regionBuilder = mulRegionBuilder);
 
 /// Build a linalg.generic, under the current ScopedContext, at the current
 /// insert point, that computes:
@@ -210,11 +221,14 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC);
 /// ```
 /// and returns the tensor `D`.
 Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
-                         RankedTensorType tD);
+                         RankedTensorType tD,
+                         MatmulRegionBuilder regionBuilder = macRegionBuilder);
 
-template <typename Container> Operation *linalg_matmul(Container values) {
+template <typename Container>
+Operation *linalg_matmul(Container values,
+                         MatmulRegionBuilder regionBuilder = macRegionBuilder) {
   assert(values.size() == 3 && "Expected exactly 3 values");
-  return linalg_matmul(values[0], values[1], values[2]);
+  return linalg_matmul(values[0], values[1], values[2], regionBuilder);
 }
 
 /// Build a linalg.generic, under the current ScopedContext, at the current

diff  --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 296370af03e1..25d6091f9fed 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -212,6 +212,14 @@ static void mulRegionBuilder(ArrayRef<BlockArgument> args) {
   linalg_yield((a * b).getValue());
 }
 
+void mlir::edsc::ops::mulRegionBuilder(ArrayRef<BlockArgument> args) {
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  assert(args.size() == 2 && "expected 2 block arguments");
+  ValueHandle a(args[0]), b(args[1]);
+  linalg_yield((a * b).getValue());
+}
+
 void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
   using edsc::op::operator+;
   using edsc::op::operator*;
@@ -291,7 +299,8 @@ Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
 }
 
 Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
-                                          ValueHandle vC) {
+                                          ValueHandle vC,
+                                          MatmulRegionBuilder regionBuilder) {
   // clang-format off
   AffineExpr m, n, k;
   bindDims(ScopedContext::getContext(), m, n, k);
@@ -300,12 +309,13 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
     {A({m, k}), B({k, n})},
     {C({m, n})},
-    macRegionBuilder);
+    regionBuilder);
   // clang-format on
 }
 
 Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
-                                          RankedTensorType tC) {
+                                          RankedTensorType tC,
+                                          MatmulRegionBuilder regionBuilder) {
   // clang-format off
   AffineExpr m, n, k;
   bindDims(ScopedContext::getContext(), m, n, k);
@@ -314,12 +324,13 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
     {A({m, k}), B({k, n})},
     {C({m, n})},
-    mulRegionBuilder);
+    regionBuilder);
   // clang-format on
 }
 
 Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
-                                          ValueHandle vC, RankedTensorType tD) {
+                                          ValueHandle vC, RankedTensorType tD,
+                                          MatmulRegionBuilder regionBuilder) {
   // clang-format off
   AffineExpr m, n, k;
   bindDims(ScopedContext::getContext(), m, n, k);
@@ -328,7 +339,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
     {A({m, k}), B({k, n}), C({m, n})},
     {D({m, n})},
-    macRegionBuilder);
+    regionBuilder);
   // clang-format on
 }
 

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index ea7c0b530ff8..ea9dde9f8906 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -990,17 +990,20 @@ TEST_FUNC(linalg_tensors_test) {
   f.erase();
 }
 
-// CHECK-LABEL: func @vector_matmul_test(
-//  CHECK-SAME:   %[[A:.*]]: vector<4x16xf32>,
-//  CHECK-SAME:   %[[B:.*]]: vector<16x8xf32>,
-//  CHECK-SAME:   %[[C:.*]]: vector<4x8xf32>)
-//  CHECK:   vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
-//  CHECK-SAME:                     affine_map<(d0, d1, d2) -> (d2, d1)>,
-//  CHECK-SAME:                     affine_map<(d0, d1, d2) -> (d0, d1)>],
-//  CHECK-SAME:              {{.*}}["parallel", "parallel", "reduction"]
-//  CHECK-SAME: %[[A]], %[[B]], %[[C]]
-//  CHECK-SAME:   vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
-TEST_FUNC(vector_matmul_test) {
+// CHECK-LABEL: func @memref_vector_matmul_test(
+//  CHECK-SAME:   %[[A:.*]]: memref<?x?xvector<4x16xf32>>,
+//  CHECK-SAME:   %[[B:.*]]: memref<?x?xvector<16x8xf32>>,
+//  CHECK-SAME:   %[[C:.*]]: memref<?x?xvector<4x8xf32>>)
+//       CHECK:   linalg.generic {{.*}} %[[A]], %[[B]], %[[C]]
+//       CHECK:     vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0,
+//  d2)>,
+//  CHECK-SAME:                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+//  CHECK-SAME:                       affine_map<(d0, d1, d2) -> (d0, d1)>],
+//  CHECK-SAME:                {{.*}}["parallel", "parallel", "reduction"]
+//  CHECK-SAME:     vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
+//       CHECK:   memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>,
+//  CHECK-SAME:   memref<?x?xvector<4x8xf32>>
+TEST_FUNC(memref_vector_matmul_test) {
   using namespace edsc;
   using namespace edsc::ops;
 
@@ -1009,13 +1012,26 @@ TEST_FUNC(vector_matmul_test) {
   auto mkVectorType = VectorType::get({M, K}, f32Type);
   auto knVectorType = VectorType::get({K, N}, f32Type);
   auto mnVectorType = VectorType::get({M, N}, f32Type);
-  auto f = makeFunction("vector_matmul_test", {},
-                        {mkVectorType, knVectorType, mnVectorType});
+  auto typeA =
+      MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+                      mkVectorType, {}, 0);
+  auto typeB =
+      MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+                      knVectorType, {}, 0);
+  auto typeC =
+      MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+                      mnVectorType, {}, 0);
+  auto f = makeFunction("memref_vector_matmul_test", {}, {typeA, typeB, typeC});
 
   OpBuilder builder(f.getBody());
   ScopedContext scope(builder, f.getLoc());
   ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
-  vector_matmul(A, B, C);
+  auto contractionBuilder = [](ArrayRef<BlockArgument> args) {
+    assert(args.size() == 3 && "expected 3 block arguments");
+    (linalg_yield(vector_matmul(args[0], args[1], args[2])));
+  };
+  linalg_matmul(A, B, C, contractionBuilder);
+
   f.print(llvm::outs());
   f.erase();
 }


        


More information about the Mlir-commits mailing list