[Mlir-commits] [mlir] 8513ff0 - [mlir][VectorOps][EDSC] Add EDSC for VectorOps
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Feb 10 12:01:49 PST 2020
Author: Nicolas Vasilache
Date: 2020-02-10T15:01:14-05:00
New Revision: 8513ff05c81e60f771aa58846b37840f979a2777
URL: https://github.com/llvm/llvm-project/commit/8513ff05c81e60f771aa58846b37840f979a2777
DIFF: https://github.com/llvm/llvm-project/commit/8513ff05c81e60f771aa58846b37840f979a2777.diff
LOG: [mlir][VectorOps][EDSC] Add EDSC for VectorOps
Summary:
This revision adds EDSC support for VectorOps to enable the creation of a `vector_matmul` declaratively. The `vector_matmul` is a simple configuration
of the `vector.contract` op that follows the StructuredOps abstraction.
Differential Revision: https://reviews.llvm.org/D74284
Added:
mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h
mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h
mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp
Modified:
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/include/mlir/EDSC/Builders.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/lib/Dialect/VectorOps/CMakeLists.txt
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/EDSC/CMakeLists.txt
mlir/test/EDSC/builder-api-test.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h
new file mode 100644
index 000000000000..ba6e6b1ebc9e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/VectorOps/EDSC/Builders.h
@@ -0,0 +1,53 @@
+//===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides intuitive composable interfaces for building structured MLIR
+// snippets in a declarative fashion.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
+#define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+
+namespace mlir {
+namespace edsc {
+namespace ops {
+
+/// Build a generic vector contraction, that is a `vector.contract` op with
+/// specified `iteratorTypes`. The client is responsible for specifying proper
+/// indexings when creating the StructuredIndexed.
+/// The computation represents a notional (A * B + C) where indexings specify
+/// which dimensions are reduced and reordered.
+/// Return the result of the `vector.contract` op
+///
+/// Prerequisites:
+/// A, B and C capture values of proper vector types, and indexing expressions
+/// that match semantics of the `vector.contract` op.
+Value vector_contraction(StructuredIndexed A, StructuredIndexed B,
+ StructuredIndexed C,
+ ArrayRef<IteratorType> iteratorTypes);
+
+/// Build a generic vector contraction that computes a matmul on vectors.
+/// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors.
+///
+/// Prerequisites:
+/// A, B and C capture values of proper vector types. For instance
+/// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`.
+Value vector_matmul(Value A, Value B, Value C);
+
+} // namespace ops
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
diff --git a/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h
new file mode 100644
index 000000000000..a8ffd7648b22
--- /dev/null
+++ b/mlir/include/mlir/Dialect/VectorOps/EDSC/Intrinsics.h
@@ -0,0 +1,23 @@
+//===- Intrinsics.h - MLIR EDSC Intrinsics for VectorOps --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
+#define MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
+
+#include "mlir/Dialect/VectorOps/EDSC/Builders.h"
+
+namespace mlir {
+namespace edsc {
+namespace intrinsics {
+
+using vector_contract = ValueBuilder<vector::ContractionOp>;
+
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index e7e165aa381f..de7007e3b509 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -141,7 +141,11 @@ def Vector_ContractionOp :
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
- "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
+ "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">,
+ OpBuilder<
+ "Builder *builder, OperationState &result, Value lhs, Value rhs, "
+ "Value acc, ArrayRef<ArrayRef<AffineExpr>> indexingExprs, "
+ "ArrayRef<StringRef> iteratorTypes">];
let extraClassDeclaration = [{
VectorType getLhsType() {
return lhs().getType().cast<VectorType>();
diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index dafa09bc6628..955b6c81cc66 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -436,8 +436,9 @@ struct StructuredIndexed : public ValueHandle {
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
assert((v.getType().isa<MemRefType>() ||
- v.getType().isa<RankedTensorType>()) &&
- "MemRef or RankedTensor expected");
+ v.getType().isa<RankedTensorType>() ||
+ v.getType().isa<VectorType>()) &&
+ "MemRef, RankedTensor or Vector expected");
}
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 1311566da323..0bf52e32f3ab 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -63,6 +63,14 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ /// Returns a vector of AffineMaps; each with as many results as
+ /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
+ /// symbols as the largest symbol in `exprs`.
+ static SmallVector<AffineMap, 4>
+ inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList);
+ static SmallVector<AffineMap, 4>
+ inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList);
+
MLIRContext *getContext() const;
explicit operator bool() { return map != nullptr; }
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index f2f8e5551522..296370af03e1 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -130,18 +130,6 @@ GenericLoopNestRangeBuilder<loop::ParallelOp>::GenericLoopNestRangeBuilder(
} // namespace edsc
} // namespace mlir
-static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
- unsigned &pos) {
- for (auto sidx : structuredIndices) {
- for (auto expr : sidx.getExprs()) {
- expr.walk([&pos](AffineExpr e) {
- if (auto d = e.dyn_cast<AffineDimExpr>())
- pos = std::max(pos, d.getPosition());
- });
- }
- }
-}
-
Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
ArrayRef<StructuredIndexed> outputs,
@@ -155,20 +143,16 @@ Operation *mlir::edsc::makeGenericLinalgOp(
auto *ctx = builder.getContext();
unsigned nInputs = inputs.size();
unsigned nOutputs = outputs.size();
- unsigned maxPos = 0;
- getMaxDimIndex(inputs, maxPos);
- getMaxDimIndex(outputs, maxPos);
- // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
- unsigned nDims = maxPos + 1;
-
- SmallVector<AffineMap, 4> maps;
- maps.reserve(nInputs + nOutputs);
- for (auto in : inputs)
- maps.push_back(
- AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
- for (auto out : outputs)
- maps.push_back(
- AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
+
+ SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
+ exprsList.reserve(nInputs + nOutputs);
+ for (auto structuredIndexed : inputs)
+ exprsList.emplace_back(structuredIndexed.getExprs().begin(),
+ structuredIndexed.getExprs().end());
+ for (auto structuredIndexed : outputs)
+ exprsList.emplace_back(structuredIndexed.getExprs().begin(),
+ structuredIndexed.getExprs().end());
+ auto maps = AffineMap::inferFromExprList(exprsList);
unsigned nViews = nInputs + nOutputs;
SmallVector<Value, 4> values;
diff --git a/mlir/lib/Dialect/VectorOps/CMakeLists.txt b/mlir/lib/Dialect/VectorOps/CMakeLists.txt
index 41db7fafe03e..2a9071036331 100644
--- a/mlir/lib/Dialect/VectorOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/VectorOps/CMakeLists.txt
@@ -3,6 +3,7 @@ add_llvm_library(MLIRVectorOps
VectorOps.cpp
VectorTransforms.cpp
VectorUtils.cpp
+ EDSC/Builders.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps
diff --git a/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp b/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp
new file mode 100644
index 000000000000..163000f1cac1
--- /dev/null
+++ b/mlir/lib/Dialect/VectorOps/EDSC/Builders.cpp
@@ -0,0 +1,41 @@
+//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/VectorOps/EDSC/Builders.h"
+#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/Functional.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::edsc::ops;
+
+Value mlir::edsc::ops::vector_contraction(
+ StructuredIndexed A, StructuredIndexed B, StructuredIndexed C,
+ ArrayRef<IteratorType> iteratorTypes) {
+ using IndexingExprs = ArrayRef<ArrayRef<AffineExpr>>;
+ return vector_contract(
+ A.getValue(), B.getValue(), C.getValue(),
+ IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()},
+ ArrayRef<StringRef>{functional::map(toString, iteratorTypes)});
+}
+
+Value mlir::edsc::ops::vector_matmul(Value A, Value B, Value C) {
+ AffineExpr m, n, k;
+ bindDims(ScopedContext::getContext(), m, n, k);
+ return vector_contraction(StructuredIndexed(A, {m, k}),
+ StructuredIndexed(B, {k, n}),
+ StructuredIndexed(C, {m, n}),
+ {IteratorType::Parallel, IteratorType::Parallel,
+ IteratorType::Reduction});
+}
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 35de10270f97..a987a54f5ea1 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -64,6 +64,19 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
// ContractionOp
//===----------------------------------------------------------------------===//
+void vector::ContractionOp::build(Builder *builder, OperationState &result,
+ Value lhs, Value rhs, Value acc,
+ ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
+ ArrayRef<StringRef> iteratorTypes) {
+ result.addOperands({lhs, rhs, acc});
+ result.addTypes(acc.getType());
+ result.addAttribute(getIndexingMapsAttrName(),
+ builder->getAffineMapArrayAttr(
+ AffineMap::inferFromExprList(indexingExprs)));
+ result.addAttribute(getIteratorTypesAttrName(),
+ builder->getStrArrayAttr(iteratorTypes));
+}
+
void vector::ContractionOp::build(Builder *builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayAttr indexingMaps,
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9dd0e96f2cae..72b9fa792fc2 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -111,6 +111,44 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
return permutationMap;
}
+template <typename AffineExprContainer>
+static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
+ int64_t &maxDim, int64_t &maxSym) {
+ for (const auto &exprs : exprsList) {
+ for (auto expr : exprs) {
+ expr.walk([&maxDim, &maxSym](AffineExpr e) {
+ if (auto d = e.dyn_cast<AffineDimExpr>())
+ maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
+ if (auto s = e.dyn_cast<AffineSymbolExpr>())
+ maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
+ });
+ }
+ }
+}
+
+template <typename AffineExprContainer>
+SmallVector<AffineMap, 4>
+inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
+ int64_t maxDim = -1, maxSym = -1;
+ getMaxDimAndSymbol(exprsList, maxDim, maxSym);
+ SmallVector<AffineMap, 4> maps;
+ maps.reserve(exprsList.size());
+ for (const auto &exprs : exprsList)
+ maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
+ /*symbolCount=*/maxSym + 1, exprs));
+ return maps;
+}
+
+SmallVector<AffineMap, 4>
+AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
+ return ::inferFromExprList(exprsList);
+}
+
+SmallVector<AffineMap, 4>
+AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
+ return ::inferFromExprList(exprsList);
+}
+
AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
MLIRContext *context) {
SmallVector<AffineExpr, 4> dimExprs;
diff --git a/mlir/test/EDSC/CMakeLists.txt b/mlir/test/EDSC/CMakeLists.txt
index 0bdf1653bbc1..b922c594dffd 100644
--- a/mlir/test/EDSC/CMakeLists.txt
+++ b/mlir/test/EDSC/CMakeLists.txt
@@ -14,6 +14,7 @@ target_link_libraries(mlir-edsc-builder-api-test
MLIRLoopOps
MLIRStandardOps
MLIRTransforms
+ MLIRVectorOps
LLVMCore
LLVMSupport
)
@@ -25,5 +26,6 @@ whole_archive_link(mlir-edsc-builder-api-test
MLIRLinalgOps
MLIRLoopOps
MLIRStandardOps
+ MLIRVectorOps
MLIRTransforms
)
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 68b7759b036b..92f05680e432 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
@@ -981,6 +982,36 @@ TEST_FUNC(linalg_tensors_test) {
f.erase();
}
+// CHECK-LABEL: func @vector_matmul_test(
+// CHECK-SAME: %[[A:.*]]: vector<4x16xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<16x8xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<4x8xf32>)
+// CHECK: vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
+// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
+TEST_FUNC(vector_matmul_test) {
+ using namespace edsc;
+ using namespace edsc::ops;
+
+ int64_t M = 4, N = 8, K = 16;
+ auto f32Type = FloatType::getF32(&globalContext());
+ auto mkVectorType = VectorType::get({M, K}, f32Type);
+ auto knVectorType = VectorType::get({K, N}, f32Type);
+ auto mnVectorType = VectorType::get({M, N}, f32Type);
+ auto f = makeFunction("vector_matmul_test", {},
+ {mkVectorType, knVectorType, mnVectorType});
+
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
+ vector_matmul(A, B, C);
+ f.print(llvm::outs());
+ f.erase();
+}
+
int main() {
RUN_TESTS();
return 0;
More information about the Mlir-commits
mailing list