[Mlir-commits] [mlir] 5b3cb31 - [mlir][linalg] Purge linalg.indexed_generic.

Alexander Belyaev llvmlistbot at llvm.org
Thu Jun 17 05:46:12 PDT 2021


Author: Alexander Belyaev
Date: 2021-06-17T14:45:37+02:00
New Revision: 5b3cb31edbcf99ef15c2de2d29ad0ff9927ba136

URL: https://github.com/llvm/llvm-project/commit/5b3cb31edbcf99ef15c2de2d29ad0ff9927ba136
DIFF: https://github.com/llvm/llvm-project/commit/5b3cb31edbcf99ef15c2de2d29ad0ff9927ba136.diff

LOG: [mlir][linalg] Purge linalg.indexed_generic.

Differential Revision: https://reviews.llvm.org/D104449

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 54dda4c4fcb2b..a8e4bbdd69b35 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -32,7 +32,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
     // always be 0 for index-free linalg ops. For IndexedGeneric, this must be
     // equal to numLoops.
     unsigned getNumPayloadInductionVariables() {
-      return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
+      return 0;
     }
 
     // Return whether the op accesses the iteration indices.
@@ -671,140 +671,6 @@ def GenericOp : GenericOpBase<"generic"> {
   let hasFolder = 1;
 }
 
-/// GenericOp with Indexing (i.e. multi-for style in which the region is passed
-/// the enclosing loop induction variables)
-def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
-  let description = [{
-    Indexed Generic Linalg op form where the key properties of the computation
-    are specified as attributes. In pretty form, a `linalg.indexed_generic` op
-    is written as:
-
-      ```mlir
-      linalg.indexed_generic #trait_attribute
-          ins(%A, %B : memref<?x?xf32, stride_specification>,
-                       memref<?x?xf32, stride_specification>)
-          outs(%C : memref<?x?xf32, stride_specification>)
-          attrs = {other-optional-attributes}
-          {region}
-      ```
-
-    Where #trait_attributes is an alias of a dictionary attribute containing:
-      - doc [optional]: a documentation string
-      - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
-        and output view. Such AffineMapAttr specifies the mapping between the
-        loops and the indexing within each view.
-      - library_call [optional]: a StringAttr containing the name of an
-        external library function that the linalg.indexed_generic operation
-        maps to.  The external library is assumed to be dynamically linked and
-        no strong compile-time guarantees are provided. In the absence of such
-        a library call, linalg.indexed_generic will always lower to loops.
-      - iterator_types: an ArrayAttr they type of the enclosing loops; Each
-        element of the list represents and iterator of one of the following
-        types:
-          parallel, reduction, window
-
-    Example:
-    Defining a #matmul_trait attribute in MLIR can be done as follows:
-
-    ```mlir
-    #matmul_accesses = [
-      (m, n, k) -> (m, k),
-      (m, n, k) -> (k, n),
-      (m, n, k) -> (m, n)
-    ]
-    #matmul_trait = {
-      doc = "C(m, n) += A(m, k) * B(k, n)",
-      indexing_maps = #matmul_accesses,
-      library_call = "linalg_matmul",
-      iterator_types = ["parallel", "parallel", "reduction"]
-    }
-    ```
-
-    And can be reused in multiple places as:
-
-    ```mlir
-      linalg.indexed_generic #matmul_trait
-         ins(%A, %B : memref<?x?xf32, stride_specification>,
-                      memref<?x?xf32, stride_specification>)
-        outs(%C : memref<?x?xf32, stride_specification>) {
-      (%offset_m: index, %offset_n: index, %offset_k: index,
-       %a: f32, %b: f32, %c: f32) :
-        "some_optional_computation"(%offset_m, %offset_n, %offset_k)
-        %d = mulf %a, %b: f32
-        %e = addf %c, %d: f32
-        linalg_yield %e : f32
-    }
-    ```
-
-    This may lower to either:
-
-    ```mlir
-    call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
-      (index, index, index,
-       memref<?x?xf32, stride_specification>,
-       memref<?x?xf32, stride_specification>,
-       memref<?x?xf32, stride_specification>)
-      -> ()
-    ```
-
-    or IR resembling:
-
-    ```mlir
-    scf.for %m = %c0 to %M step %c1 {
-      scf.for %n = %c0 to %N step %c1 {
-        scf.for %k = %c0 to %K step %c1 {
-          %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
-          %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
-          %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
-          "some_optional_computation"(%m, %n, %k)
-          %d = mulf %a, %b: f32
-          %e = addf %c, %d: f32
-          store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
-        }
-      }
-    }
-    ```
-
-    To allow progressive lowering from the value world (a.k.a tensor values) to
-    the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
-    allows mixing tensors and buffers operands and tensor results.
-
-    ```mlir
-    %C = linalg.indexed_generic #trait_attribute
-       ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
-      outs(%C : tensor<?x?xf32>)
-      {other-optional-attributes}
-      {region_with_index_arguments}
-      -> (tensor<?x?xf32>)
-    ```
-  }];
-
-  let builders = [
-    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-      "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
-      "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
-      "StringRef":$libraryCall,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
-           "nullptr">)>,
-    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
-      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
-      "StringRef":$doc, "StringRef":$libraryCall,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
-           "nullptr">)>,
-    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-      "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
-      "ArrayRef<StringRef>":$iteratorTypes,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
-           "nullptr">)>,
-    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
-      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
-           "nullptr">)>
-  ];
-  let verifier = [{ return ::verify(*this); }];
-
-  let hasCanonicalizer = 1;
-}
 
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 28dd7bb860c78..6f422e5f629fb 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -100,10 +100,6 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
   if (isa<CopyOp>(op))
     return failure();
 
-  // Canonicalize indexed generic operations before library call conversion.
-  if (isa<IndexedGenericOp>(op))
-    return failure();
-
   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
   if (!libraryCallName)
     return failure();

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 985a9f7a09a2f..b05a1477982cd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -525,69 +525,8 @@ void GenericOp::build(
         /*doc=*/"",
         /*libraryCall=*/"", bodyBuild);
 }
-void IndexedGenericOp::build(
-    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
-    ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
-    function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
-        bodyBuild) {
-  build(builder, result, resultTensorTypes, inputs, outputs,
-        builder.getAffineMapArrayAttr(indexingMaps),
-        builder.getStrArrayAttr(iteratorTypes),
-        doc.empty() ? StringAttr() : builder.getStringAttr(doc),
-        libraryCall.empty() ? StringAttr()
-                            : builder.getStringAttr(libraryCall));
-  if (!bodyBuild)
-    return;
-
-  unsigned nLoops = iteratorTypes.size();
-  SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
-  for (ValueRange container : {inputs, outputs})
-    for (Value v : container)
-      blockArgTypes.push_back(getElementTypeOrSelf(v));
-
-  OpBuilder::InsertionGuard guard(builder);
-  auto &region = *result.regions.front();
-  Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
-  bodyBuild(builder, result.location,
-            bodyBlock->getArguments().take_front(nLoops),
-            bodyBlock->getArguments().drop_front(nLoops));
-}
-
-void IndexedGenericOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
-    ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
-    function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
-        bodyBuild) {
-  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
-        iteratorTypes, doc, libraryCall, bodyBuild);
-}
-
-void IndexedGenericOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
-    ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
-        bodyBuild) {
-  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
-        /*doc=*/"", /*libraryCall=*/"", bodyBuild);
-}
-
-void IndexedGenericOp::build(
-    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
-    ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
-        bodyBuild) {
-  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
-        iteratorTypes,
-        /*doc=*/"",
-        /*libraryCall=*/"", bodyBuild);
-}
 
-template <typename GenericOpType>
-static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
+static void print(OpAsmPrinter &p, GenericOp op) {
   p << op.getOperationName() << " ";
 
   // Print extra attributes.
@@ -628,12 +567,6 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
 }
 
-static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
-
-static void print(OpAsmPrinter &p, IndexedGenericOp op) {
-  printGenericOp(p, op);
-}
-
 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
   DictionaryAttr dictAttr;
   // Parse the core linalg traits that must check into a dictAttr.
@@ -704,15 +637,6 @@ void GenericOp::getEffects(
                         outputBuffers);
 }
 
-void IndexedGenericOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  SmallVector<Value> inputBuffers = getInputBufferOperands();
-  SmallVector<Value> outputBuffers = getOutputBufferOperands();
-  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
-                        outputBuffers);
-}
-
 template <typename GenericOpType>
 static LogicalResult verifyGenericOp(GenericOpType op) {
   return success();
@@ -720,52 +644,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
 
 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
 
-static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
-
-namespace {
-
-/// Replace indexed_generic ops by generic ops that access the iteration indices
-/// using index operation calls.
-struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
-  using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(IndexedGenericOp indexedOp,
-                                PatternRewriter &rewriter) const override {
-    // Replace all uses of the index block arguments.
-    BlockAndValueMapping bvm;
-    if (Block *body = indexedOp.getBody()) {
-      rewriter.setInsertionPointToStart(body);
-      for (const auto &en : llvm::enumerate(
-               body->getArguments().take_front(indexedOp.getNumLoops()))) {
-        Value index = rewriter.create<IndexOp>(indexedOp.getLoc(), en.index());
-        bvm.map(en.value(), index);
-      }
-    }
-
-    // Create a generic replacement operation and clone the body.
-    rewriter.setInsertionPointAfter(indexedOp);
-    SmallVector<Value> inputOperands = indexedOp.getInputOperands();
-    SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
-    SmallVector<StringRef> iterators = llvm::to_vector<4>(
-        indexedOp.iterator_types().getAsValueRange<StringAttr>());
-    GenericOp genericOp = rewriter.create<GenericOp>(
-        indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
-        outputOperands, indexedOp.getIndexingMaps(), iterators);
-    Region &genericRegion = genericOp.region();
-    Region &indexedRegion = indexedOp.region();
-    rewriter.cloneRegionBefore(indexedRegion, genericRegion,
-                               genericRegion.begin(), bvm);
-
-    rewriter.replaceOp(indexedOp, genericOp->getResults());
-    return success();
-  }
-};
-} // namespace
-
-void IndexedGenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                   MLIRContext *context) {
-  results.add<ConvertIndexedToGenericOp>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//
@@ -3230,7 +3108,7 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
                                 PatternRewriter &rewriter) const override {
     // This pattern reduces the number of arguments of an op, which breaks
     // the invariants of semantically charged named ops.
-    if (!isa<GenericOp, IndexedGenericOp>(op))
+    if (!isa<GenericOp>(op))
       return failure();
 
     // Associate each input to an equivalent "canonical" input that has the same
@@ -3290,10 +3168,6 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
     // the value from the original op.
     newLinalgOp.setNumInputs(canonicalInput.size());
 
-    // linalg.indexed_generic payloads have additional arguments prepended to
-    // the block arg list.
-    int bbArgBaseOffset = newLinalgOp.getNumPayloadInductionVariables();
-
     // Repair the payload entry block by RAUW'ing redundant arguments and
     // erasing them.
     Block &payload = newOp->getRegion(0).front();
@@ -3305,10 +3179,10 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
       unsigned operandNumber = opOperand->getOperandNumber();
       if (canonicalInputIndices[operandNumber] == operandNumber)
         continue;
-      payload.getArgument(bbArgBaseOffset + operandNumber)
-          .replaceAllUsesWith(payload.getArgument(
-              bbArgBaseOffset + canonicalInputIndices[operandNumber]));
-      payload.eraseArgument(bbArgBaseOffset + operandNumber);
+      payload.getArgument(operandNumber)
+          .replaceAllUsesWith(
+              payload.getArgument(canonicalInputIndices[operandNumber]));
+      payload.eraseArgument(operandNumber);
     }
 
     rewriter.replaceOp(op, newOp->getResults());
@@ -3316,7 +3190,7 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
   }
 };
 
-/// Remove generic/indexed_generic operations (on tensors) that are just copying
+/// Remove generic operations (on tensors) that are just copying
 /// the values from inputs to the results. Requirements are
 /// 1) All iterator types are parallel
 /// 2) The body contains just a yield operation with the yielded values being
@@ -3335,7 +3209,7 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
       }
     }
 
-    if (!isa<GenericOp, IndexedGenericOp>(op))
+    if (!isa<GenericOp>(op))
       return failure();
     if (!op.hasTensorSemantics())
       return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 757f336a22c0b..13a03f601336c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -202,10 +202,6 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
   LogicalResult
   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    // Canonicalize indexed generic operations before bufferization.
-    if (isa<IndexedGenericOp>(op))
-      return failure();
-
     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
     if (!op->hasAttr("operand_segment_sizes"))
       return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 098442cf149e1..1a9767a556d15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -230,7 +230,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   // When the producer has index semantics, we have to transform the indices of
   // the producer according to the tiling of the consumer, i.e. offset them by
   // the values computed in `loopRanges`.
-  assert(!isa<IndexedGenericOp>(producer) && "unexpected op");
   if (producer.hasIndexSemantics()) {
     assert(clonedOp->getNumRegions() == 1 &&
            clonedOp->getRegion(0).getBlocks().size() == 1 &&
@@ -426,10 +425,6 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
   if (!fusableDependence)
     return llvm::None;
 
-  // Canonicalize indexed generic ops before fusion.
-  if (isa<IndexedGenericOp>(fusableDependence->getDependentOp()))
-    return llvm::None;
-
   LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
   if (!producerOp)
     return llvm::None;
@@ -507,10 +502,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
 Optional<FusionInfo>
 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
                                    OpOperand &consumerOpOperand) {
-  // Canonicalize indexed generic ops before fusion.
-  if (isa<IndexedGenericOp>(producerOpResult.getOwner()))
-    return llvm::None;
-
   auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
   if (!producerOp)
     return llvm::None;
@@ -766,9 +757,6 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
           fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
       if (!fusableDependence)
         continue;
-      // Canonicalize indexed generic ops before fusion.
-      if (isa<IndexedGenericOp>(fusableDependence->getDependentOp()))
-        continue;
       LinalgOp producerOp =
           dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
       if (!producerOp)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 5e76361324a08..bfe03b64c5bb2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1402,7 +1402,6 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
                                             options.controlFoldingReshapesFn);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   GenericOp::getCanonicalizationPatterns(patterns, context);
-  IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
   context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index b74e3829e1ad3..d5e619719fd7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -130,8 +130,8 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
       return failure();
 
-    // No nothing to do for linalg.generic and linalg.indexed_generic.
-    if (isa<GenericOp, IndexedGenericOp>(rootOp))
+    // No nothing to do for linalg.generic.
+    if (isa<GenericOp>(rootOp))
       return failure();
 
     GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 49870da402cf8..d8c930e4c28bf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -418,10 +418,6 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
                                 AffineStoreOp, memref::StoreOp>::type;
 
-  // Canonicalize indexed_generic operations before lowering them to loops.
-  if (isa<IndexedGenericOp>(linalgOp))
-    return llvm::None;
-
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (which is asserted in the inverse calculation).
   assert(linalgOp.hasBufferSemantics() &&

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index ab80be520c9c6..98c1454858d1d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -163,10 +163,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
   if (llvm::all_of(tileSizes, isZero))
     return llvm::None;
 
-  // Canonicalize indexed generic operations before tiling.
-  if (isa<IndexedGenericOp>(op))
-    return llvm::None;
-
   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
     // For conv op only support tiling along batch dimension (which is the first
     // loop).

diff  --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
index e314bc45744bb..73a1031aa68b6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
@@ -89,26 +89,3 @@ func @multiple_
diff erent_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor<?xf3
   } -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
-
-// -----
-
-// Test case: linalg.indexed_generic.
-// Other than the payload argument handling, everything else is the same.
-
-#map = affine_map<(d0) -> (d0)>
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: @indexed_generic
-func @indexed_generic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
-  // CHECK: linalg.generic
-  // CHECK:   ^bb0(%[[BBARG:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32):
-  // CHECK:     addf %[[BBARG]], %[[BBARG]]
-  %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
-      ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>)
-     outs(%arg0 : tensor<?xf32>) {
-  ^bb0(%index: index, %arg1: f32, %arg2: f32, %arg3: f32):
-    %1 = addf %arg1, %arg2 : f32
-    linalg.yield %1 : f32
-  } -> tensor<?xf32>
-  return %0 : tensor<?xf32>
-}

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 16895590a55fd..a3796c15ff21c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -842,39 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
 
 // -----
 
-#map = affine_map<(d0, d1) -> (d0, d1)>
-
-func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
-  linalg.indexed_generic {
-      indexing_maps = [#map, #map],
-      iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : memref<?x?xindex>)
-   outs(%arg1 : memref<?x?xindex>) {
-  ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
-    %0 = addi %arg4, %arg5 : index
-    %1 = addi %0, %arg6 : index
-    %2 = addi %1, %arg7 : index
-    linalg.yield %2 : index
-  }
-  return
-}
-
-//      CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//      CHECK: func @indexed_generic
-// CHECK-NEXT:   linalg.generic {
-// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME:      ins(%[[ARG0:[A-Za-z0-9_]+]] : memref<?x?xindex>)
-// CHECK-SAME:     outs(%[[ARG1:[A-Za-z0-9_]+]] : memref<?x?xindex>)
-//      CHECK:   ^bb0(%[[ARG2:[A-Za-z0-9_]+]]: index, %[[ARG3:[A-Za-z0-9_]+]]: index):
-// CHECK-NEXT:     %[[IDX0:.+]] = linalg.index 0 : index
-// CHECK-NEXT:     %[[IDX1:.+]] = linalg.index 1 : index
-// CHECK-NEXT:     %[[SUM0:.+]] = addi %[[IDX0]], %[[IDX1]] : index
-// CHECK-NEXT:     %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index
-// CHECK-NEXT:     %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index
-// CHECK-NEXT:     linalg.yield %[[SUM2]] : index
-
-// -----
-
 func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
   %c0 = constant 0 : index
   %cst = constant 0.0 : f32

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a9041e2203c87..aed7e080deaf6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -227,72 +227,6 @@ func @generic_scalar_operand_block_arg_type(%arg0: f32) {
 
 // -----
 
-func @indexed_generic_block_arg_count(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
-  linalg.indexed_generic {
-    indexing_maps =  [ affine_map<(i) -> (i)> ],
-    iterator_types = ["parallel"]}
-      outs(%arg0 : memref<?xf32>) {
-    ^bb(%f: f32):
-      linalg.yield %f : f32
-  }
-}
-
-// -----
-
-func @indexed_generic_block_induction_var_arg_type(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{op expected index block argument #0}}
-  linalg.indexed_generic {
-    indexing_maps =  [ affine_map<(d0) -> (d0)> ],
-    iterator_types = ["parallel"]}
-      outs(%arg0 : memref<?xf32>) {
-    ^bb(%i: f64, %f: f32):
-    linalg.yield %f: f32
-  }
-}
-
-// -----
-
-func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{expected type of bb argument #1 ('i1') to match element or self type of the corresponding operand ('f32')}}
-  linalg.indexed_generic {
-    indexing_maps =  [ affine_map<(d0) -> (d0)> ],
-    iterator_types = ["parallel"]}
-      outs(%arg0 : memref<?xf32>) {
-    ^bb(%i: index, %f: i1):
-    linalg.yield %i: index
-  }
-}
-
-// -----
-
-func @indexed_generic_arg_count(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
-  linalg.indexed_generic {
-    indexing_maps =  [ affine_map<()[] -> ()> ],
-    iterator_types = []}
-      outs(%arg0 : memref<f32>) {
-    ^bb(%0: index, %1: f32):
-      linalg.yield %1: f32
-  }
-  return
-}
-
-// -----
-
-func @indexed_generic_result_count(%arg0: memref<?xf32>) {
-  // expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
-  linalg.indexed_generic {
-    indexing_maps =  [ affine_map<(d0) -> (d0)> ],
-    iterator_types = ["parallel"]}
-      outs(%arg0 : memref<?xf32>) {
-    ^bb(%i: index, %val: f32):
-      linalg.yield %val, %val: f32, f32
-  }
-}
-
-// -----
-
 func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index b0954016cb75f..8fe0f95451dcd 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
 
-// TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
+// TODO: Re-enable LLVM lowering test.
 //
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
@@ -457,43 +457,6 @@ func @generic_with_multiple_tensor_outputs(
 
 // -----
 
-#accesses_2 = [
-  affine_map<(i, j, k) -> (j, i)>,
-  affine_map<(i, j, k) -> (i, k, i + j)>,
-  affine_map<(i, j, k) -> (i, k, i + j)>
-]
-
-#trait_2 = {
-  indexing_maps = #accesses_2,
-  iterator_types = ["parallel", "parallel", "parallel"],
-  library_call = "some_external_function_name_1"
-}
-
-func @indexed_generic_with_tensor_input_and_output(
-    %arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-    -> (tensor<?x?x?xf32>) {
-  %0 = linalg.indexed_generic #trait_2
-       ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
-      outs(%arg1 : tensor<?x?x?xf32>)
-      attrs = {foo = 1} {
-    ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32, %2: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
-  } -> tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
-}
-// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
-//       CHECK:   linalg.indexed_generic {
-//  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-//  CHECK-SAME:     library_call = "some_external_function_name_1"}
-//  CHECK-SAME:      ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
-//  CHECK-SAME:     outs({{.*}} : tensor<?x?x?xf32>)
-//  CHECK-SAME:     {foo = 1 : i64}
-//       CHECK:     -> tensor<?x?x?xf32>
-//       CHECK:   return {{.*}} : tensor<?x?x?xf32>
-
-// -----
-
 #broadcast_access = [
   affine_map<(i, j) -> ()>,
   affine_map<(i, j) -> (i, j)>
@@ -516,17 +479,6 @@ func @generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tens
   return %0 : tensor<3x4xf32>
 }
 
-func @indexed_generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tensor<3x4xf32>)
-{
-  %0 = linalg.indexed_generic #trait_broadcast
-       ins(%arg0 : tensor<f32>)
-      outs(%arg1 : tensor<3x4xf32>) {
-    ^bb(%i: index, %j: index, %a: f32, %b: f32) :
-      linalg.yield %a : f32
-  } -> tensor<3x4xf32>
-  return %0 : tensor<3x4xf32>
-}
-
 // -----
 
 
@@ -569,29 +521,6 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
 //       CHECK:    %{{.*}} = linalg.index 2 : index
 //       CHECK:    linalg.yield %{{.*}} : f32
 
-func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
-                      %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.indexed_generic #trait_3
-       ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
-      outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
-      attrs = {foo = 1} {
-    ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
-      linalg.yield %b : f32
-  }
-  return
-}
-// CHECK-LABEL: func @indexed_generic
-//       CHECK:   linalg.indexed_generic {
-//  CHECK-SAME:     indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
-//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"],
-//  CHECK-SAME:     library_call = "some_external_function_name_2"
-//  CHECK-SAME:      ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
-//  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
-//  CHECK-SAME:     {foo = 1 : i64}
-//       CHECK:    ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
-//       CHECK:      linalg.yield %{{.*}} : f32
-//       CHECK:    }
-
 // -----
 
 func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,


        


More information about the Mlir-commits mailing list