[llvm-branch-commits] [mlir] cf21667 - [mlir][linalg] Add vectorization for linalg on tensor ops
Thomas Raoux via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Dec 29 09:06:50 PST 2020
Author: Thomas Raoux
Date: 2020-12-29T09:02:23-08:00
New Revision: cf216670a0bd1f2ce561a315e00649740f117e1c
URL: https://github.com/llvm/llvm-project/commit/cf216670a0bd1f2ce561a315e00649740f117e1c
DIFF: https://github.com/llvm/llvm-project/commit/cf216670a0bd1f2ce561a315e00649740f117e1c.diff
LOG: [mlir][linalg] Add vectorization for linalg on tensor ops
Support vectorization of linalg ops using tensor inputs/outputs.
Differential Revision: https://reviews.llvm.org/D93890
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23e452df9184..2a1d4cd2ef57 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -119,34 +119,38 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
-static VectorType extractVectorTypeFromScalarView(Value v) {
- MemRefType mt = v.getType().cast<MemRefType>();
- return mt.getShape().empty()
- ? VectorType()
- : VectorType::get(mt.getShape(), mt.getElementType());
+static VectorType extractVectorTypeFromShapedValue(Value v) {
+ auto st = v.getType().cast<ShapedType>();
+ if (st.isa<MemRefType>() && st.getShape().empty())
+ return VectorType();
+ return VectorType::get(st.getShape(), st.getElementType());
}
-static Value transferReadVector(OpBuilder &builder, Value memref) {
+static Value transferReadVector(OpBuilder &builder, Value source) {
edsc::ScopedContext scope(builder);
- auto memrefType = memref.getType().cast<MemRefType>();
- if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
- SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
- return vector_transfer_read(vectorType, memref, indices);
+ auto shapedType = source.getType().cast<ShapedType>();
+ if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+ SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
+ return vector_transfer_read(vectorType, source, indices);
}
- return std_load(memref);
+ return std_load(source);
}
-static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
+static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
edsc::ScopedContext scope(builder);
- auto memrefType = memref.getType().cast<MemRefType>();
- if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
- SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+ Operation *write;
+ auto shapedType = dest.getType().cast<ShapedType>();
+ if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
+ SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
if (vectorType != value.getType())
value = vector_broadcast(vectorType, value);
- vector_transfer_write(value, memref, indices);
+ write = vector_transfer_write(value, dest, indices);
} else {
- std_store(value, memref);
+ write = std_store(value, dest);
}
+ if (!write->getResults().empty())
+ return write->getResult(0);
+ return Value();
}
namespace {
@@ -167,10 +171,12 @@ class GenericVectorizer {
void vectorize(Operation &scalarOp) {
auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
if (yieldOp) {
- for (auto outputAndMemref :
- llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
- Value vectorValue = vectorize(std::get<0>(outputAndMemref));
- transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
+ for (auto outputs : llvm::enumerate(yieldOp.values())) {
+ Value vectorValue = vectorize(outputs.value());
+ Value result = transferWriteVector(builder, vectorValue,
+ generic.getOutput(outputs.index()));
+ if (result)
+ results.push_back(result);
}
return;
}
@@ -182,6 +188,8 @@ class GenericVectorizer {
}
}
+ llvm::ArrayRef<Value> getResults() { return results; }
+
private:
// Transforms a scalar value into its vectorized counterpart, recursively
// vectorizing operations as necessary using the underlying builder.
@@ -261,6 +269,7 @@ class GenericVectorizer {
OpBuilder &builder;
linalg::GenericOp generic;
llvm::DenseMap<Value, Value> valueCache;
+ SmallVector<Value, 8> results;
};
} // namespace
@@ -271,6 +280,8 @@ static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
for (Operation &scalarOp : op.region().front()) {
vectorizer.vectorize(scalarOp);
}
+ if (!op->getResults().empty())
+ op->replaceAllUsesWith(vectorizer.getResults());
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -331,32 +342,14 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
auto linalgOp = cast<linalg::LinalgOp>(op);
- Value viewA = linalgOp.getInput(0);
- Value viewB = linalgOp.getInput(1);
- Value viewC = linalgOp.getOutputBuffer(0);
- VectorType vtA = extractVectorTypeFromScalarView(viewA);
- VectorType vtB = extractVectorTypeFromScalarView(viewB);
- VectorType vtC = extractVectorTypeFromScalarView(viewC);
- Value zero = std_constant_index(0);
- SmallVector<Value, 4> indicesA, indicesB, indicesC;
- if (vtA)
- indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
- if (vtB)
- indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
- if (vtC)
- indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
- Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
- : std_load(viewA, indicesA).value;
- Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
- : std_load(viewB, indicesB).value;
- Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
- : std_load(viewC, indicesC).value;
+ Value a = transferReadVector(builder, linalgOp.getInput(0));
+ Value b = transferReadVector(builder, linalgOp.getInput(1));
+ Value c = transferReadVector(builder, linalgOp.getOutput(0));
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
- if (vtC)
- vector_transfer_write(res, viewC, indicesC);
- else
- std_store(res, viewC, indicesC);
+ Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
+ if (writeResult)
+ linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
}
/// Check whether there is any interleaved use of any `values` between `firstOp`
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index a6a353c97977..5acc5f5a7ce2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2039,28 +2039,28 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
/// Builder that sets padding to zero.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vector, Value memref, ValueRange indices,
+ VectorType vector, Value source, ValueRange indices,
AffineMap permutationMap,
ArrayRef<bool> maybeMasked) {
- Type elemType = memref.getType().cast<MemRefType>().getElementType();
+ Type elemType = source.getType().cast<ShapedType>().getElementType();
Value padding = builder.create<ConstantOp>(result.location, elemType,
builder.getZeroAttr(elemType));
if (maybeMasked.empty())
- return build(builder, result, vector, memref, indices, permutationMap,
+ return build(builder, result, vector, source, indices, permutationMap,
padding, ArrayAttr());
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
- build(builder, result, vector, memref, indices, permutationMap, padding,
+ build(builder, result, vector, source, indices, permutationMap, padding,
maskedArrayAttr);
}
/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
/// (resp. zero).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value memref,
+ VectorType vectorType, Value source,
ValueRange indices, ArrayRef<bool> maybeMasked) {
auto permMap = getTransferMinorIdentityMap(
- memref.getType().cast<MemRefType>(), vectorType);
- build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
+ source.getType().cast<ShapedType>(), vectorType);
+ build(builder, result, vectorType, source, indices, permMap, maybeMasked);
}
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
@@ -2251,7 +2251,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<bool> maybeMasked) {
auto vectorType = vector.getType().cast<VectorType>();
auto permMap = getTransferMinorIdentityMap(
- source.getType().cast<MemRefType>(), vectorType);
+ source.getType().cast<ShapedType>(), vectorType);
if (maybeMasked.empty())
return build(builder, result, vector, source, indices, permMap,
ArrayAttr());
@@ -2327,7 +2327,7 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) {
}
static LogicalResult verify(TransferWriteOp op) {
- // Consistency of elemental types in memref and vector.
+ // Consistency of elemental types in shape and vector.
ShapedType shapedType = op.getShapedType();
VectorType vectorType = op.getVectorType();
auto permutationMap = op.permutation_map();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 6019dde49983..cf49e85a4a28 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -189,7 +189,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
-// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -209,3 +209,108 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+
+func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
+ %arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
+ %i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) {
+ %c1_f32 = constant 1.0 : f32
+ %r:10 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1, %arg2: tensor<4x256xf32>, tensor<256xf32>)
+ outs(
+ %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
+ %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
+ %arg14 : f32):
+ %6 = addf %arg4, %arg6 : f32
+ %7 = cmpf "ogt", %arg3, %arg6 : f32
+ %8 = constant 2.0 : f32
+ %9 = divf %arg5, %i : f32
+ %10 = exp2 %arg5 : f32
+ %11 = mulf %arg5, %8 : f32
+ %12 = rsqrt %arg5 : f32
+ %13 = select %7, %arg5, %arg6 : f32
+ %14 = subf %arg5, %arg4 : f32
+ %15 = tanh %arg5 : f32
+ linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32
+ } -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+ return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+}
+
+// CHECK-LABEL: func @generic_vectorize_tensor
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
+// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
+// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
+// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
+// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
+// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
+// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
+// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
+// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
+// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
+// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
+// CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+
+func @matmul_tensors(
+ %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
+ -> tensor<8x12xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
+ outs(%arg2: tensor<8x12xf32>)
+ -> tensor<8x12xf32>
+ return %0 : tensor<8x12xf32>
+}
+
+// CHECK-LABEL: func @matmul_tensors
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
+// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
+// CHECK: return %[[W]] : tensor<8x12xf32>
More information about the llvm-branch-commits
mailing list