[Mlir-commits] [mlir] 046922e - [mlir][linalg] Add support for scalar input operands.
Tobias Gysi
llvmlistbot at llvm.org
Sun Jun 13 23:28:16 PDT 2021
Author: Tobias Gysi
Date: 2021-06-14T06:27:16Z
New Revision: 046922e1003795d67df89721e6b76c01b214d408
URL: https://github.com/llvm/llvm-project/commit/046922e1003795d67df89721e6b76c01b214d408
DIFF: https://github.com/llvm/llvm-project/commit/046922e1003795d67df89721e6b76c01b214d408.diff
LOG: [mlir][linalg] Add support for scalar input operands.
Up to now all structured op operands are assumed to be shaped. The patch relaxes this assumption and allows scalar input operands. In contrast to shaped operands scalar operands are not indexed and directly forwarded to the body of the operation. As all other operands, scalar operands are associated to an indexing map that in case of a scalar or a 0D-operand has an empty range.
We will use scalar operands as a replacement for the capture mechanism. In contrast to captures, the approach ensures we can generate the function signature from the operand list and it prevents outdated capture values in case a transformation updates only the capture operand but not the hidden body of a named operation.
Removing captures and updating existing operations such as linalg.fill is left for a later patch.
The patch depends on https://reviews.llvm.org/D103891 and https://reviews.llvm.org/D103890.
Differential Revision: https://reviews.llvm.org/D104109
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 6dabb6de2cfb7..bc9acee6fd6e7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -15,8 +15,6 @@
include "mlir/IR/OpBase.td"
-def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
-
def Linalg_Dialect : Dialect {
let name = "linalg";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 73e5570be4627..e53f2f35ca66e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -584,6 +584,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if the `opOperand` is a scalar value.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isScalar",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ return !opOperand->get().getType().template isa<ShapedType>();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the input or output indexing map for `opOperand`.
@@ -694,10 +707,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return this->getOperation()->getNumResults() == 0 &&
- llvm::all_of(getInputAndOutputOperands(),
- [](OpOperand *opOperand) {
- return opOperand->get().getType().template isa<MemRefType>();
- });
+ llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
+ return isScalar(opOperand) ||
+ opOperand->get().getType().template isa<MemRefType>();
+ }) &&
+ llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<MemRefType>();
+ });
}]
>,
InterfaceMethod<
@@ -709,8 +725,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::all_of(getInputAndOutputOperands(),
- [](OpOperand *opOperand) {
+ return
+ llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
+ return isScalar(opOperand) ||
+ opOperand->get().getType().template isa<RankedTensorType>();
+ }) &&
+ llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
});
}]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 5d7e2cc5e64d0..9b6120e61eeda 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -640,8 +640,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let arguments = (ins Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
- Variadic<LinalgOperand>:$inputs,
- Variadic<LinalgOperand>:$outputs,
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
ArrayAttr:$iterator_types,
OptionalAttr<ArrayAttr>:$distribution_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 41fcc2495e658..2b70572ac1e1c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -517,17 +517,12 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
//===----------------------------------------------------------------------===//
// Generic Linalg ops.
//===----------------------------------------------------------------------===//
-class LinalgOperandOfRank<int rank>: Type<
- And<[
- LinalgOperand.predicate,
- CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
- >>;
class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
- let arguments = (ins Variadic<AnyShaped>:$inputs,
+ let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 8a48e89cda530..18717e9820d34 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -338,7 +338,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
return failure();
- // All shaped operands must be indexed.
+ // All input/output operands must be indexed.
if (static_cast<int64_t>(linalgOp.indexing_maps().size()) !=
linalgOp.getNumInputsAndOutputs())
return op->emitOpError("expected the number of indexing_map (")
@@ -363,7 +363,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
int64_t rank = linalgOp.getRank(opOperand);
if (indexingMap.getNumResults() != rank)
- return op->emitOpError("expected shaped value rank (")
+ return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
<< opOperand->getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
@@ -444,7 +444,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (linalgOp.getNumInputsAndOutputs() + numBBIvs != block.getNumArguments())
return op->emitOpError("expected as many non-induction variable region "
- "arguments as the number of shaped operands");
+ "arguments as the number of input/output operands");
// Note: the number and type of yield values are checked in the YieldOp.
for (unsigned i = 0; i < numBBIvs; ++i)
@@ -452,14 +452,14 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
return op->emitOpError("expected index block argument #") << i;
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+ Type elementType = getElementTypeOrSelf(opOperand->get());
Type argType =
block.getArgument(numBBIvs + opOperand->getOperandNumber()).getType();
if (elementType != argType)
return op->emitOpError("expected type of bb argument #")
<< numBBIvs + opOperand->getOperandNumber() << " (" << argType
<< ")"
- << " to match element type of corresponding shaped operand ("
+ << " to match element or self type of the corresponding operand ("
<< elementType << ")";
}
@@ -489,10 +489,11 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
// The first index or last index should be the maximum or the minimum in
// the inferred index ranges since the range is increasing or
- // decreasing. The size of dimensions of shaped operands and the maximum
- // value + 1 in the inferred range should be the same. But, for now we
- // check if the inferred ranges are in boundary of shaped operands' size
- // or not in case that Affine Expressions are complicated such as d0 * 3
+ // decreasing. The size of dimensions of input/output operands and the
+ // maximum value + 1 in the inferred range should be the same. But, for
+ // now we check if the inferred ranges are in boundary of input/output
+ // operands' size or not in case that Affine Expressions are complicated
+ // such as d0 * 3
// + d1 since it is not easy to handle the issues.
// Found the case that this solution can't check, for example, (d0, d1)
// -> (d1 - d0)
@@ -510,14 +511,14 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
}
if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
if (inferredDimSize != shape[dim]) {
- return op->emitOpError("inferred shaped operand #")
+ return op->emitOpError("inferred input/output operand #")
<< opOperand->getOperandNumber()
<< " has shape's dimension #" << dim << " to be "
<< inferredDimSize << ", but found " << shape[dim];
}
} else {
if (inferredDimSize > shape[dim]) {
- return op->emitOpError("inferred shaped operand #")
+ return op->emitOpError("inferred input/output operand #")
<< opOperand->getOperandNumber()
<< " has shape's dimension #" << dim
<< " to be greater than or equal to " << inferredDimSize
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6aa0ed15fc945..6eef6b0b48efc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -377,8 +377,7 @@ void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
OpOperand *output = op.getOutputOperand(0);
OpOperand *input = op.getInputOperand(0);
- if (getElementTypeOrSelf(input->get().getType()) !=
- getElementTypeOrSelf(output->get().getType()))
+ if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
return op.emitOpError("expects views of the same type");
if (op.getRank(input) != op.getRank(output))
return op.emitOpError("expects views of the same rank");
@@ -452,7 +451,7 @@ void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
OpOperand *output = op.getOutputOperand(0);
Type fillType = op.value().getType();
- if (getElementTypeOrSelf(output->get().getType()) != fillType)
+ if (getElementTypeOrSelf(output->get()) != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
return op.emitOpError(
@@ -489,7 +488,7 @@ void GenericOp::build(
SmallVector<Type, 4> blockArgTypes;
for (ValueRange container : {inputs, outputs})
for (Value v : container)
- blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
+ blockArgTypes.push_back(getElementTypeOrSelf(v));
OpBuilder::InsertionGuard guard(builder);
auto ®ion = *result.regions.front();
@@ -545,7 +544,7 @@ void IndexedGenericOp::build(
SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
for (ValueRange container : {inputs, outputs})
for (Value v : container)
- blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
+ blockArgTypes.push_back(getElementTypeOrSelf(v));
OpBuilder::InsertionGuard guard(builder);
auto ®ion = *result.regions.front();
@@ -2949,7 +2948,6 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures,
std::function<void(unsigned, unsigned)> errorHandler) {
- assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
// TODO: atm all operands go through getElementTypeOrSelf,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 1c2d32bba8183..3b63bf1f38d5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -484,18 +484,21 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
b.setInsertionPoint(op);
Location loc = op.getLoc();
- SmallVector<Value, 2> newInputBuffers;
- newInputBuffers.reserve(op.getNumInputs());
+ SmallVector<Value> newInputs;
+ newInputs.reserve(op.getNumInputs());
for (OpOperand *opOperand : op.getInputOperands()) {
- Value v = lookup(bvm, opOperand->get());
- if (!v)
+ if (op.isScalar(opOperand)) {
+ newInputs.push_back(opOperand->get());
+ continue;
+ }
+ newInputs.push_back(lookup(bvm, opOperand->get()));
+ if (!newInputs.back())
return failure();
- newInputBuffers.push_back(v);
}
- SmallVector<Value, 2> newOutputBuffers;
+ SmallVector<Value> newOutputBuffers;
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
return failure();
- finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
+ finalizeBufferAllocation(b, op, newInputs, newOutputBuffers, bvm);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 102dbdb4e2c36..1deea94766746 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -301,7 +301,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
++dim;
}
// Compute the tensor or scalar replacement type.
- Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+ Type elementType = getElementTypeOrSelf(opOperand->get());
Type replacementType = elementType == opOperand->get().getType()
? elementType
: RankedTensorType::get(newShape, elementType);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 0263bcb708448..098442cf149e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -129,14 +129,14 @@ static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
assert(producer.hasTensorSemantics() &&
"only fusion on tensors is currently supported for TiledLinalgOp");
- for (OpOperand *producerInput : producer.getInputTensorOperands()) {
+ for (OpOperand *producerInput : producer.getInputOperands()) {
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
if (addedInput == nullptr)
addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
tiledOperands.push_back(addedBlockArg);
}
- for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
+ for (OpOperand *producerOutput : producer.getOutputOperands()) {
OpResult result = producer.getTiedOpResult(producerOutput);
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 530c02466fc12..49870da402cf8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -126,8 +126,12 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
// TODO: Avoid the loads if the corresponding argument of the
// region has no uses.
- // 1.a. Emit load from input views.
+ // 1.a. Emit load from input operand or for scalars access the operand itself.
for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
+ if (linalgOp.isScalar(inputOperand)) {
+ indexedValues.push_back(inputOperand->get());
+ continue;
+ }
auto indexing = makeCanonicalAffineApplies(
b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
indexedValues.push_back(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 44acac1069707..efd0c3b2079d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -149,7 +149,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
}
Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
auto staticTensorType = RankedTensorType::get(
- staticSizes, getElementTypeOrSelf(opOperand->get().getType()));
+ staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 689aae1e7df4a..96846bfb66335 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -479,6 +479,10 @@ LogicalResult vectorizeAsLinalgGeneric(
SmallVector<AffineMap> indexings;
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
+ if (linalgOp.isScalar(opOperand)) {
+ bvm.map(bbarg, opOperand->get());
+ continue;
+ }
// TODO: 0-d vectors.
if (linalgOp.getShape(opOperand).empty()) {
Value loaded =
@@ -494,14 +498,13 @@ LogicalResult vectorizeAsLinalgGeneric(
if (broadcastToMaximalCommonShape) {
map = inverseAndBroadcastProjectedPermuation(
linalgOp.getTiedIndexingMap(opOperand));
- vectorType = VectorType::get(
- commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
+ vectorType = VectorType::get(commonVectorShape,
+ getElementTypeOrSelf(opOperand->get()));
} else {
map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
- vectorType =
- VectorType::get(map.compose(linalgOp.getShape(opOperand)),
- getElementTypeOrSelf(opOperand->get().getType()));
+ vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+ getElementTypeOrSelf(opOperand->get()));
}
Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
@@ -1157,7 +1160,7 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
int64_t rank = op.getRank(input);
int64_t numDims = mapping.size();
- Type elemType = getElementTypeOrSelf(input->get().getType());
+ Type elemType = getElementTypeOrSelf(input->get());
auto map = AffineMap::get(rank, 0, mapping, context);
SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index c99aafb29c3ee..4460ae88c6c22 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1372,6 +1372,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Detects sparse annotations and translate the per-dimension sparsity
// information for all tensors to loop indices in the kernel.
assert(op.getNumOutputs() == 1);
+ assert(llvm::none_of(op.getInputAndOutputOperands(),
+ [&](OpOperand *t) { return op.isScalar(t); }));
unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numLoops = op.iterator_types().getValue().size();
Merger merger(numTensors, numLoops);
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index bd4f66defe6ee..5a53c228bea5e 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -2,6 +2,7 @@
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> ()>,
affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
]
@@ -11,21 +12,22 @@
library_call = "some_external_func"
}
-func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
%0 = linalg.generic #trait
- ins(%arg0 : tensor<?x1x?xf32>)
+ ins(%arg0, %arg1 : tensor<?x1x?xf32>, f32)
outs(%shape : tensor<?x1x?x1x?xf32>) {
- ^bb0(%arg2 : f32, %arg3 : f32) :
- linalg.yield %arg2 : f32
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+ linalg.yield %arg3 : f32
} -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @drop_one_trip_loops
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
+// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 9f1566c8fb590..730482d957578 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -292,7 +292,7 @@ module {
// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
-// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
+// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
// TLOOP-SAME: step (%[[C32]], %[[C64]])
// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
@@ -305,7 +305,80 @@ module {
// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
-// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
+// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
+// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
+// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
+// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
+// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
+// TLOOP-SAME: iterators["reduction"] {
+
+// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]]
+// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0]
+
+// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
+// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]]
+// TLOOP: }
+// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]]
+// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
+// TLOOP: }
+// TLOOP: return %[[AB]] : [[TY]]
+
+// -----
+
+module {
+ func @generic_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = constant 0.0 : f32
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%c0 : f32)
+ outs(%arg0: tensor<?x?xf32>) {
+ ^bb(%0: f32, %1: f32) :
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
+ ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+ }
+}
+
+// TLOOP-LABEL: func @generic_plus_matmul(
+// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
+// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0
+// TLOOP-DAG: %[[C32:.*]] = constant 32 : index
+// TLOOP-DAG: %[[C64:.*]] = constant 64 : index
+// TLOOP-DAG: %[[C16:.*]] = constant 16 : index
+// TLOOP-DAG: %[[C0:.*]] = constant 0 : index
+// TLOOP-DAG: %[[C1:.*]] = constant 1 : index
+
+// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
+
+// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C32]], %[[C64]])
+// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
+// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]],
+// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]]
+// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
+
+// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
+// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
+// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP: %[[INIT_SUB:.*]] = linalg.generic
+// TLOOP-SAME: ins(%[[C0_F32_]]
+// TLOOP-SAME: outs(%[[OUT_SUB]]
+
+// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 3146b4194f828..068b875f3b7cd 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -40,6 +40,48 @@ func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : te
// -----
+// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> ()>
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> ()>
+
+// CHECK-LABEL: @scalar_add_mul_fusion
+func @scalar_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : f32, %arg2 : f32) -> tensor<?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, f32)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %4 = addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ // CHECK: linalg.generic {
+ // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP1]], [[$MAP0]]{{\]}}
+ %4 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg2 : tensor<?x?xf32>, f32)
+ outs(%2 : tensor<?x?xf32>) {
+ // CHECK: ^{{[a-zA-Z0-9_]*}}
+ // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]]
+ // CHECK-SAME: [[ARG4:%[a-zA-Z0-9_]*]]
+ // CHECK-SAME: [[ARG5:%[a-zA-Z0-9_]*]]
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): // no predecessors
+ // CHECK: [[T1:%[a-zA-Z0-9_]*]] = addf [[ARG3]], [[ARG4]]
+ // CHECK-NOT: linalg.yield
+ // CHECK: mulf [[T1]], [[ARG5]]
+ // CHECK: linalg.yield
+ %5 = mulf %arg5, %arg6 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ return %4 : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)>
#map0 = affine_map<(d0, d1) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9dc20a46dde0a..ac56add661bcc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -96,7 +96,7 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
// -----
func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+1 {{expected shaped value rank (1) to match the result rank of indexing_map #0 (2)}}
+ // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}}
linalg.generic {
indexing_maps = [ affine_map<() -> (0, 0)> ],
iterator_types = []}
@@ -108,6 +108,21 @@ func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>)
// -----
+func @generic_scalar_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+ %cst = constant 0.0 : f32
+ // expected-error @+1 {{expected operand rank (0) to match the result rank of indexing_map #0 (1)}}
+ linalg.generic {
+ indexing_maps = [ affine_map<() -> (0)>, affine_map<() -> (0, 0)> ],
+ iterator_types = []}
+ ins(%cst : f32)
+ outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+ ^bb(%0 : f32, %1 : f32):
+ linalg.yield %0: f32
+ }
+}
+
+// -----
+
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
linalg.generic {
@@ -174,7 +189,7 @@ func @generic_empty_region(%arg0: memref<f32>) {
// -----
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
- // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
+ // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
linalg.generic {
indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
iterator_types = []}
@@ -186,8 +201,8 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// -----
-func @generic_block_arg_type(%arg0: memref<f32>) {
- // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element type of corresponding shaped operand ('f32')}}
+func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
+ // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
linalg.generic {
indexing_maps = [ affine_map<() -> ()> ],
iterator_types = []}
@@ -199,8 +214,21 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
// -----
+func @generic_scalar_operand_block_arg_type(%arg0: f32) {
+ // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
+ linalg.generic {
+ indexing_maps = [ affine_map<() -> ()> ],
+ iterator_types = []}
+ outs(%arg0 : f32) {
+ ^bb(%i: i1):
+ linalg.yield %i : i1
+ }
+}
+
+// -----
+
func @indexed_generic_block_arg_count(%arg0: memref<?xf32>) {
- // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
+ // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(i) -> (i)> ],
iterator_types = ["parallel"]}
@@ -226,7 +254,7 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<?xf32>) {
// -----
func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
- // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element type of corresponding shaped operand ('f32')}}
+ // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element or self type of the corresponding operand ('f32')}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"]}
@@ -239,7 +267,7 @@ func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
// -----
func @indexed_generic_arg_count(%arg0: memref<f32>) {
- // expected-error @+1 {{expected as many non-induction variable region arguments as the number of shaped operands}}
+ // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
linalg.indexed_generic {
indexing_maps = [ affine_map<()[] -> ()> ],
iterator_types = []}
@@ -401,7 +429,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
%arg1: memref<2x3xf32>,
%arg2: memref<?x?x?xf32>) {
- // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}}
+ // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
memref<?x?x?xf32>, memref<2x3xf32>, memref<?x?x?xf32>
return
@@ -410,7 +438,7 @@ func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
// -----
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
- // expected-error @+1 {{expected shaped value rank (2) to match the result rank of indexing_map #1 (3)}}
+ // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
outs(%c3 : memref<?x?x?xf32>)
return
@@ -714,7 +742,7 @@ func @illegal_fill_tensor_with_memref_return
// -----
func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
- // expected-error @+1 {{inferred shaped operand #1 has shape's dimension #0 to be 4, but found 3}}
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 3}}
linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
outs(%arg2 :memref<2x4xf32>)
return
@@ -723,7 +751,7 @@ func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg
// -----
func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
- // expected-error @+1 {{inferred shaped operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
+ // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
linalg.conv_2d_input_nhwc_filter_hwcf
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
ins(%input, %filter : memref<1x3x4x2xf32>, memref<3x2x2x1xf32>)
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index c469160d6e864..7bc8dc5b7acd7 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -975,6 +975,30 @@ func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][]
// CHECKPARALLEL: store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
+func @generic_op_scalar(%arg0: f32, %arg1: memref<3x4xf32>)
+{
+ linalg.generic #trait_broadcast
+ ins(%arg0 : f32)
+ outs(%arg1 : memref<3x4xf32>) {
+ ^bb(%a: f32, %b: f32) :
+ linalg.yield %a : f32
+ }
+ return
+}
+
+// CHECK-LABEL: @generic_op_scalar
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
+// CHECK: scf.for %[[i:.*]] = {{.*}}
+// CHECK: scf.for %[[j:.*]] = {{.*}}
+// CHECK: store %[[ARG0]], %[[ARG1]][%[[i]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_op_scalar
+// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
+// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
+// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
+// CHECKPARALLEL: store %[[ARG0]], %[[ARG1]][%[[i]], %[[j]]]
+
func @generic_index_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
{
linalg.generic #trait_broadcast
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4a3d54c3dc3a4..1cf4e6bbe9489 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -2,37 +2,42 @@
// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
- %arg1 : tensor<?x?x?xf32>) ->
+ %arg1 : tensor<?x?x?xf32>,
+ %arg2 : f32) ->
tensor<?x?x?xf32>
{
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] :
tensor<?x?x4x?xf32> into tensor<?x?x?xf32>
%1 = linalg.generic {
- indexing_maps = [#map0, #map1, #map1],
+ indexing_maps = [#map0, #map1, #map2, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
outs(%0 : tensor<?x?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %2 = addf %1, %arg5 : f32
+ linalg.yield %2 : f32
} -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2], [3]
// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK: %[[T3:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
+// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>)
+// CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x4xf32>)
// CHECK: %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]]
// CHECK-SAME: [0], [1], [2, 3]
@@ -42,18 +47,21 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> ()>
func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
- %arg1 : tensor<?x?xf32>) ->
+ %arg1 : tensor<?x?xf32>,
+ %arg2 : f32) ->
tensor<?x4x?x5xf32>
{
%0 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
+ indexing_maps = [#map0, #map0, #map1, #map0],
iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>, f32)
outs(%arg0 : tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
+ %2 = addf %1, %arg5 : f32
+ linalg.yield %2 : f32
} -> tensor<?x?xf32>
%1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
tensor<?x?xf32> into tensor<?x4x?x5xf32>
@@ -61,9 +69,12 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
}
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
+
// CHECK: func @generic_op_reshape_consumer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>)
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
@@ -71,9 +82,9 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>)
+// CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
// CHECK-SAME: outs(%{{.+}} : tensor<?x4x?x5xf32>)
// CHECK: return %[[T3]] : tensor<?x4x?x5xf32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 9ca383a66cc13..4675f82852b7d 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -316,6 +316,7 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
#accesses_0 = [
affine_map<(i, j, k) -> (j, i)>,
+ affine_map<(i, j, k) -> ()>,
affine_map<(i, j, k) -> (i, k, i + j)>
]
@@ -327,34 +328,34 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ %cst = constant 0.0 : f32
linalg.generic #trait_0
- ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ ins(%arg0, %cst : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, f32)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
- ^bb(%0: vector<3x4xi4>, %1: f32) :
- %f0 = constant 0.0 : f32
- linalg.yield %f0 : f32
+ ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
+ linalg.yield %1 : f32
}
return
}
// CHECK-LABEL: func @generic
// CHECK: linalg.generic {
-// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
+// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+// CHECK-SAME: ins({{.*}}, {{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>, f32)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ %cst = constant 0.0 : f32
linalg.generic #trait_0
- ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
+ ins(%arg0, %cst : tensor<?x?xvector<3x4xi4>>, f32)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
- ^bb(%0: vector<3x4xi4>, %1: f32) :
- %f0 = constant 0.0 : f32
- linalg.yield %f0 : f32
+ ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
+ linalg.yield %1 : f32
}
return
}
@@ -362,7 +363,7 @@ func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>)
+// CHECK-SAME: ins({{.*}}, {{.*}} : tensor<?x?xvector<3x4xi4>>, f32)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
index 63dc9fba7a85b..a7322a7a34a90 100644
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -8,6 +8,7 @@
func @matmul_tensors(
%arg0: tensor<?x?xi8>, %arg1: tensor<?x?xi8>, %arg2: tensor<?x?xi32>)
-> tensor<?x?xi32> {
+// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xi32>) {
@@ -19,11 +20,11 @@ func @matmul_tensors(
// CHECK-NOT: linalg.matmul {{.*}} tensor<?x?xi8>
// Padding injects static information.
-// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?xi8> to tensor<2x4xi8>
-// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?xi8> to tensor<4x3xi8>
-// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK: : tensor<?x?xi32> to tensor<2x3xi32>
// CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -41,6 +42,41 @@ func @matmul_tensors(
return %0 : tensor<?x?xi32>
}
+// CHECK-LABEL: func @generic_scalar_and_tensor(
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[VAL:[0-9a-z]+]]: f32) -> tensor<?x?x?xf32> {
+func @generic_scalar_and_tensor(
+ %arg0: tensor<?x?x?xf32>, %arg1: f32)
+ -> tensor<?x?x?xf32> {
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?x?xf32>) {
+// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?x?xf32>) {
+// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?x?xf32>) {
+// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+
+// Padding injects static information.
+// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
+// CHECK: : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+// CHECK: %[[pD:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>)
+// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0, 0] [%{{.*}}, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<2x3x4xf32> to tensor<?x?x?xf32>
+// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+// CHECK: scf.yield %[[TD]] : tensor<?x?x?xf32>
+// CHECK: scf.yield %[[TD2]] : tensor<?x?x?xf32>
+// CHECK: scf.yield %[[TD1]] : tensor<?x?x?xf32>
+ %0 = linalg.generic {
+ indexing_maps = [ affine_map<(d0, d1, d2) -> ()>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ {__internal_linalg_transform__ = "tile-and-pad"}
+ ins(%arg1 : f32)
+ outs(%arg0: tensor<?x?x?xf32>) {
+ ^bb(%0: f32, %1: f32) :
+ linalg.yield %0 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
// CHECK-1DIM-TILE: func @matmul_tensors(
// CHECK-1DIM-TILE: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
// CHECK-1DIM-TILE: %[[TB:[0-9a-z]+]]: tensor<?x?xi8>
@@ -65,6 +101,7 @@ func @matmul_partially_padded_tensors(
// CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor<?x8xi8>
// CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8>
// CHECK-1DIM-TILE-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
+// CHECK-1DIM-TILE: %[[C0:.*]] = constant 0 : index
// CHECK-1DIM-TILE: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
// CHECK-1DIM-TILE: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
// CHECK-1DIM-TILE: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
@@ -72,11 +109,11 @@ func @matmul_partially_padded_tensors(
// CHECK-1DIM-TILE: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
// CHECK-1DIM-TILE: %[[sTBc:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor<?x?xi8>
// CHECK-1DIM-TILE: %[[sTC:.*]] = subtensor %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<?x?xi8> to tensor<2x8xi8>
-// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<?x?xi8> to tensor<8x3xi8>
-// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
+// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<?x?xi32> to tensor<2x3xi32>
// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5a7110fec18bf..ff594a0f32e79 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -136,6 +136,23 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// -----
+// CHECK-LABEL: func @test_vectorize_scalar_input
+func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
+ // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+ linalg.generic {
+ indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : f32)
+ outs(%A: memref<8x16xf32>) {
+ ^bb(%0: f32, %1: f32) :
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @test_vectorize_fill
func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 26e4f6a95b02e..1e68c6c927716 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -162,8 +162,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
changed = true;
}
- } else {
- assert(opOperand->get().getType().isa<RankedTensorType>());
+ } else if (opOperand->get().getType().isa<RankedTensorType>()) {
// Tile and Fuse tensor input.
if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
continue;
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 16eb79ffe4fb6..9159dcfe5d642 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -533,7 +533,7 @@ static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
// For now, just assume it is the zero of type.
// In the future, it should be the zero of type + op.
static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
- auto t = getElementTypeOrSelf(op.get().getType());
+ auto t = getElementTypeOrSelf(op.get());
return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
}
@@ -544,7 +544,8 @@ static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
linalg::LinalgTilingOptions()
.setTileSizes(tileSizes)
.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
- tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
+ tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
+ linalg::LinalgTilingPattern<linalg::GenericOp>>(
context, linalgTilingOptions,
linalg::LinalgTransformationFilter(
Identifier::get("tile-and-pad", context)));
More information about the Mlir-commits
mailing list