[Mlir-commits] [mlir] ab15417 - [mlir] Support dimension permutations in ProgressiveVectorToSCF
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 23 01:47:11 PDT 2021
Author: Matthias Springer
Date: 2021-04-23T17:46:35+09:00
New Revision: ab154176bfc7891979b9cd406d2e952a1764f406
URL: https://github.com/llvm/llvm-project/commit/ab154176bfc7891979b9cd406d2e952a1764f406
DIFF: https://github.com/llvm/llvm-project/commit/ab154176bfc7891979b9cd406d2e952a1764f406.diff
LOG: [mlir] Support dimension permutations in ProgressiveVectorToSCF
This commit adds support for dimension permutations in permutation maps of vector transfer ops.
Differential Revision: https://reviews.llvm.org/D101007
Added:
Modified:
mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 785da3b330b9..2e4a0725b4de 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -71,8 +71,22 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
/// Given a vector transfer op, calculate which dimension of the `source`
/// memref should be unpacked in the next application of TransferOpConversion.
template <typename OpTy>
-static int64_t unpackedDim(OpTy xferOp) {
- return xferOp.getShapedType().getRank() - xferOp.getVectorType().getRank();
+static unsigned unpackedDim(OpTy xferOp) {
+ auto map = xferOp.permutation_map();
+ // TODO: Handle broadcast
+ auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
+ assert(expr && "Expected AffineDimExpr in permutation map result");
+ return expr.getPosition();
+}
+
+/// Compute the permutation map for the new (N-1)-D vector transfer op. This
+/// map is identical to the current permutation map, but the first result is
+/// omitted.
+template <typename OpTy>
+static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) {
+ auto map = xferOp.permutation_map();
+ return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
+ builder.getContext());
}
/// Calculate the indices for the new vector transfer op.
@@ -93,8 +107,8 @@ static void getXferIndices(OpTy xferOp, Value iv,
indices[dim] = adaptor.indices()[dim] + iv;
}
-static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
- Value value) {
+static void maybeYieldValue(
+ bool hasRetVal, OpBuilder builder, Location loc, Value value) {
if (hasRetVal) {
builder.create<scf::YieldOp>(loc, value);
} else {
@@ -124,7 +138,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
/// `resultTypes`.
template <typename OpTy>
static Value generateInBoundsCheck(
- OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
+ OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim,
TypeRange resultTypes,
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
@@ -136,19 +150,15 @@ static Value generateInBoundsCheck(
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
auto check = builder.create<scf::IfOp>(
xferOp.getLoc(), resultTypes, cond,
- /*thenBuilder=*/
- [&](OpBuilder &builder, Location loc) {
- maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
- },
- /*elseBuilder=*/
- [&](OpBuilder &builder, Location loc) {
- if (outOfBoundsCase) {
- maybeYieldValue(hasRetVal, builder, loc,
- outOfBoundsCase(builder, loc));
- } else {
- builder.create<scf::YieldOp>(loc);
- }
- });
+ /*thenBuilder=*/[&](OpBuilder &builder, Location loc) {
+ maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
+ }, /*elseBuilder=*/[&](OpBuilder &builder, Location loc) {
+ if (outOfBoundsCase) {
+ maybeYieldValue(hasRetVal, builder, loc, outOfBoundsCase(builder, loc));
+ } else {
+ builder.create<scf::YieldOp>(loc);
+ }
+ });
return hasRetVal ? check.getResult(0) : Value();
}
@@ -166,15 +176,13 @@ static void generateInBoundsCheck(
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
generateInBoundsCheck(
xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
- /*inBoundsCase=*/
- [&](OpBuilder &builder, Location loc) {
+ /*inBoundsCase=*/[&](OpBuilder &builder, Location loc) {
inBoundsCase(builder, loc);
return Value();
},
- /*outOfBoundsCase=*/
- [&](OpBuilder &builder, Location loc) {
+ /*outOfBoundsCase=*/[&](OpBuilder &builder, Location loc) {
if (outOfBoundsCase)
- outOfBoundsCase(builder, loc);
+ outOfBoundsCase(builder, loc);
return Value();
});
}
@@ -182,7 +190,7 @@ static void generateInBoundsCheck(
/// Given an ArrayAttr, return a copy where the first element is dropped.
static ArrayAttr dropFirstElem(PatternRewriter &rewriter, ArrayAttr attr) {
if (!attr)
- return attr;
+ return attr;
return ArrayAttr::get(rewriter.getContext(), attr.getValue().drop_front());
}
@@ -191,13 +199,14 @@ template <typename OpTy>
struct Strategy;
/// Code strategy for vector TransferReadOp.
-template <>
+template<>
struct Strategy<TransferReadOp> {
/// Find the StoreOp that is used for writing the current TransferReadOp's
/// result to the temporary buffer allocation.
static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
- auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
+ auto storeOp = dyn_cast<memref::StoreOp>(
+ (*xferOp->use_begin()).getOwner());
assert(storeOp && "Expected TransferReadOp result used by StoreOp");
return storeOp;
}
@@ -215,7 +224,7 @@ struct Strategy<TransferReadOp> {
/// Retrieve the indices of the current StoreOp.
static void getStoreIndices(TransferReadOp xferOp,
- SmallVector<Value, 8> &indices) {
+ SmallVector<Value, 8> &indices) {
auto storeOp = getStoreOp(xferOp);
auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
indices.append(prevIndices.begin(), prevIndices.end());
@@ -258,24 +267,25 @@ struct Strategy<TransferReadOp> {
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
- auto map = getTransferMinorIdentityMap(xferOp.getShapedType(), vecType);
auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr());
- auto newXfer = vector_transfer_read(vecType, xferOp.source(), xferIndices,
- AffineMapAttr::get(map),
- xferOp.padding(), Value(), inBoundsAttr)
- .value;
+ auto newXfer =
+ vector_transfer_read(
+ vecType, xferOp.source(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)),
+ xferOp.padding(), Value(), inBoundsAttr)
+ .value;
if (vecType.getRank() > kTargetRank)
- newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr());
+ newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr());
memref_store(newXfer, buffer, storeIndices);
}
/// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
/// padding value to the temporary buffer.
- static void handleOutOfBoundsDim(PatternRewriter &rewriter,
- TransferReadOp xferOp, Value buffer,
- Value iv) {
+ static void handleOutOfBoundsDim(
+ PatternRewriter &rewriter, TransferReadOp xferOp, Value buffer,
+ Value iv) {
SmallVector<Value, 8> storeIndices;
getStoreIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
@@ -294,7 +304,7 @@ struct Strategy<TransferReadOp> {
};
/// Codegen strategy for vector TransferWriteOp.
-template <>
+template<>
struct Strategy<TransferWriteOp> {
/// Find the temporary buffer allocation. All labeled TransferWriteOps are
/// used like this, where %buf is either the buffer allocation or a type cast
@@ -337,20 +347,20 @@ struct Strategy<TransferWriteOp> {
auto vec = memref_load(buffer, loadIndices);
auto vecType = vec.value.getType().dyn_cast<VectorType>();
- auto map = getTransferMinorIdentityMap(xferOp.getShapedType(), vecType);
auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr());
- auto newXfer =
- vector_transfer_write(Type(), vec, xferOp.source(), xferIndices,
- AffineMapAttr::get(map), Value(), inBoundsAttr);
+ auto newXfer = vector_transfer_write(
+ Type(), vec, xferOp.source(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), Value(),
+ inBoundsAttr);
if (vecType.getRank() > kTargetRank)
- newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr());
+ newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr());
}
/// Handle out-of-bounds accesses on the to-be-unpacked dimension.
- static void handleOutOfBoundsDim(PatternRewriter &rewriter,
- TransferWriteOp xferOp, Value buffer,
- Value iv) {}
+ static void handleOutOfBoundsDim(
+ PatternRewriter &rewriter, TransferWriteOp xferOp, Value buffer,
+ Value iv) {}
/// Cleanup after rewriting the op.
static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
@@ -361,13 +371,11 @@ struct Strategy<TransferWriteOp> {
template <typename OpTy>
LogicalResult checkPrepareXferOp(OpTy xferOp) {
if (xferOp->hasAttr(kPassLabel))
- return failure();
+ return failure();
if (xferOp.getVectorType().getRank() <= kTargetRank)
- return failure();
+ return failure();
if (xferOp.mask())
- return failure();
- if (!xferOp.permutation_map().isMinorIdentity())
- return failure();
+ return failure();
return success();
}
@@ -392,7 +400,8 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
/// memref.store %1, %0[] : memref<vector<5x4xf32>>
/// %vec = memref.load %0[] : memref<vector<5x4xf32>>
/// ```
-struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> {
+struct PrepareTransferReadConversion
+ : public OpRewritePattern<TransferReadOp> {
using OpRewritePattern<TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TransferReadOp xferOp,
@@ -481,7 +490,7 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
- return failure();
+ return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
// How the buffer can be found depends on OpTy.
@@ -491,20 +500,16 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
auto casted = vector_type_cast(castedType, buffer);
auto lb = std_constant_index(0).value;
- auto ub =
- std_constant_index(castedType.getDimSize(castedType.getRank() - 1))
- .value;
+ auto ub = std_constant_index(
+ castedType.getDimSize(castedType.getRank() - 1)).value;
affineLoopBuilder(lb, ub, 1, [&](Value iv) {
generateInBoundsCheck(
xferOp, iv, rewriter, unpackedDim(xferOp),
- /*inBoundsCase=*/
- [&](OpBuilder & /*b*/, Location loc) {
- Strategy<OpTy>::rewriteOp(rewriter, xferOp, casted, iv);
- },
- /*outOfBoundsCase=*/
- [&](OpBuilder & /*b*/, Location loc) {
- Strategy<OpTy>::handleOutOfBoundsDim(rewriter, xferOp, casted, iv);
- });
+ /*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
+ Strategy<OpTy>::rewriteOp(rewriter, xferOp, casted, iv);
+ }, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
+ Strategy<OpTy>::handleOutOfBoundsDim(rewriter, xferOp, casted, iv);
+ });
});
Strategy<OpTy>::cleanup(rewriter, xferOp);
@@ -541,25 +546,25 @@ struct Strategy1d;
/// Codegen strategy for TransferReadOp.
template <>
struct Strategy1d<TransferReadOp> {
- static void generateForLoopBody(OpBuilder &builder, Location loc,
- TransferReadOp xferOp, Value iv,
- ValueRange loopState) {
+ static void generateForLoopBody(
+ OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv,
+ ValueRange loopState) {
SmallVector<Value, 8> indices;
auto dim = get1dMemrefIndices(xferOp, iv, indices);
- auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
+ auto ivI32 = std_index_cast(
+ IntegerType::get(builder.getContext(), 32), iv);
auto vec = loopState[0];
// In case of out-of-bounds access, leave `vec` as is (was initialized with
// padding value).
auto nextVec = generateInBoundsCheck(
xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
- /*inBoundsCase=*/
- [&](OpBuilder & /*b*/, Location loc) {
- auto val = memref_load(xferOp.source(), indices);
- return vector_insert_element(val, vec, ivI32.value).value;
- },
- /*outOfBoundsCase=*/
- [&](OpBuilder & /*b*/, Location loc) { return vec; });
+ /*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
+ auto val = memref_load(xferOp.source(), indices);
+ return vector_insert_element(val, vec, ivI32.value).value;
+ }, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
+ return vec;
+ });
builder.create<scf::YieldOp>(loc, nextVec);
}
@@ -572,24 +577,27 @@ struct Strategy1d<TransferReadOp> {
/// Codegen strategy for TransferWriteOp.
template <>
struct Strategy1d<TransferWriteOp> {
- static void generateForLoopBody(OpBuilder &builder, Location loc,
- TransferWriteOp xferOp, Value iv,
- ValueRange /*loopState*/) {
+ static void generateForLoopBody(
+ OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv,
+ ValueRange /*loopState*/) {
SmallVector<Value, 8> indices;
auto dim = get1dMemrefIndices(xferOp, iv, indices);
- auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
+ auto ivI32 = std_index_cast(
+ IntegerType::get(builder.getContext(), 32), iv);
// Nothing to do in case of out-of-bounds access.
generateInBoundsCheck(
xferOp, iv, builder, dim,
- /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) {
- auto val = vector_extract_element(xferOp.vector(), ivI32.value);
- memref_store(val, xferOp.source(), indices);
- });
+ /*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
+ auto val = vector_extract_element(xferOp.vector(), ivI32.value);
+ memref_store(val, xferOp.source(), indices);
+ });
builder.create<scf::YieldOp>(loc);
}
- static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
+ static Value initialLoopState(TransferWriteOp xferOp) {
+ return Value();
+ }
};
/// Lower a 1D vector transfer op that operates on a dimension
diff erent from
@@ -623,11 +631,11 @@ struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
auto map = xferOp.permutation_map();
if (xferOp.getVectorType().getRank() != 1)
- return failure();
- if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
- return failure();
+ return failure();
+ if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
+ return failure();
if (xferOp.mask())
- return failure();
+ return failure();
// Loop bounds, step, state...
auto vecType = xferOp.getVectorType();
@@ -640,16 +648,16 @@ struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
rewriter.replaceOpWithNewOp<scf::ForOp>(
xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
- ScopedContext nestedScope(builder, loc);
- Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
- loopState);
- });
+ ScopedContext nestedScope(builder, loc);
+ Strategy1d<OpTy>::generateForLoopBody(
+ builder, loc, xferOp, iv, loopState);
+ });
return success();
}
};
-} // namespace
+} // namespace
namespace mlir {
@@ -657,10 +665,13 @@ void populateProgressiveVectorToSCFConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
TransferOpConversion<TransferReadOp>,
- TransferOpConversion<TransferWriteOp>,
- Strided1dTransferOpConversion<TransferReadOp>,
- Strided1dTransferOpConversion<TransferWriteOp>>(
- patterns.getContext());
+ TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+
+ if (kTargetRank == 1) {
+ patterns.add<Strided1dTransferOpConversion<TransferReadOp>,
+ Strided1dTransferOpConversion<TransferWriteOp>>(
+ patterns.getContext());
+ }
}
struct ConvertProgressiveVectorToSCFPass
@@ -672,8 +683,9 @@ struct ConvertProgressiveVectorToSCFPass
}
};
-} // namespace mlir
+} // namespace mlir
-std::unique_ptr<Pass> mlir::createProgressiveConvertVectorToSCFPass() {
+std::unique_ptr<Pass>
+mlir::createProgressiveConvertVectorToSCFPass() {
return std::make_unique<ConvertProgressiveVectorToSCFPass>();
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index a1e81a3cc6cd..fbcf94a6233c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -17,6 +17,16 @@ func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
return
}
+func @transfer_read_2d_transposed(
+ %A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %fm42
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+ memref<?x?xf32>, vector<4x9xf32>
+ vector.print %f: vector<4x9xf32>
+ return
+}
+
func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = constant -1.0 : f32
%vf0 = splat %fn1 : vector<1x4xf32>
@@ -53,12 +63,20 @@ func @entry() {
// On input, memory contains [[ 0, 1, 2, ...], [10, 11, 12, ...], ...]
// Read shifted by 2 and pad with -42:
call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+ // Same as above, but transposed
+ call @transfer_read_2d_transposed(%A, %c1, %c2)
+ : (memref<?x?xf32>, index, index) -> ()
// Write into memory shifted by 3
call @transfer_write_2d(%A, %c3, %c1) : (memref<?x?xf32>, index, index) -> ()
// Read shifted by 0 and pad with -42:
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ // Same as above, but transposed
+ call @transfer_read_2d_transposed(%A, %c0, %c0)
+ : (memref<?x?xf32>, index, index) -> ()
return
}
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index ef6857337775..7ecac4a38938 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -19,6 +19,16 @@ func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
return
}
+func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
+ %o: index, %a: index, %b: index, %c: index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42
+ {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)>}
+ : memref<?x?x?x?xf32>, vector<3x5x3xf32>
+ vector.print %f: vector<3x5x3xf32>
+ return
+}
+
func @transfer_write_3d(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fn1 = constant -1.0 : f32
@@ -66,8 +76,11 @@ func @entry() {
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+ call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
+ : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
return
}
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
+// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir
index 19bf28b57c3f..9488534d3e93 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir
@@ -3,6 +3,11 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
More information about the Mlir-commits
mailing list