[Mlir-commits] [mlir] f9c8feb - [mlir] Added support for symbols inside linalg.generic and map concatenation
Alex Zinenko
llvmlistbot at llvm.org
Mon Jul 20 10:20:56 PDT 2020
Author: Jakub Lichman
Date: 2020-07-20T19:20:47+02:00
New Revision: f9c8febc522c2d26a44d4881f015e0e11e4f9167
URL: https://github.com/llvm/llvm-project/commit/f9c8febc522c2d26a44d4881f015e0e11e4f9167
DIFF: https://github.com/llvm/llvm-project/commit/f9c8febc522c2d26a44d4881f015e0e11e4f9167.diff
LOG: [mlir] Added support for symbols inside linalg.generic and map concatenation
This commit adds functionality needed for implementation of convolutions with
linalg.generic op. Since linalg.generic right now expects indexing maps to be
just permutations, offset indexing needed in convolutions is not possible.
Therefore in this commit we address the issue by adding support for symbols inside
indexing maps which enables more advanced indexing. The upcoming commit will
solve the problem of computing loop bounds from such maps.
Differential Revision: https://reviews.llvm.org/D83158
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/AffineExpr.h
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/lib/IR/AffineExpr.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/lib/Transforms/TestBufferPlacement.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 7d259fde05e7..81f911a37cea 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -485,7 +485,9 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
AffineMapArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
OptionalAttr<StrAttr>:$doc,
- OptionalAttr<StrAttr>:$library_call);
+ OptionalAttr<StrAttr>:$library_call,
+ Confined<OptionalAttr<I64Attr>,
+ [IntMinValue<0>]>:$symbol_source);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
@@ -493,7 +495,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
return SmallVector<StringRef, 8>{
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
getIndexingMapsAttrName(), getLibraryCallAttrName(),
- getIteratorTypesAttrName()
+ getIteratorTypesAttrName(), getSymbolSourceAttrName()
};
}
@@ -508,12 +510,18 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
llvm_unreachable(
"No such thing as reference iterator types for a generic op.");
- }
+ }
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
llvm_unreachable(
"No such thing as reference indexing maps for a generic op.");
- }
+ }
+
+ llvm::Optional<unsigned> getSymbolSource() {
+ auto ss = symbol_source();
+ return ss.hasValue() ?
+ llvm::Optional<unsigned>(ss.getValue().getLimitedValue()) : llvm::None;
+ }
}];
let printer = [{ return ::print(p, *this); }];
@@ -549,6 +557,10 @@ def GenericOp : GenericOpBase<"generic"> {
Each element of the list represents and iterator of one of the following
types:
parallel, reduction, window
+ - symbol_source: index of the operand whose dimensions will be propagated
+ as symbols to the indexing maps. When specified the number of symbols
+ in each of the indexing maps has to be either 0 or the rank of the
+ specified operand.
Example:
Defining a #matmul_trait attribute in MLIR can be done as follows:
@@ -630,6 +642,35 @@ def GenericOp : GenericOpBase<"generic"> {
escape naturally. Still, transformations and rewrites that take advantage of
tensor SSA values are expected to be useful and will be added in the near
future.
+
+ Example of 1D convolution with symbols:
+ ```mlir
+ #conv_1d_accesses = [
+ affine_map<(m, n)[dimN] -> (m + n - dimN floordiv 2)>, // in
+ affine_map<(m, n)[dimN] -> (n)>, // filter
+ affine_map<(m, n)[dimN] -> (m)> // out
+ ]
+
+ #conv_1d_trait = {
+ doc = "O(m) += I(m + n - size(n) floordiv 2) * K(n)",
+ indexing_maps = #conv_1d_accesses,
+ library_call = "linalg_conv_1d",
+ iterator_types = ["parallel", "parallel"],
+ symbol_source = 1
+ }
+
+ linalg.generic #conv_1d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?xf32>,
+ memref<?xf32>,
+ memref<?xf32>
+ ```
+ where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first
+ and only dimension of the second operand as specified by the symbol_source
+ attribute.
}];
let builders = [
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ab85cebee178..76d570a50572 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -101,11 +101,28 @@ template <typename ConcreteOp>
SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) {
auto loc = linalgOp.getLoc();
SmallVector<Value, 8> res;
+ SmallVector<unsigned, 4> ranks;
for (auto v : linalgOp.getInputsAndOutputBuffers()) {
MemRefType t = v.getType().template cast<MemRefType>();
+ ranks.push_back(t.getRank());
for (unsigned i = 0; i < t.getRank(); ++i)
res.push_back(builder.create<DimOp>(loc, v, i));
}
+
+ auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
+ if (attr) {
+ // Find the correct position for inserting values for symbols.
+ unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
+ for (unsigned idx = 0; idx < attr.getInt(); idx++)
+ symbolsPos += ranks[idx];
+
+ // Append or rewrite the end of the value list that corresponds to the
+ // values mapping to symbols. Since inside concatinated map symbols are
+ // repeated we have to repeat the sizes as well.
+ for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
+ for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
+ res.push_back(res[symbolsPos + idx2]);
+ }
return res;
}
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 168e877e5056..41d614251936 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -46,6 +46,10 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
return indexingMaps == maps;
}
+/// Attribute name for the IntegerAttr which encodes the index of operand
+/// whose dimensions will be propagated as symbols to the indexing maps
+constexpr StringRef getSymbolSourceAttrName() { return "symbol_source"; }
+
/// Attribute name for the AffineArrayAttr which encodes the relationship
/// between a structured op iterators' and its operands.
constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 2302abd8554d..2df16ee2bfc9 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -118,6 +118,10 @@ class AffineExpr {
AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
ArrayRef<AffineExpr> symReplacements) const;
+ /// Replace symbols[0 .. numDims - 1] by
+ /// symbols[shift .. shift + numDims - 1].
+ AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;
+
AffineExpr operator+(int64_t v) const;
AffineExpr operator+(AffineExpr other) const;
AffineExpr operator-() const;
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index b9ec01d3ec79..44eb5e723075 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -69,7 +69,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
- StringAttr() /*library_call*/
+ StringAttr() /*library_call*/,
+ IntegerAttr() /*symbol_source*/
/* TODO: other attributes in op */
)
.getOperation();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 528e856fe5bb..32e0b12c22b5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -79,7 +79,8 @@ void GenericOp::build(
builder.getI64IntegerAttr(argsOut),
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
- /*doc=*/nullptr, /*library_call=*/nullptr);
+ /*doc=*/nullptr, /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
if (!bodyBuild)
return;
@@ -103,7 +104,8 @@ void IndexedGenericOp::build(
builder.getI64IntegerAttr(argsOut),
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
- /*doc=*/nullptr, /*library_call=*/nullptr);
+ /*doc=*/nullptr, /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
if (!bodyBuild)
return;
@@ -257,6 +259,15 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
return failure();
+ auto attr = op.template getAttrOfType<IntegerAttr>("symbol_source");
+ int64_t targetRank = 0;
+ if (attr) {
+ unsigned index = attr.getInt();
+ if (index >= op.getNumOperands())
+ return op.emitOpError("symbol_source index out of range");
+ targetRank = op.getShapedType(index).getRank();
+ }
+
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.reserve(op.indexing_maps().size());
for (auto en : llvm::enumerate(op.indexing_maps())) {
@@ -266,9 +277,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
: op.getOutputShapedType(idx - nInputViews);
- if (m.getNumSymbols() != 0)
- return op.emitOpError("expected indexing_map #")
- << idx << " to have no symbols";
+ if (m.getNumSymbols() != targetRank)
+ return op.emitOpError("expected the number of symbols in indexing_map #")
+ << idx << " to match target rank";
if (m.getNumDims() != nLoops)
return op.emitOpError("expected indexing_map #")
@@ -281,8 +292,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
}
auto concatMap = concatAffineMaps(indexingMaps);
- auto aggregateMap = inversePermutation(concatMap);
- if (!aggregateMap)
+ // TODO: Bound inference for maps with symbols
+ if (!concatMap.getNumSymbols() && !inversePermutation(concatMap))
return op.emitOpError("expected the concatenation of maps in indexing_map "
"to be invertible");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e08c43d48ba0..65fd197a6661 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -319,7 +319,8 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps),
genericOp.iterator_types(),
/*doc = */ nullptr,
- /*library_call = */ nullptr);
+ /*library_call = */ nullptr,
+ /*symbol_source = */ nullptr);
rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
replacementOp.region().begin());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index d67126c21f3e..82dfa75fc1f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -510,7 +510,8 @@ struct FuseGenericOpsOnTensors {
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
- /*library_call=*/nullptr)
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr)
.getOperation();
} else {
fusedOp =
@@ -524,7 +525,8 @@ struct FuseGenericOpsOnTensors {
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
- /*library_call=*/nullptr)
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr)
.getOperation();
}
@@ -787,7 +789,8 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
rewriter.getI64IntegerAttr(consumer.getNumResults()),
rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
/*doc=*/nullptr,
- /*library_call=*/nullptr);
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
auto &fusedRegion = fusedOp.region();
rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
fusedRegion.begin());
@@ -843,7 +846,8 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
producer.iterator_types(),
/*doc=*/nullptr,
- /*library_call=*/nullptr);
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
auto &fusedRegion = fusedOp.region();
rewriter.cloneRegionBefore(producer.region(), fusedRegion,
fusedRegion.begin());
@@ -893,7 +897,8 @@ template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
- /*library_call=*/nullptr);
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
// Map the block argument corresponding to the replaced argument with the
// scalar constant.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 6a1d00fe620c..b5643e997da5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -36,13 +36,13 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
ArrayRef<Value> vals) {
if (map.isEmpty())
return {};
- assert(map.getNumSymbols() == 0);
+
assert(map.getNumInputs() == vals.size());
SmallVector<Value, 8> res;
res.reserve(map.getNumResults());
auto dims = map.getNumDims();
for (auto e : map.getResults()) {
- auto exprMap = AffineMap::get(dims, 0, e);
+ auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
SmallVector<Value, 4> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
res.push_back(affine_apply(exprMap, operands));
@@ -165,19 +165,29 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
SmallVector<Value, 4> indexedValues;
indexedValues.reserve(nInputs + nOutputs);
+ auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
+ auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
+ if (attr) {
+ auto operand = linalgOp.getOperand(attr.getInt());
+ auto shapedType = operand.getType().template cast<ShapedType>();
+ allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
+ for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
+ allIvsPlusDims.push_back(b.create<DimOp>(loc, operand, idx));
+ }
+
// TODO: Avoid the loads if the corresponding argument of the
// region has no uses.
// 1.a. Emit load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
auto indexing = makeCanonicalAffineApplies(
- b, loc, linalgOp.getInputIndexingMap(i), allIvs);
+ b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims);
// Passing through IndexedValueType emits the proper load operation.
indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
}
// 1.b. Emit load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
auto indexing = makeCanonicalAffineApplies(
- b, loc, linalgOp.getOutputIndexingMap(i), allIvs);
+ b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims);
// Passing through IndexedValueType emits the proper load operation.
indexedValues.push_back(
IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
@@ -190,7 +200,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
SmallVector<Value, 8> outputBuffers;
for (unsigned i = 0; i < nOutputs; ++i) {
indexing.push_back(makeCanonicalAffineApplies(
- b, loc, linalgOp.getOutputIndexingMap(i), allIvs));
+ b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims));
outputBuffers.push_back(linalgOp.getOutputBuffer(i));
}
inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing,
@@ -457,7 +467,24 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
- AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
+ SmallVector<Value, 8> sizes = getViewSizes(builder, linalgOp);
+ AffineMap map = concatAffineMaps(maps);
+ if (map.getNumSymbols()) {
+ // Ignore symbols for now as they are not supported by inversePermutation.
+ unsigned dims = map.getNumDims();
+ SmallVector<AffineExpr, 8> zeros(
+ map.getNumSymbols(), getAffineConstantExpr(0, map.getContext()));
+ SmallVector<AffineExpr, 8> res;
+ for (auto result : map.getResults())
+ res.push_back(result.replaceDimsAndSymbols({}, zeros));
+
+ map = AffineMap::get(dims, 0, res, map.getContext());
+
+ // Cut off values that would have been applied to symbols
+ sizes.resize(res.size());
+ }
+
+ AffineMap invertedMap = inversePermutation(map);
if (!invertedMap)
return {};
if (invertedMap.isEmpty()) {
@@ -466,9 +493,8 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
}
SmallVector<Value, 4> allIvs;
- auto loopRanges =
- emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
- getViewSizes(builder, linalgOp));
+ auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(),
+ invertedMap, sizes);
GenerateLoopNest<LoopTy>::doit(
loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {
allIvs.append(ivs.begin(), ivs.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index afd94cc06c6e..04c1fbd5d565 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -65,7 +65,8 @@ class GenericOpConverter
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
- op.iterator_types(), op.docAttr(), op.library_callAttr());
+ op.iterator_types(), op.docAttr(), op.library_callAttr(),
+ op.symbol_sourceAttr());
// Create a new block in the region of the new Generic Op.
Block &oldBlock = op.getRegion().front();
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index e0c4b6b208f7..5ba9737a5245 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -93,6 +93,14 @@ AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
llvm_unreachable("Unknown AffineExpr");
}
+/// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
+AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
+ SmallVector<AffineExpr, 4> symbols;
+ for (unsigned idx = 0; idx < numSymbols; ++idx)
+ symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
+ return replaceDimsAndSymbols({}, symbols);
+}
+
/// Returns true if this expression is made out of only symbols and
/// constants (no dimensional identifiers).
bool AffineExpr::isSymbolicOrConstant() const {
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 42dbef273b6d..ba76976a17c1 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -434,18 +434,19 @@ AffineMap mlir::inversePermutation(AffineMap map) {
}
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
- unsigned numResults = 0;
+ unsigned numResults = 0, numDims = 0, numSymbols = 0;
for (auto m : maps)
numResults += m.getNumResults();
- unsigned numDims = 0;
SmallVector<AffineExpr, 8> results;
results.reserve(numResults);
for (auto m : maps) {
- assert(m.getNumSymbols() == 0 && "expected map without symbols");
- results.append(m.getResults().begin(), m.getResults().end());
+ for (auto res : m.getResults())
+ results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
+
+ numSymbols += m.getNumSymbols();
numDims = std::max(m.getNumDims(), numDims);
}
- return AffineMap::get(numDims, /*numSymbols=*/0, results,
+ return AffineMap::get(numDims, numSymbols, results,
maps.front().getContext());
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 585dc36dcaa8..99b942389e41 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -106,7 +106,7 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
// -----
func @generic_symbol_in_map(%arg0: memref<i32>) {
- // expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
+ // expected-error @+1 {{expected the number of symbols in indexing_map #0 to match target rank}}
linalg.generic {
args_in = 0,
args_out = 1,
@@ -120,6 +120,22 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
// -----
+func @generic_symbol_source_out_of_range(%arg0: memref<i32>) {
+ // expected-error @+1 {{symbol_source index out of range}}
+ linalg.generic {
+ args_in = 0,
+ args_out = 1,
+ indexing_maps = [ affine_map<()[N] -> (0)> ],
+ iterator_types = ["parallel"],
+ symbol_source = 1
+ } %arg0 {
+ ^bb(%i : i32):
+ linalg.yield %i : i32
+ }: memref<i32>
+}
+
+// -----
+
func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.generic {
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index f03129c4d8be..b3f6160b17ed 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -14,6 +14,7 @@
// CHECKLOOP-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECKLOOP-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECKLOOP-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
+// CHECKLOOP-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
// CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
@@ -25,6 +26,7 @@
// CHECKPARALLEL-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECKPARALLEL-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECKPARALLEL-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
+// CHECKPARALLEL-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
@@ -910,3 +912,331 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+
+#conv_1d_accesses = [
+ affine_map<(m, n)[s0] -> (m + n - s0 floordiv 2)>, // in
+ affine_map<(m, n)[s0] -> (n)>, // filter
+ affine_map<(m, n)[s0] -> (m)> // out
+]
+
+#conv_1d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m) += A(m) * B(n)",
+ indexing_maps = #conv_1d_accesses,
+ library_call = "linalg_conv_1d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
+ linalg.generic #conv_1d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?xf32>,
+ memref<?xf32>,
+ memref<?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv1d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
+// CHECKLOOP: scf.for %[[b:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[m:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+
+// CHECKPARALLEL-LABEL: @conv1d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
+// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%{{.*}}, %{{.*}}) to (%[[dim1]], %[[dim0]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+
+#conv_2d_accesses = [
+ affine_map<(m, n, m1, n1)[s0, s1] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2)>, // in
+ affine_map<(m, n, m1, n1)[s0, s1] -> (m1, n1)>, // filter
+ affine_map<(m, n, m1, n1)[s0, s1] -> (m, n)> // out
+]
+
+#conv_2d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m,n) += A(m,n) * B(m1,n1)",
+ indexing_maps = #conv_2d_accesses,
+ library_call = "linalg_conv_2d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
+ linalg.generic #conv_2d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?x?xf32>,
+ memref<?x?xf32>,
+ memref<?x?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv2d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[c1:.*]] = constant 1 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
+// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: %[[dim4:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKLOOP: %[[dim5:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]]
+// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]]] : memref<?x?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i2]], %[[i3]]] : memref<?x?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv2d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]]
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i2]], %[[i3]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
+
+#conv_3d_accesses = [
+ affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2)>, // in
+ affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m1, n1, k1)>, // filter
+ affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m, n, k)> // out
+]
+
+#conv_3d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m,n,k) += A(m,n,k) * B(m1,n1,k1)",
+ indexing_maps = #conv_3d_accesses,
+ library_call = "linalg_conv_3d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
+ linalg.generic #conv_3d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?x?x?xf32>,
+ memref<?x?x?xf32>,
+ memref<?x?x?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv3d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[c1:.*]] = constant 1 : index
+// CHECKLOOP: %[[c2:.*]] = constant 2 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
+// CHECKLOOP: %[[dim6:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim7:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[dim8:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]]
+// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]]
+// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i3]], %[[i4]], %[[i5]]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv3d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]]
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]]
+// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i3]], %[[i4]], %[[i5]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
+
+#conv_4d_accesses = [
+ affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2, l + l1 - s3 floordiv 2)>, // in
+ affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m1, n1, k1, l1)>, // filter
+ affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m, n, k, l)> // out
+]
+
+#conv_4d_trait = {
+ args_in = 2,
+ args_out = 1,
+ doc = "C(m,n,k,l) += A(m,n,k,l) * B(m1,n1,k1,l1)",
+ indexing_maps = #conv_4d_accesses,
+ library_call = "linalg_conv_4d",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
+ symbol_source = 1
+}
+
+func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : memref<?x?x?x?xf32>) -> () {
+ linalg.generic #conv_4d_trait %in, %filter, %out {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ } : memref<?x?x?x?xf32>,
+ memref<?x?x?x?xf32>,
+ memref<?x?x?x?xf32>
+ return
+}
+
+// CHECKLOOP-LABEL: @conv4d
+// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKLOOP: %[[c0:.*]] = constant 0 : index
+// CHECKLOOP: %[[c1:.*]] = constant 1 : index
+// CHECKLOOP: %[[c2:.*]] = constant 2 : index
+// CHECKLOOP: %[[c3:.*]] = constant 3 : index
+// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim3:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim6]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim7]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i6:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
+// CHECKLOOP: scf.for %[[i7:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
+// CHECKLOOP: %[[dim8:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim9:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim10:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[dim11:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]]
+// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]]
+// CHECKLOOP: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]]
+// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv4d
+// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
+// CHECKPARALLEL: %[[c3:.*]] = constant 3 : index
+// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]], %[[i6:.*]], %[[i7:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim4]], %[[dim5]], %[[dim6]], %[[dim7]], %[[dim0]], %[[dim1]], %[[dim2]], %[[dim3]]) step ({{.*}}) {
+// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim9:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim10:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[dim11:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
+// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]]
+// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]]
+// CHECKPARALLEL: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]]
+// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 2fbdfe989d05..5ad441aa15c3 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -77,7 +77,8 @@ struct TestBufferPlacementPreparationPass
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
- op.iterator_types(), op.docAttr(), op.library_callAttr());
+ op.iterator_types(), op.docAttr(), op.library_callAttr(),
+ op.symbol_sourceAttr());
// Create a new block in the region of the new Generic Op.
Block &oldBlock = op.getRegion().front();
More information about the Mlir-commits
mailing list