[Mlir-commits] [mlir] 1694765 - [mlir][linalg] Extend linalg vectorization to support non-identity input maps
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 18 12:33:00 PDT 2021
Author: thomasraoux
Date: 2021-03-18T12:32:35-07:00
New Revision: 16947650d5ca602d63d5cd64e68bb0bb0f3674b7
URL: https://github.com/llvm/llvm-project/commit/16947650d5ca602d63d5cd64e68bb0bb0f3674b7
DIFF: https://github.com/llvm/llvm-project/commit/16947650d5ca602d63d5cd64e68bb0bb0f3674b7.diff
LOG: [mlir][linalg] Extend linalg vectorization to support non-identity input maps
This propagates the affine map to transfer_read op in case it is not a
minor identity map.
Differential Revision: https://reviews.llvm.org/D98523
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 880e7f385724..dab32d2e2727 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -87,11 +87,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
/// Build a vector.transfer_read from `source` at indices set to all `0`.
/// If source has rank zero, build an memref.load.
/// Return the produced value.
-static Value buildVectorRead(OpBuilder &builder, Value source) {
+static Value buildVectorRead(OpBuilder &builder, Value source,
+ VectorType vectorType, AffineMap map) {
edsc::ScopedContext scope(builder);
auto shapedType = source.getType().cast<ShapedType>();
- if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+ if (vectorType) {
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+ if (map)
+ return vector_transfer_read(vectorType, source, indices, map);
return vector_transfer_read(vectorType, source, indices);
}
return memref_load(source);
@@ -238,6 +241,51 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
builder.createOperation(state)};
}
+/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
+static bool hasOnlyScalarElementwiseOp(Region &r) {
+ if (!llvm::hasSingleElement(r))
+ return false;
+ for (Operation &op : r.front()) {
+ if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+ OpTrait::hasElementwiseMappableTraits(&op)) ||
+ llvm::any_of(op.getResultTypes(),
+ [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+ return false;
+ }
+ return true;
+}
+
+// Return true if the op is an element-wise linalg op.
+static bool isElementwise(Operation *op) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp)
+ return false;
+ if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
+ return false;
+ // TODO: relax the restrictions on indexing map.
+ for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
+ if (!linalgOp.getOutputIndexingMap(i).isIdentity())
+ return false;
+ }
+ if (linalgOp->getNumRegions() != 1)
+ return false;
+ return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
+}
+
+// Calculate the map to apply to transfer_read to convert the input shape into
+// the output shape.
+static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
+ AffineMap linalgMap = linalgOp.getIndexingMap(argIndex);
+ MLIRContext *context = linalgMap.getContext();
+ AffineExpr zero = mlir::getAffineConstantExpr(0, context);
+ SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero);
+ for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) {
+ exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context);
+ }
+ return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs,
+ context);
+}
+
/// Generic vectorization function that rewrites the body of a `linalgOp` into
/// vector form. Generic vectorization proceeds as follows:
/// 1. The region for the linalg op is created if necessary.
@@ -282,7 +330,19 @@ LogicalResult vectorizeAsLinalgGeneric(
SmallVector<AffineMap> indexings;
for (auto bbarg : block->getArguments()) {
Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
- Value vectorRead = buildVectorRead(builder, vectorArg);
+ AffineMap map;
+ VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg);
+ if (isElementwise(linalgOp) &&
+ !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) {
+ // Currently assume we don't support output permutations.
+ assert(linalgOp.getNumOutputs() > 0 &&
+ linalgOp.getOutputIndexingMap(0).isIdentity());
+ ArrayRef<int64_t> outputShape =
+ linalgOp.getOutputShapedType(0).getShape();
+ vectorType = VectorType::get(outputShape, vectorType.getElementType());
+ map = getTransferReadMap(linalgOp, bbarg.getArgNumber());
+ }
+ Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << vectorRead);
bvm.map(bbarg, vectorRead);
@@ -316,44 +376,6 @@ LogicalResult vectorizeAsLinalgGeneric(
return success();
}
-/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
-static bool hasOnlyScalarElementwiseOp(Region &r) {
- if (!llvm::hasSingleElement(r))
- return false;
- for (Operation &op : r.front()) {
- if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
- OpTrait::hasElementwiseMappableTraits(&op)) ||
- llvm::any_of(op.getResultTypes(),
- [](Type type) { return !type.isIntOrIndexOrFloat(); }))
- return false;
- }
- return true;
-}
-
-// Return true if the op is an element-wise linalg op.
-static bool isElementwise(Operation *op) {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (!linalgOp)
- return false;
- if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
- return false;
- // TODO: relax the restrictions on indexing map.
- for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
- if (!linalgOp.getOutputIndexingMap(i).isIdentity())
- return false;
- }
- // Currently bound the input indexing map to minor identity as other
- // permutations might require adding transpose ops to convert the vector read
- // to the right shape.
- for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
- if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
- return false;
- }
- if (linalgOp->getNumRegions() != 1)
- return false;
- return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
-}
-
static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
assert(isaContractionOpInterface(linalgOp) &&
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 6ca28ba681ef..08bf7628e8c0 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2294,8 +2294,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 2> elidedAttrs;
- if (op.permutation_map() ==
- getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
+ if (op.permutation_map().isMinorIdentity())
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideMasked = true;
if (auto maybeMasked = op.masked()) {
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9de80e96d451..98ca45bbb6f6 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -106,8 +106,9 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
}
bool AffineMap::isMinorIdentity() const {
- return *this ==
- getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
+ return getNumDims() >= getNumResults() &&
+ *this ==
+ getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
}
/// Returns true if this affine map is a minor identity up to broadcasted
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c43bf07d775d..74ff4367724e 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -341,6 +341,42 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
// -----
+// Test
diff erent input maps.
+#matmul_trait = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (0, d1, 0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
+// CHECK: func @vectorization_transpose
+// CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP1]]} : memref<16x14xf32>, vector<7x14x8x16xf32>
+// CHECK: vector.transfer_read {{.*}}{permutation_map = #[[MAP2]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32>
+// CHECK: addf {{.*}} : vector<7x14x8x16xf32>
+// CHECK: addf {{.*}} : vector<7x14x8x16xf32>
+// CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32>
+func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>,
+ %C: memref<16x14x7x8xf32>, %D: memref<7x14x8x16xf32>) {
+ linalg.generic #matmul_trait
+ ins(%A, %B, %C : memref<14x7xf32>, memref<16x14xf32>, memref<16x14x7x8xf32>)
+ outs(%D : memref<7x14x8x16xf32>) {
+ ^bb(%a: f32, %b: f32, %c: f32, %d: f32) :
+ %e = addf %a, %b: f32
+ %f = addf %e, %c: f32
+ linalg.yield %f : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @matmul_tensors
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
More information about the Mlir-commits
mailing list