[Mlir-commits] [mlir] 5b3cb31 - [mlir][linalg] Purge linalg.indexed_generic.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Jun 17 05:46:12 PDT 2021
Author: Alexander Belyaev
Date: 2021-06-17T14:45:37+02:00
New Revision: 5b3cb31edbcf99ef15c2de2d29ad0ff9927ba136
URL: https://github.com/llvm/llvm-project/commit/5b3cb31edbcf99ef15c2de2d29ad0ff9927ba136
DIFF: https://github.com/llvm/llvm-project/commit/5b3cb31edbcf99ef15c2de2d29ad0ff9927ba136.diff
LOG: [mlir][linalg] Purge linalg.indexed_generic.
Differential Revision: https://reviews.llvm.org/D104449
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 54dda4c4fcb2b..a8e4bbdd69b35 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -32,7 +32,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
// always be 0 for index-free linalg ops. For IndexedGeneric, this must be
// equal to numLoops.
unsigned getNumPayloadInductionVariables() {
- return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
+ return 0;
}
// Return whether the op accesses the iteration indices.
@@ -671,140 +671,6 @@ def GenericOp : GenericOpBase<"generic"> {
let hasFolder = 1;
}
-/// GenericOp with Indexing (i.e. multi-for style in which the region is passed
-/// the enclosing loop induction variables)
-def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
- let description = [{
- Indexed Generic Linalg op form where the key properties of the computation
- are specified as attributes. In pretty form, a `linalg.indexed_generic` op
- is written as:
-
- ```mlir
- linalg.indexed_generic #trait_attribute
- ins(%A, %B : memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>)
- outs(%C : memref<?x?xf32, stride_specification>)
- attrs = {other-optional-attributes}
- {region}
- ```
-
- Where #trait_attributes is an alias of a dictionary attribute containing:
- - doc [optional]: a documentation string
- - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
- and output view. Such AffineMapAttr specifies the mapping between the
- loops and the indexing within each view.
- - library_call [optional]: a StringAttr containing the name of an
- external library function that the linalg.indexed_generic operation
- maps to. The external library is assumed to be dynamically linked and
- no strong compile-time guarantees are provided. In the absence of such
- a library call, linalg.indexed_generic will always lower to loops.
- - iterator_types: an ArrayAttr they type of the enclosing loops; Each
- element of the list represents and iterator of one of the following
- types:
- parallel, reduction, window
-
- Example:
- Defining a #matmul_trait attribute in MLIR can be done as follows:
-
- ```mlir
- #matmul_accesses = [
- (m, n, k) -> (m, k),
- (m, n, k) -> (k, n),
- (m, n, k) -> (m, n)
- ]
- #matmul_trait = {
- doc = "C(m, n) += A(m, k) * B(k, n)",
- indexing_maps = #matmul_accesses,
- library_call = "linalg_matmul",
- iterator_types = ["parallel", "parallel", "reduction"]
- }
- ```
-
- And can be reused in multiple places as:
-
- ```mlir
- linalg.indexed_generic #matmul_trait
- ins(%A, %B : memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>)
- outs(%C : memref<?x?xf32, stride_specification>) {
- (%offset_m: index, %offset_n: index, %offset_k: index,
- %a: f32, %b: f32, %c: f32) :
- "some_optional_computation"(%offset_m, %offset_n, %offset_k)
- %d = mulf %a, %b: f32
- %e = addf %c, %d: f32
- linalg_yield %e : f32
- }
- ```
-
- This may lower to either:
-
- ```mlir
- call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
- (index, index, index,
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>)
- -> ()
- ```
-
- or IR resembling:
-
- ```mlir
- scf.for %m = %c0 to %M step %c1 {
- scf.for %n = %c0 to %N step %c1 {
- scf.for %k = %c0 to %K step %c1 {
- %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
- %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
- %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
- "some_optional_computation"(%m, %n, %k)
- %d = mulf %a, %b: f32
- %e = addf %c, %d: f32
- store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
- }
- }
- }
- ```
-
- To allow progressive lowering from the value world (a.k.a tensor values) to
- the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
- allows mixing tensors and buffers operands and tensor results.
-
- ```mlir
- %C = linalg.indexed_generic #trait_attribute
- ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
- outs(%C : tensor<?x?xf32>)
- {other-optional-attributes}
- {region_with_index_arguments}
- -> (tensor<?x?xf32>)
- ```
- }];
-
- let builders = [
- OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
- "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
- "StringRef":$libraryCall,
- CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
- "nullptr">)>,
- OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
- "StringRef":$doc, "StringRef":$libraryCall,
- CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
- "nullptr">)>,
- OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
- "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
- "ArrayRef<StringRef>":$iteratorTypes,
- CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
- "nullptr">)>,
- OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
- "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
- CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
- "nullptr">)>
- ];
- let verifier = [{ return ::verify(*this); }];
-
- let hasCanonicalizer = 1;
-}
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 28dd7bb860c78..6f422e5f629fb 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -100,10 +100,6 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
if (isa<CopyOp>(op))
return failure();
- // Canonicalize indexed generic operations before library call conversion.
- if (isa<IndexedGenericOp>(op))
- return failure();
-
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
if (!libraryCallName)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 985a9f7a09a2f..b05a1477982cd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -525,69 +525,8 @@ void GenericOp::build(
/*doc=*/"",
/*libraryCall=*/"", bodyBuild);
}
-void IndexedGenericOp::build(
- OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
- ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
- function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
- bodyBuild) {
- build(builder, result, resultTensorTypes, inputs, outputs,
- builder.getAffineMapArrayAttr(indexingMaps),
- builder.getStrArrayAttr(iteratorTypes),
- doc.empty() ? StringAttr() : builder.getStringAttr(doc),
- libraryCall.empty() ? StringAttr()
- : builder.getStringAttr(libraryCall));
- if (!bodyBuild)
- return;
-
- unsigned nLoops = iteratorTypes.size();
- SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
- for (ValueRange container : {inputs, outputs})
- for (Value v : container)
- blockArgTypes.push_back(getElementTypeOrSelf(v));
-
- OpBuilder::InsertionGuard guard(builder);
- auto ®ion = *result.regions.front();
- Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes);
- bodyBuild(builder, result.location,
- bodyBlock->getArguments().take_front(nLoops),
- bodyBlock->getArguments().drop_front(nLoops));
-}
-
-void IndexedGenericOp::build(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
- function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
- bodyBuild) {
- build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
- iteratorTypes, doc, libraryCall, bodyBuild);
-}
-
-void IndexedGenericOp::build(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes,
- function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
- bodyBuild) {
- build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
- /*doc=*/"", /*libraryCall=*/"", bodyBuild);
-}
-
-void IndexedGenericOp::build(
- OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
- ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
- ArrayRef<StringRef> iteratorTypes,
- function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
- bodyBuild) {
- build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
- iteratorTypes,
- /*doc=*/"",
- /*libraryCall=*/"", bodyBuild);
-}
-template <typename GenericOpType>
-static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
+static void print(OpAsmPrinter &p, GenericOp op) {
p << op.getOperationName() << " ";
// Print extra attributes.
@@ -628,12 +567,6 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
printNamedStructuredOpResults(p, op.result_tensors().getTypes());
}
-static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
-
-static void print(OpAsmPrinter &p, IndexedGenericOp op) {
- printGenericOp(p, op);
-}
-
static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
DictionaryAttr dictAttr;
// Parse the core linalg traits that must check into a dictAttr.
@@ -704,15 +637,6 @@ void GenericOp::getEffects(
outputBuffers);
}
-void IndexedGenericOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- SmallVector<Value> inputBuffers = getInputBufferOperands();
- SmallVector<Value> outputBuffers = getOutputBufferOperands();
- getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
- outputBuffers);
-}
-
template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
return success();
@@ -720,52 +644,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
-static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
-
-namespace {
-
-/// Replace indexed_generic ops by generic ops that access the iteration indices
-/// using index operation calls.
-struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
- using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(IndexedGenericOp indexedOp,
- PatternRewriter &rewriter) const override {
- // Replace all uses of the index block arguments.
- BlockAndValueMapping bvm;
- if (Block *body = indexedOp.getBody()) {
- rewriter.setInsertionPointToStart(body);
- for (const auto &en : llvm::enumerate(
- body->getArguments().take_front(indexedOp.getNumLoops()))) {
- Value index = rewriter.create<IndexOp>(indexedOp.getLoc(), en.index());
- bvm.map(en.value(), index);
- }
- }
-
- // Create a generic replacement operation and clone the body.
- rewriter.setInsertionPointAfter(indexedOp);
- SmallVector<Value> inputOperands = indexedOp.getInputOperands();
- SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
- SmallVector<StringRef> iterators = llvm::to_vector<4>(
- indexedOp.iterator_types().getAsValueRange<StringAttr>());
- GenericOp genericOp = rewriter.create<GenericOp>(
- indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
- outputOperands, indexedOp.getIndexingMaps(), iterators);
- Region &genericRegion = genericOp.region();
- Region &indexedRegion = indexedOp.region();
- rewriter.cloneRegionBefore(indexedRegion, genericRegion,
- genericRegion.begin(), bvm);
-
- rewriter.replaceOp(indexedOp, genericOp->getResults());
- return success();
- }
-};
-} // namespace
-
-void IndexedGenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<ConvertIndexedToGenericOp>(context);
-}
-
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
@@ -3230,7 +3108,7 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
PatternRewriter &rewriter) const override {
// This pattern reduces the number of arguments of an op, which breaks
// the invariants of semantically charged named ops.
- if (!isa<GenericOp, IndexedGenericOp>(op))
+ if (!isa<GenericOp>(op))
return failure();
// Associate each input to an equivalent "canonical" input that has the same
@@ -3290,10 +3168,6 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
// the value from the original op.
newLinalgOp.setNumInputs(canonicalInput.size());
- // linalg.indexed_generic payloads have additional arguments prepended to
- // the block arg list.
- int bbArgBaseOffset = newLinalgOp.getNumPayloadInductionVariables();
-
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp->getRegion(0).front();
@@ -3305,10 +3179,10 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
unsigned operandNumber = opOperand->getOperandNumber();
if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
- payload.getArgument(bbArgBaseOffset + operandNumber)
- .replaceAllUsesWith(payload.getArgument(
- bbArgBaseOffset + canonicalInputIndices[operandNumber]));
- payload.eraseArgument(bbArgBaseOffset + operandNumber);
+ payload.getArgument(operandNumber)
+ .replaceAllUsesWith(
+ payload.getArgument(canonicalInputIndices[operandNumber]));
+ payload.eraseArgument(operandNumber);
}
rewriter.replaceOp(op, newOp->getResults());
@@ -3316,7 +3190,7 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
}
};
-/// Remove generic/indexed_generic operations (on tensors) that are just copying
+/// Remove generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
@@ -3335,7 +3209,7 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
}
}
- if (!isa<GenericOp, IndexedGenericOp>(op))
+ if (!isa<GenericOp>(op))
return failure();
if (!op.hasTensorSemantics())
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 757f336a22c0b..13a03f601336c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -202,10 +202,6 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
LogicalResult
matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- // Canonicalize indexed generic operations before bufferization.
- if (isa<IndexedGenericOp>(op))
- return failure();
-
// GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
if (!op->hasAttr("operand_segment_sizes"))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 098442cf149e1..1a9767a556d15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -230,7 +230,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
// When the producer has index semantics, we have to transform the indices of
// the producer according to the tiling of the consumer, i.e. offset them by
// the values computed in `loopRanges`.
- assert(!isa<IndexedGenericOp>(producer) && "unexpected op");
if (producer.hasIndexSemantics()) {
assert(clonedOp->getNumRegions() == 1 &&
clonedOp->getRegion(0).getBlocks().size() == 1 &&
@@ -426,10 +425,6 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
if (!fusableDependence)
return llvm::None;
- // Canonicalize indexed generic ops before fusion.
- if (isa<IndexedGenericOp>(fusableDependence->getDependentOp()))
- return llvm::None;
-
LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
if (!producerOp)
return llvm::None;
@@ -507,10 +502,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
Optional<FusionInfo>
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand &consumerOpOperand) {
- // Canonicalize indexed generic ops before fusion.
- if (isa<IndexedGenericOp>(producerOpResult.getOwner()))
- return llvm::None;
-
auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
if (!producerOp)
return llvm::None;
@@ -766,9 +757,6 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
if (!fusableDependence)
continue;
- // Canonicalize indexed generic ops before fusion.
- if (isa<IndexedGenericOp>(fusableDependence->getDependentOp()))
- continue;
LinalgOp producerOp =
dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
if (!producerOp)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 5e76361324a08..bfe03b64c5bb2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1402,7 +1402,6 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
options.controlFoldingReshapesFn);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
- IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index b74e3829e1ad3..d5e619719fd7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -130,8 +130,8 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
- // No nothing to do for linalg.generic and linalg.indexed_generic.
- if (isa<GenericOp, IndexedGenericOp>(rootOp))
+ // No nothing to do for linalg.generic.
+ if (isa<GenericOp>(rootOp))
return failure();
GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 49870da402cf8..d8c930e4c28bf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -418,10 +418,6 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
AffineStoreOp, memref::StoreOp>::type;
- // Canonicalize indexed_generic operations before lowering them to loops.
- if (isa<IndexedGenericOp>(linalgOp))
- return llvm::None;
-
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
assert(linalgOp.hasBufferSemantics() &&
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index ab80be520c9c6..98c1454858d1d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -163,10 +163,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
if (llvm::all_of(tileSizes, isZero))
return llvm::None;
- // Canonicalize indexed generic operations before tiling.
- if (isa<IndexedGenericOp>(op))
- return llvm::None;
-
if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
// For conv op only support tiling along batch dimension (which is the first
// loop).
diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
index e314bc45744bb..73a1031aa68b6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
@@ -89,26 +89,3 @@ func @multiple_
diff erent_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor<?xf3
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
-
-// -----
-
-// Test case: linalg.indexed_generic.
-// Other than the payload argument handling, everything else is the same.
-
-#map = affine_map<(d0) -> (d0)>
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: @indexed_generic
-func @indexed_generic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32):
- // CHECK: addf %[[BBARG]], %[[BBARG]]
- %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
- ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>)
- outs(%arg0 : tensor<?xf32>) {
- ^bb0(%index: index, %arg1: f32, %arg2: f32, %arg3: f32):
- %1 = addf %arg1, %arg2 : f32
- linalg.yield %1 : f32
- } -> tensor<?xf32>
- return %0 : tensor<?xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 16895590a55fd..a3796c15ff21c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -842,39 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
// -----
-#map = affine_map<(d0, d1) -> (d0, d1)>
-
-func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
- linalg.indexed_generic {
- indexing_maps = [#map, #map],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : memref<?x?xindex>)
- outs(%arg1 : memref<?x?xindex>) {
- ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
- %0 = addi %arg4, %arg5 : index
- %1 = addi %0, %arg6 : index
- %2 = addi %1, %arg7 : index
- linalg.yield %2 : index
- }
- return
-}
-
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: func @indexed_generic
-// CHECK-NEXT: linalg.generic {
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME: ins(%[[ARG0:[A-Za-z0-9_]+]] : memref<?x?xindex>)
-// CHECK-SAME: outs(%[[ARG1:[A-Za-z0-9_]+]] : memref<?x?xindex>)
-// CHECK: ^bb0(%[[ARG2:[A-Za-z0-9_]+]]: index, %[[ARG3:[A-Za-z0-9_]+]]: index):
-// CHECK-NEXT: %[[IDX0:.+]] = linalg.index 0 : index
-// CHECK-NEXT: %[[IDX1:.+]] = linalg.index 1 : index
-// CHECK-NEXT: %[[SUM0:.+]] = addi %[[IDX0]], %[[IDX1]] : index
-// CHECK-NEXT: %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index
-// CHECK-NEXT: %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index
-// CHECK-NEXT: linalg.yield %[[SUM2]] : index
-
-// -----
-
func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = constant 0 : index
%cst = constant 0.0 : f32
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a9041e2203c87..aed7e080deaf6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -227,72 +227,6 @@ func @generic_scalar_operand_block_arg_type(%arg0: f32) {
// -----
-func @indexed_generic_block_arg_count(%arg0: memref<?xf32>) {
- // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
- linalg.indexed_generic {
- indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]}
- outs(%arg0 : memref<?xf32>) {
- ^bb(%f: f32):
- linalg.yield %f : f32
- }
-}
-
-// -----
-
-func @indexed_generic_block_induction_var_arg_type(%arg0: memref<?xf32>) {
- // expected-error @+1 {{op expected index block argument #0}}
- linalg.indexed_generic {
- indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]}
- outs(%arg0 : memref<?xf32>) {
- ^bb(%i: f64, %f: f32):
- linalg.yield %f: f32
- }
-}
-
-// -----
-
-func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
- // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element or self type of the corresponding operand ('f32')}}
- linalg.indexed_generic {
- indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]}
- outs(%arg0 : memref<?xf32>) {
- ^bb(%i: index, %f: i1):
- linalg.yield %i: index
- }
-}
-
-// -----
-
-func @indexed_generic_arg_count(%arg0: memref<f32>) {
- // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
- linalg.indexed_generic {
- indexing_maps = [ affine_map<()[] -> ()> ],
- iterator_types = []}
- outs(%arg0 : memref<f32>) {
- ^bb(%0: index, %1: f32):
- linalg.yield %1: f32
- }
- return
-}
-
-// -----
-
-func @indexed_generic_result_count(%arg0: memref<?xf32>) {
- // expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
- linalg.indexed_generic {
- indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]}
- outs(%arg0 : memref<?xf32>) {
- ^bb(%i: index, %val: f32):
- linalg.yield %val, %val: f32, f32
- }
-}
-
-// -----
-
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
linalg.generic {
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index b0954016cb75f..8fe0f95451dcd 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
-// TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
+// TODO: Re-enable LLVM lowering test.
//
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
@@ -457,43 +457,6 @@ func @generic_with_multiple_tensor_outputs(
// -----
-#accesses_2 = [
- affine_map<(i, j, k) -> (j, i)>,
- affine_map<(i, j, k) -> (i, k, i + j)>,
- affine_map<(i, j, k) -> (i, k, i + j)>
-]
-
-#trait_2 = {
- indexing_maps = #accesses_2,
- iterator_types = ["parallel", "parallel", "parallel"],
- library_call = "some_external_function_name_1"
-}
-
-func @indexed_generic_with_tensor_input_and_output(
- %arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
- -> (tensor<?x?x?xf32>) {
- %0 = linalg.indexed_generic #trait_2
- ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
- outs(%arg1 : tensor<?x?x?xf32>)
- attrs = {foo = 1} {
- ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32, %2: f32) :
- %f0 = constant 0.0 : f32
- linalg.yield %f0 : f32
- } -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
-}
-// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
-// CHECK: linalg.indexed_generic {
-// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<?x?x?xf32>)
-// CHECK-SAME: {foo = 1 : i64}
-// CHECK: -> tensor<?x?x?xf32>
-// CHECK: return {{.*}} : tensor<?x?x?xf32>
-
-// -----
-
#broadcast_access = [
affine_map<(i, j) -> ()>,
affine_map<(i, j) -> (i, j)>
@@ -516,17 +479,6 @@ func @generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tens
return %0 : tensor<3x4xf32>
}
-func @indexed_generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tensor<3x4xf32>)
-{
- %0 = linalg.indexed_generic #trait_broadcast
- ins(%arg0 : tensor<f32>)
- outs(%arg1 : tensor<3x4xf32>) {
- ^bb(%i: index, %j: index, %a: f32, %b: f32) :
- linalg.yield %a : f32
- } -> tensor<3x4xf32>
- return %0 : tensor<3x4xf32>
-}
-
// -----
@@ -569,29 +521,6 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
// CHECK: %{{.*}} = linalg.index 2 : index
// CHECK: linalg.yield %{{.*}} : f32
-func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
- %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.indexed_generic #trait_3
- ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
- outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
- attrs = {foo = 1} {
- ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
- linalg.yield %b : f32
- }
- return
-}
-// CHECK-LABEL: func @indexed_generic
-// CHECK: linalg.indexed_generic {
-// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
-// CHECK-SAME: library_call = "some_external_function_name_2"
-// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
-// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
-// CHECK-SAME: {foo = 1 : i64}
-// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
-// CHECK: linalg.yield %{{.*}} : f32
-// CHECK: }
-
// -----
func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,
More information about the Mlir-commits
mailing list