[Mlir-commits] [mlir] 56c638b - [mlir][Linalg] Generalize Vectorization of Linalg contractions
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jul 10 07:29:46 PDT 2020
Author: Nicolas Vasilache
Date: 2020-07-10T10:28:34-04:00
New Revision: 56c638b5c1caf018da3fa1a95b603267e607c89c
URL: https://github.com/llvm/llvm-project/commit/56c638b5c1caf018da3fa1a95b603267e607c89c
DIFF: https://github.com/llvm/llvm-project/commit/56c638b5c1caf018da3fa1a95b603267e607c89c.diff
LOG: [mlir][Linalg] Generalize Vectorization of Linalg contractions
This revision adds support for vectorizing named and generic contraction ops to vector.contract. Cases in which the memref is 0-D are special cased to emit std.load/std.store instead of vector.transfer. Relevant tests are added.
Differential revision: https://reviews.llvm.org/D83307
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6436bb9550e8..89dad2ec40cf 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -286,6 +286,12 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
attr_value_iterator<AttrTy>(end()));
}
+ template <typename AttrTy, typename UnderlyingTy>
+ auto getAsRange() {
+ return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
+ return static_cast<UnderlyingTy>(attr.getValue());
+ });
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bba7b2a10030..bbdb8e7b46b4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -36,8 +36,7 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
-static bool hasMultiplyAddBody(linalg::GenericOp op) {
- auto &r = op.region();
+static bool hasMultiplyAddBody(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
@@ -59,14 +58,26 @@ static bool hasMultiplyAddBody(linalg::GenericOp op) {
}
// TODO: Should be Tablegen'd from a single source that generates the op itself.
-static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
- return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
- isRowMajorMatmul(genericOp.indexing_maps()) &&
- hasMultiplyAddBody(genericOp);
+static LogicalResult isContraction(Operation *op) {
+ // TODO: interface for named ops.
+ if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
+ linalg::DotOp>(op))
+ return success();
+
+ auto genericOp = dyn_cast<linalg::GenericOp>(op);
+ if (!genericOp)
+ return failure();
+
+ auto mapRange =
+ genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
+
+ return success(
+ genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
+ llvm::all_of(mapRange,
+ [](AffineMap m) { return m.isProjectedPermutation(); }) &&
+ hasMultiplyAddBody(genericOp.region()));
}
-// TODO: This is in fact much more general than just vectorization for matmul
-// and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
@@ -76,33 +87,16 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
- if (isa<linalg::MatmulOp, linalg::FillOp>(op))
- return success();
- auto genericOp = dyn_cast<linalg::GenericOp>(op);
- if (!genericOp || !::isRowMajorMatmul(genericOp))
- return failure();
+ if (isa<linalg::FillOp>(op))
+ return success();
- // TODO: non-identity layout.
- auto isStaticMemRefWithIdentityLayout = [](Value v) {
- auto m = v.getType().dyn_cast<MemRefType>();
- if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
- return false;
- return true;
- };
- return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
- isStaticMemRefWithIdentityLayout));
+ return isContraction(op);
}
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
- if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
- // TODO: add a level of indirection to linalg.generic.
- if (convOp.padding())
- llvm_unreachable("Unexpected conv with padding");
- }
-
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
(void)dbgPref;
edsc::ScopedContext scope(builder, op->getLoc());
@@ -117,33 +111,47 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
return;
}
- // Vectorize other ops as vector contraction (currently only matmul).
+ assert(succeeded(isContraction(op)) && "Expected contraction");
+
+ // Vectorize other ops as vector contraction.
+ // TODO: interface.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
+ // In the case of 0-D memrefs, return null and special case to scalar load or
+ // store later.
auto extractVectorTypeFromScalarView = [](Value v) {
MemRefType mt = v.getType().cast<MemRefType>();
- return VectorType::get(mt.getShape(), mt.getElementType());
+ return mt.getShape().empty()
+ ? VectorType()
+ : VectorType::get(mt.getShape(), mt.getElementType());
};
auto linalgOp = cast<linalg::LinalgOp>(op);
Value viewA = linalgOp.getInput(0);
Value viewB = linalgOp.getInput(1);
Value viewC = linalgOp.getOutputBuffer(0);
+ VectorType vtA = extractVectorTypeFromScalarView(viewA);
+ VectorType vtB = extractVectorTypeFromScalarView(viewB);
+ VectorType vtC = extractVectorTypeFromScalarView(viewC);
Value zero = std_constant_index(0);
- SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
- zero);
- SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
- zero);
- SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
- zero);
- Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
- indicesA);
- Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
- indicesB);
- Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
- indicesC);
+ SmallVector<Value, 4> indicesA, indicesB, indicesC;
+ if (vtA)
+ indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
+ if (vtB)
+ indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
+ if (vtC)
+ indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
+ Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
+ : std_load(viewA, indicesA).value;
+ Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
+ : std_load(viewB, indicesB).value;
+ Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
+ : std_load(viewC, indicesC).value;
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
- vector_transfer_write(res, viewC, indicesC);
+ if (vtC)
+ vector_transfer_write(res, viewC, indicesC);
+ else
+ std_store(res, viewC, indicesC);
}
/// Check whether there is any interleaved use of any `values` between `firstOp`
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index cf75ee5691d0..b0702f9fdcfd 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
@@ -30,3 +31,38 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
//
// CHECK: linalg.copy
+
+// VECTOR-CONTRACTION-LABEL: contraction_dot
+func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
+ // VECTOR-CONTRACTION: vector.contract
+ // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
+ linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref<f32>
+ return
+}
+
+// VECTOR-CONTRACTION-LABEL: contraction_matvec
+func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
+ // VECTOR-CONTRACTION: vector.contract
+ // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
+ linalg.matvec %A, %B, %C :
+ (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
+ return
+}
+
+// VECTOR-CONTRACTION-LABEL: contraction_matmul
+func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
+ // VECTOR-CONTRACTION: vector.contract
+ // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
+ linalg.matmul %A, %B, %C :
+ (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
+ return
+}
+
+// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
+func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
+ // VECTOR-CONTRACTION: vector.contract
+ // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
+ linalg.batch_matmul %A, %B, %C :
+ (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f93cd9faa504..4fb378c5ab8a 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -54,6 +54,11 @@ struct TestLinalgTransforms
llvm::cl::desc(
"Test a fused pass that forwards linalg.copy to vector.transfer"),
llvm::cl::init(false)};
+ Option<bool> testGenericToVectorPattern{
+ *this, "test-contraction-to-vector-patterns",
+ llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
+ "in vector.contract form"),
+ llvm::cl::init(false)};
};
} // end anonymous namespace
@@ -300,6 +305,16 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
applyPatternsAndFoldGreedily(funcOp, forwardPattern);
}
+static void applyContractionToVectorPatterns(FuncOp funcOp) {
+ OwningRewritePatternList patterns;
+ patterns.insert<LinalgVectorizationPattern<BatchMatmulOp>,
+ LinalgVectorizationPattern<MatmulOp>,
+ LinalgVectorizationPattern<MatvecOp>,
+ LinalgVectorizationPattern<DotOp>,
+ LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
+ applyPatternsAndFoldGreedily(funcOp, patterns);
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
auto lambda = [&](void *) {
@@ -323,6 +338,8 @@ void TestLinalgTransforms::runOnFunction() {
testMatmulToVectorPatterns2dTiling);
if (testVectorTransferForwardingPatterns)
return applyVectorTransferForwardingPatterns(getFunction());
+ if (testGenericToVectorPattern)
+ return applyContractionToVectorPatterns(getFunction());
}
namespace mlir {
More information about the Mlir-commits
mailing list