[Mlir-commits] [mlir] a819e73 - [mlir] Support broadcast dimensions in ProgressiveVectorToSCF
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 23 02:01:59 PDT 2021
Author: Matthias Springer
Date: 2021-04-23T18:01:32+09:00
New Revision: a819e7339315687f06f686971a649f614afbd987
URL: https://github.com/llvm/llvm-project/commit/a819e7339315687f06f686971a649f614afbd987
DIFF: https://github.com/llvm/llvm-project/commit/a819e7339315687f06f686971a649f614afbd987.diff
LOG: [mlir] Support broadcast dimensions in ProgressiveVectorToSCF
This commit adds support for broadcast dimensions in permutation maps of vector transfer ops.
Also fixes a bug in VectorToSCF that generated incorrect in-bounds checks for broadcast dimensions.
Differential Revision: https://reviews.llvm.org/D101019
Added:
Modified:
mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 05f8dc4b4856..3eb6072e7979 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -70,13 +70,16 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
/// Given a vector transfer op, calculate which dimension of the `source`
/// memref should be unpacked in the next application of TransferOpConversion.
+/// A return value of None indicates a broadcast.
template <typename OpTy>
-static unsigned unpackedDim(OpTy xferOp) {
+static Optional<int64_t> unpackedDim(OpTy xferOp) {
auto map = xferOp.permutation_map();
- // TODO: Handle broadcast
- auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
- assert(expr && "Expected AffineDimExpr in permutation map result");
- return expr.getPosition();
+ if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>())
+ return expr.getPosition();
+
+ assert(map.getResult(0).template isa<AffineConstantExpr>() &&
+ "Expected AffineDimExpr or AffineConstantExpr");
+ return None;
}
/// Compute the permutation map for the new (N-1)-D vector transfer op. This
@@ -103,8 +106,12 @@ static void getXferIndices(OpTy xferOp, Value iv,
auto dim = unpackedDim(xferOp);
auto prevIndices = adaptor.indices();
indices.append(prevIndices.begin(), prevIndices.end());
- using edsc::op::operator+;
- indices[dim] = adaptor.indices()[dim] + iv;
+
+ bool isBroadcast = !dim.hasValue();
+ if (!isBroadcast) {
+ using edsc::op::operator+;
+ indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv;
+ }
}
static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
@@ -116,7 +123,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
}
}
-/// Helper function TransferOpConversion and Strided1dTransferOpConversion.
+/// Helper function TransferOpConversion and TransferOp1dConversion.
/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
/// specified dimension `dim` with the loop iteration variable `iv`.
/// E.g., when unpacking dimension 0 from:
@@ -138,15 +145,17 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
/// `resultTypes`.
template <typename OpTy>
static Value generateInBoundsCheck(
- OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim,
+ OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
TypeRange resultTypes,
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
bool hasRetVal = !resultTypes.empty();
- if (!xferOp.isDimInBounds(0)) {
- auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim));
+ bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
+ if (!xferOp.isDimInBounds(0) && !isBroadcast) {
+ auto memrefDim =
+ memref_dim(xferOp.source(), std_constant_index(dim.getValue()));
using edsc::op::operator+;
- auto memrefIdx = xferOp.indices()[dim] + iv;
+ auto memrefIdx = xferOp.indices()[dim.getValue()] + iv;
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
auto check = builder.create<scf::IfOp>(
xferOp.getLoc(), resultTypes, cond,
@@ -175,7 +184,7 @@ static Value generateInBoundsCheck(
/// a return value. Consequently, this function does not have a return value.
template <typename OpTy>
static void generateInBoundsCheck(
- OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
+ OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
function_ref<void(OpBuilder &, Location)> inBoundsCase,
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
generateInBoundsCheck(
@@ -534,27 +543,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
};
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
-/// part of Strided1dTransferOpConversion. Return the memref dimension on which
-/// the transfer is operating.
+/// part of TransferOp1dConversion. Return the memref dimension on which
+/// the transfer is operating. A return value of None indicates a broadcast.
template <typename OpTy>
-static unsigned get1dMemrefIndices(OpTy xferOp, Value iv,
- SmallVector<Value, 8> &memrefIndices) {
+static Optional<int64_t>
+get1dMemrefIndices(OpTy xferOp, Value iv,
+ SmallVector<Value, 8> &memrefIndices) {
auto indices = xferOp.indices();
auto map = xferOp.permutation_map();
memrefIndices.append(indices.begin(), indices.end());
assert(map.getNumResults() == 1 &&
"Expected 1 permutation map result for 1D transfer");
- // TODO: Handle broadcast
- auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
- assert(expr && "Expected AffineDimExpr in permutation map result");
- auto dim = expr.getPosition();
- using edsc::op::operator+;
- memrefIndices[dim] = memrefIndices[dim] + iv;
- return dim;
+ if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
+ auto dim = expr.getPosition();
+ using edsc::op::operator+;
+ memrefIndices[dim] = memrefIndices[dim] + iv;
+ return dim;
+ }
+
+ assert(map.getResult(0).template isa<AffineConstantExpr>() &&
+ "Expected AffineDimExpr or AffineConstantExpr");
+ return None;
}
-/// Codegen strategy for Strided1dTransferOpConversion, depending on the
+/// Codegen strategy for TransferOp1dConversion, depending on the
/// operation.
template <typename OpTy>
struct Strategy1d;
@@ -613,14 +626,24 @@ struct Strategy1d<TransferWriteOp> {
static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
};
-/// Lower a 1D vector transfer op that operates on a dimension
diff erent from
-/// the last one. Instead of accessing contiguous chunks (vectors) of memory,
-/// such ops access memory in a strided fashion.
+/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
+/// necessary in cases where a 1D vector transfer op cannot be lowered into
+/// vector load/stores due to non-unit strides or broadcasts:
+///
+/// * Transfer dimension is not the last memref dimension
+/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
+/// * Memref has a layout map with non-unit stride on the last dimension
+///
+/// This pattern generates IR as follows:
///
/// 1. Generate a for loop iterating over each vector element.
/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
/// depending on OpTy.
///
+/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
+/// can be generated instead of TransferOp1dConversion. Add such a pattern
+/// to ConvertVectorToLLVM.
+///
/// E.g.:
/// ```
/// vector.transfer_write %vec, %A[%a, %b]
@@ -635,7 +658,7 @@ struct Strategy1d<TransferWriteOp> {
/// }
/// ```
template <typename OpTy>
-struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
+struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy xferOp,
@@ -681,8 +704,8 @@ void populateProgressiveVectorToSCFConversionPatterns(
TransferOpConversion<TransferWriteOp>>(patterns.getContext());
if (kTargetRank == 1) {
- patterns.add<Strided1dTransferOpConversion<TransferReadOp>,
- Strided1dTransferOpConversion<TransferWriteOp>>(
+ patterns.add<TransferOp1dConversion<TransferReadOp>,
+ TransferOp1dConversion<TransferWriteOp>>(
patterns.getContext());
}
}
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 72d32d071e49..4f13e7d8e5af 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -230,7 +230,10 @@ emitInBoundsCondition(PatternRewriter &rewriter,
Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off);
- if (!xferOp.isDimInBounds(leadingRank + idx)) {
+ auto affineConstExpr =
+ xferOp.permutation_map().getResult(idx).dyn_cast<AffineConstantExpr>();
+ bool isBroadcast = affineConstExpr && affineConstExpr.getValue() == 0;
+ if (!xferOp.isDimInBounds(leadingRank + idx) && !isBroadcast) {
Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub);
if (inBoundsCond)
inBoundsCondition = (inBoundsCondition)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 17f635f7b78a..b6bd8c404158 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -10,6 +10,15 @@
// Test for special cases of 1D vector transfer ops.
+func @transfer_read_2d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %fm42
+ {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
+ : memref<?x?xf32>, vector<5x6xf32>
+ vector.print %f: vector<5x6xf32>
+ return
+}
+
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %fm42
@@ -19,6 +28,16 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
return
}
+func @transfer_read_1d_broadcast(
+ %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %fm42
+ {permutation_map = affine_map<(d0, d1) -> (0)>}
+ : memref<?x?xf32>, vector<9xf32>
+ vector.print %f: vector<9xf32>
+ return
+}
+
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = constant -1.0 : f32
%vf0 = splat %fn1 : vector<7xf32>
@@ -53,8 +72,11 @@ func @entry() {
call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
+ call @transfer_read_1d_broadcast(%A, %c1, %c2)
+ : (memref<?x?xf32>, index, index) -> ()
return
}
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
+// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index fbcf94a6233c..cbe0aa52a437 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -27,6 +27,16 @@ func @transfer_read_2d_transposed(
return
}
+func @transfer_read_2d_broadcast(
+ %A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %fm42
+ {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
+ memref<?x?xf32>, vector<4x9xf32>
+ vector.print %f: vector<4x9xf32>
+ return
+}
+
func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = constant -1.0 : f32
%vf0 = splat %fn1 : vector<1x4xf32>
@@ -73,6 +83,9 @@ func @entry() {
// Same as above, but transposed
call @transfer_read_2d_transposed(%A, %c0, %c0)
: (memref<?x?xf32>, index, index) -> ()
+ // Second vector dimension is a broadcast
+ call @transfer_read_2d_broadcast(%A, %c1, %c2)
+ : (memref<?x?xf32>, index, index) -> ()
return
}
@@ -80,3 +93,4 @@ func @entry() {
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
+// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index 7ecac4a38938..ae7fee3c9110 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -19,6 +19,16 @@ func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
return
}
+func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
+ %o: index, %a: index, %b: index, %c: index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42
+ {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>}
+ : memref<?x?x?x?xf32>, vector<2x5x3xf32>
+ vector.print %f: vector<2x5x3xf32>
+ return
+}
+
func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fm42 = constant -42.0: f32
@@ -78,9 +88,12 @@ func @entry() {
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+ call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
+ : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
return
}
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
+// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )
More information about the Mlir-commits
mailing list