[Mlir-commits] [mlir] 13f15e8 - [mlir][Vector] Fix vectorization of generic ops with transposed outputs
Diego Caballero
llvmlistbot at llvm.org
Mon Jun 26 13:24:48 PDT 2023
Author: Diego Caballero
Date: 2023-06-26T20:24:29Z
New Revision: 13f15e8f14cfd0b5795ef0c0164a913fdaa5378c
URL: https://github.com/llvm/llvm-project/commit/13f15e8f14cfd0b5795ef0c0164a913fdaa5378c
DIFF: https://github.com/llvm/llvm-project/commit/13f15e8f14cfd0b5795ef0c0164a913fdaa5378c.diff
LOG: [mlir][Vector] Fix vectorization of generic ops with transposed outputs
This patch fixes a bug in the way we compute the vector type for vector
transfer writes when the value to store needs to be transposed.
Reviewed By: nicolasvasilache, mravishankar
Differential Revision: https://reviews.llvm.org/D153687
Added:
Modified:
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 01cd7183e43c7..3430db2b99c3f 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -78,6 +78,20 @@ class AffineMap {
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results,
MLIRContext *context);
+ /// Returns an identity affine map witn `numDims` input dimensions and
+ /// filtered results using `keepDimFilter`. If `keepDimFilter` returns true
+ /// for a dimension, the dimension is kept in the affine map results.
+ /// Otherwise, the dimension is dropped from the results.
+ ///
+ /// Examples:
+ /// * getFilteredIdentityMap(4, [false, true, false, true])
+ /// -> affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+ /// * getFilteredIdentityMap(3, [false, false, true])
+ /// -> affine_map<(d0, d1, d2) -> (d2)>
+ static AffineMap
+ getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims,
+ llvm::function_ref<bool(AffineDimExpr)> keepDimFilter);
+
/// Returns an AffineMap representing a permutation.
/// The permutation is expressed as a non-empty vector of integers.
/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bbcde44f08618..45b35d3e01a6a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -605,8 +605,18 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
+
+ // Compute the vector type of the value to store. This type should be an
+ // identity or projection of the canonical vector type without any permutation
+ // applied, given that any permutation in a transfer write happens as part of
+ // the write itself.
+ AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
+ opOperandMap.getContext(), opOperandMap.getNumInputs(),
+ [&](AffineDimExpr dimExpr) -> bool {
+ return llvm::is_contained(opOperandMap.getResults(), dimExpr);
+ });
auto vectorType = state.getCanonicalVecType(
- getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap);
+ getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
Operation *write;
if (vectorType.getRank() > 0) {
@@ -614,13 +624,14 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType);
+ assert(value.getType() == vectorType && "Incorrect type");
write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), indices, writeMap);
} else {
// 0-d case is still special: do not invert the reindexing writeMap.
if (!isa<VectorType>(value.getType()))
value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
- assert(value.getType() == vectorType && "incorrect type");
+ assert(value.getType() == vectorType && "Incorrect type");
write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), ValueRange{});
}
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 4a67010b7a3a0..9cdac964710ca 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -113,6 +113,19 @@ AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
}
+AffineMap AffineMap::getFilteredIdentityMap(
+ MLIRContext *ctx, unsigned numDims,
+ llvm::function_ref<bool(AffineDimExpr)> keepDimFilter) {
+ auto identityMap = getMultiDimIdentityMap(numDims, ctx);
+
+ // Apply filter to results.
+ llvm::SmallBitVector dropDimResults(numDims);
+ for (auto [idx, resultExpr] : llvm::enumerate(identityMap.getResults()))
+ dropDimResults[idx] = !keepDimFilter(resultExpr.cast<AffineDimExpr>());
+
+ return identityMap.dropResults(dropDimResults);
+}
+
bool AffineMap::isMinorIdentity() const {
return getNumDims() >= getNumResults() &&
*this ==
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 130c6bcc11abb..933c9c7864988 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1751,3 +1751,38 @@ transform.sequence failures(propagate) {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
+// -----
+
+// Make sure we generate the right transfer writes for multi-output generic ops
+// with
diff erent permutation maps.
+
+func.func @multi_output_generic_
diff erent_perm_maps(%in0: tensor<4x1xf32>,
+ %out0: tensor<4x1xf32>,
+ %out1: tensor<1x4xf32>) -> (tensor<4x1xf32>, tensor<1x4xf32>) {
+ %13:2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)> ],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in0 : tensor<4x1xf32>)
+ outs(%out0, %out1 : tensor<4x1xf32>, tensor<1x4xf32>) {
+ ^bb0(%in: f32, %out: f32, %out_2: f32):
+ %16 = arith.addf %in, %in : f32
+ linalg.yield %16, %16 : f32, f32
+ } -> (tensor<4x1xf32>, tensor<1x4xf32>)
+ return %13#0, %13#1 : tensor<4x1xf32>, tensor<1x4xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op
+ %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @multi_output_generic_
diff erent_perm_maps
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_5]] : vector<4x1xf32>
+// CHECK: %[[VAL_7:.*]] = vector.transpose %[[VAL_6]], [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+// CHECK: vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32>
+// CHECK: vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
More information about the Mlir-commits
mailing list