[Mlir-commits] [mlir] 78f37b7 - [mlir][Linalg] Miscalleneous enhancements to cover more fusion cases.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 26 16:17:44 PDT 2020


Author: MaheshRavishankar
Date: 2020-10-26T16:17:24-07:00
New Revision: 78f37b74da60ccdca200e457df195d58d76b3b8f

URL: https://github.com/llvm/llvm-project/commit/78f37b74da60ccdca200e457df195d58d76b3b8f
DIFF: https://github.com/llvm/llvm-project/commit/78f37b74da60ccdca200e457df195d58d76b3b8f.diff

LOG: [mlir][Linalg] Miscalleneous enhancements to cover more fusion cases.

Adds support for
- Dropping unit dimension loops for indexed_generic ops.
- Folding consecutive folding (or expanding) reshapes when the result
  (or src) is a scalar.
- Fixes to indexed_generic -> generic fusion when zero-dim tensors are
  involved.

Differential Revision: https://reviews.llvm.org/D90118

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cd471d5b1648..5e3ba1f95d9b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -461,6 +461,10 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
 static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
                                            ArrayRef<AffineMap> mapsConsumer,
                                            MLIRContext *context) {
+  // Handle the corner case of the result being a rank 0 shaped type. Return an
+  // emtpy ArrayAttr.
+  if (mapsConsumer.empty() && !mapsProducer.empty())
+    return ArrayAttr::get(ArrayRef<Attribute>(), context);
   if (mapsProducer.empty() || mapsConsumer.empty() ||
       mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
       mapsProducer.size() != mapsConsumer[0].getNumDims())
@@ -500,8 +504,7 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     ShapedType intermediateType,
                                     ShapedType smallerType) -> bool {
       return largerType.getRank() > intermediateType.getRank() &&
-             intermediateType.getRank() > smallerType.getRank() &&
-             smallerType.getRank() > 0;
+             intermediateType.getRank() > smallerType.getRank();
     };
     // Check if producer and consumer are both expanding dims.
     if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 611c938ab542..03fdfd4555f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -26,6 +26,8 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
+#include <set>
+
 #define DEBUG_TYPE "linalg-drop-unit-dims"
 
 using namespace mlir;
@@ -145,15 +147,42 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
       context);
 }
 
+/// Modify the region of indexed generic op to drop arguments corresponding to
+/// loops that are unit trip count.
+template <typename OpTy>
+static LogicalResult
+replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
+                               PatternRewriter &rewriterp) {
+  return success();
+}
+
+template <>
+LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
+    IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
+    PatternRewriter &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block *entryBlock = &op.getOperation()->getRegion(0).front();
+  rewriter.setInsertionPointToStart(entryBlock);
+  Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
+  for (unsigned unitDimLoop : unitDims) {
+    entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
+  }
+  std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
+  for (unsigned i : llvm::reverse(orderedUnitDims))
+    entryBlock->eraseArgument(i);
+  return success();
+}
+
 namespace {
 /// Pattern to fold unit-trip count loops in GenericOps.
 // TODO: Generalize this to indexed-generic as well by modifying the region args
 // as well.
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+template <typename GenericOpTy>
+struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
+  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
+    SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
     if (indexingMaps.empty())
       return failure();
 
@@ -164,10 +193,10 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
     if (!invertedMap)
       return failure();
     SmallVector<int64_t, 4> dims;
-    for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
+    for (ShapedType shapedType : op.getInputOutputShapedTypes())
       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
     DenseSet<unsigned> unitDims;
-    ArrayAttr iteratorTypes = genericOp.iterator_types();
+    ArrayAttr iteratorTypes = op.iterator_types();
     for (auto expr : enumerate(invertedMap.getResults())) {
       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
         if (dims[dimExpr.getPosition()] == 1 &&
@@ -183,7 +212,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
     ArrayAttr newIndexingMapAttr =
         replaceUnitDims(unitDims, indexingMaps, context);
     if (!newIndexingMapAttr)
-      return genericOp.emitError("unable to compute modified indexing_maps");
+      return op.emitError("unable to compute modified indexing_maps");
 
     // Compute the iterator types of the modified op by dropping the one-trip
     // count loops.
@@ -193,10 +222,11 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
         newIteratorTypes.push_back(attr.value());
     }
 
-    rewriter.startRootUpdate(genericOp);
-    genericOp.indexing_mapsAttr(newIndexingMapAttr);
-    genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
-    rewriter.finalizeRootUpdate(genericOp);
+    rewriter.startRootUpdate(op);
+    op.indexing_mapsAttr(newIndexingMapAttr);
+    op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+    replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
+    rewriter.finalizeRootUpdate(op);
     return success();
   }
 };
@@ -263,25 +293,27 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
 namespace {
 
 /// Pattern to replace tensors operands/results that are unit extents.
-struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+template <typename GenericOpTy>
+struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
+  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
     // TODO: support init_tensors and reductions.
-    if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
+    if (!op.hasTensorSemantics() || !op.init_tensors().empty())
       return failure();
 
     MLIRContext *context = rewriter.getContext();
-    Location loc = genericOp.getLoc();
+    Location loc = op.getLoc();
 
     SmallVector<AffineMap, 4> newIndexingMaps;
     SmallVector<ArrayAttr, 4> reassociationMaps;
     SmallVector<ShapedType, 4> newInputOutputTypes;
     bool doCanonicalization = false;
-    for (auto it : llvm::zip(genericOp.getIndexingMaps(),
-                             genericOp.getInputOutputShapedTypes())) {
+    for (auto it :
+         llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
       auto replacementInfo = replaceUnitExtents(
-          std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
+          std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
+          context);
       reassociationMaps.push_back(replacementInfo.reassociation);
       newIndexingMaps.push_back(replacementInfo.indexMap);
       newInputOutputTypes.push_back(replacementInfo.type);
@@ -313,24 +345,23 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
       return res;
     };
 
-    SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
+    SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
     SmallVector<Value, 4> newOutputBuffers =
-        insertReshapes(genericOp.output_buffers());
-    SmallVector<Value, 4> newInitTensors =
-        insertReshapes(genericOp.init_tensors());
+        insertReshapes(op.output_buffers());
+    SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());
 
     // If any result type change, insert a reshape to convert from the original
     // type to the new type.
     SmallVector<Type, 4> resultTypes;
-    resultTypes.reserve(genericOp.getNumResults());
-    for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
-      resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
-    GenericOp replacementOp = rewriter.create<GenericOp>(
+    resultTypes.reserve(op.getNumResults());
+    for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
+      resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
+    GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
         loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
         newIndexingMaps,
         llvm::to_vector<4>(
-            genericOp.iterator_types().getAsValueRange<StringAttr>()));
-    rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
+            op.iterator_types().template getAsValueRange<StringAttr>()));
+    rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
                                 replacementOp.region().begin());
 
     // If any result tensor has a modified shape, then add reshape to recover
@@ -338,16 +369,16 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     SmallVector<Value, 4> resultReplacements;
     for (auto result : llvm::enumerate(replacementOp.getResults())) {
       unsigned index = result.index() + replacementOp.getNumOperands();
-      RankedTensorType origResultType = genericOp.getResult(result.index())
+      RankedTensorType origResultType = op.getResult(result.index())
                                             .getType()
-                                            .cast<RankedTensorType>();
+                                            .template cast<RankedTensorType>();
       if (origResultType != result.value().getType())
         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
             loc, origResultType, result.value(), reassociationMaps[index]));
       else
         resultReplacements.push_back(result.value());
     }
-    rewriter.replaceOp(genericOp, resultReplacements);
+    rewriter.replaceOp(op, resultReplacements);
     return success();
   }
 };
@@ -467,7 +498,10 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
 /// broadcasting.
 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns) {
-  patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
+  patterns
+      .insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
+              ReplaceUnitExtentTensors<GenericOp>,
+              ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
   patterns.insert<FoldReshapeOpWithUnitExtent>(context);
 }
@@ -481,7 +515,8 @@ struct LinalgFoldUnitExtentDimsPass
     FuncOp funcOp = getFunction();
     MLIRContext *context = funcOp.getContext();
     if (foldOneTripLoopsOnly)
-      patterns.insert<FoldUnitDimLoops>(context);
+      patterns.insert<FoldUnitDimLoops<GenericOp>,
+                      FoldUnitDimLoops<IndexedGenericOp>>(context);
     else
       populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
     applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 52fcd54e13b9..cf56b0e551a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -109,13 +109,19 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
   // consumer's operand.
   // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
   // generic op. In this case, there are no indices in block arguments.
-  unsigned numProducerIndices =
-      isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
-  unsigned numConsumerIndices =
-      isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
+  unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
+                                    ? producer.getNumLoops()
+                                    : 0;
+  unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
+                                    ? consumer.getNumLoops()
+                                    : 0;
+  unsigned numFusedOpIndices =
+      (isa<IndexedGenericOp>(producer.getOperation()) ||
+       isa<IndexedGenericOp>(consumer.getOperation()))
+          ? std::max(producer.getNumLoops(), consumer.getNumLoops())
+          : 0;
   // Firstly, add all the indices to the block arguments.
-  for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
-       i < e; ++i)
+  for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
     fusedBlock->addArgument(rewriter.getIndexType());
   // Map the arguments for the unmodified args from the consumer.
   for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
@@ -129,7 +135,7 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
           auto newIndex = rewriter.create<mlir::AffineApplyOp>(
               producer.getLoc(),
               consumerToProducerLoopsMap.getSubMap(producerArg.index()),
-              fusedBlock->getArguments().take_front(nloops));
+              fusedBlock->getArguments().take_front(numFusedOpIndices));
           mapper.map(producerArg.value(), newIndex);
         } else {
           mapper.map(producerArg.value(),

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index cf86a97f4fcd..d68aff4d270c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -43,6 +43,34 @@ func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf3
 
 // -----
 
+// -----
+
+func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
+                                             -> tensor<f32> {
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+       tensor<1x1x1xf32> into tensor<1xf32>
+  %1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor<f32>
+  return %1 : tensor<f32>
+}
+// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
+//       CHECK:   linalg.tensor_reshape %{{.*}} []
+//  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
+
+// -----
+
+func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
+                                             -> memref<f32> {
+  %0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+       memref<1x1x1xf32> into memref<1xf32>
+  %1 = linalg.reshape %0 [] : memref<1xf32> into memref<f32>
+  return %1 : memref<f32>
+}
+// CHECK-LABEL: collapsing_memref_reshapes_to_zero
+//       CHECK:   linalg.reshape %{{.*}} []
+//  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
+
+// -----
+
 func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
 {
   %0 = linalg.tensor_reshape %arg0
@@ -106,6 +134,33 @@ func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32
 
 // -----
 
+func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
+                                             -> tensor<1x1x1xf32> {
+  %0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+       tensor<1xf32> into tensor<1x1x1xf32>
+  return %1 : tensor<1x1x1xf32>
+}
+// CHECK-LABEL: expanding_tensor_reshapes_to_zero
+//       CHECK:   linalg.tensor_reshape %{{.*}} []
+//  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
+
+// -----
+
+func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
+                                             -> memref<1x1x1xf32> {
+  %0 = linalg.reshape %arg0 [] : memref<f32> into memref<1xf32>
+  %1 = linalg.reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+       memref<1xf32> into memref<1x1x1xf32>
+  return %1 : memref<1x1x1xf32>
+}
+// CHECK-LABEL: expanding_memref_reshapes_to_zero
+//       CHECK:   linalg.reshape %{{.*}} []
+//  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
+
+// -----
+
 func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
 {
   %0 = linalg.tensor_reshape %arg0

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 1793d2b59b70..e04d03b4e493 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -36,6 +36,47 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
 
 // -----
 
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops_indexed_generic
+  (%arg0 : tensor<?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
+{
+  %0 = linalg.indexed_generic #trait
+    ins(%arg0 : tensor<?x1x?xi32>) {
+       ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
+            %arg5 : index, %arg6 : i32) :
+	 %1 = addi %arg1, %arg2 : index
+	 %2 = addi %1, %arg3 : index
+	 %3 = addi %2, %arg4 : index
+	 %4 = addi %3, %arg5 : index
+	 %5 = index_cast %4 : index to i32
+	 %6 = addi %5, %arg6 : i32
+         linalg.yield %6 : i32
+       } -> tensor<?x1x?x1x?xi32>
+  return %0 : tensor<?x1x?x1x?xi32>
+}
+// CHECK-LABEL: func @drop_one_trip_loops_indexed_generic
+//       CHECK:   linalg.indexed_generic
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32)
+//       CHECK:     %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]]
+//       CHECK:     %[[T4:.+]] = addi %[[T3]], %[[ARG3]]
+//       CHECK:     %[[T5:.+]] = index_cast %[[T4]] : index to i32
+//       CHECK:     %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
+//       CHECK:     linalg.yield %[[T6]] : i32
+
+// -----
+
 #map0 = affine_map<(i, j) -> (i, j)>
 #access = [#map0, #map0]
 #trait = {
@@ -62,6 +103,35 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
 
 // -----
 
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops_indexed_generic
+  (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>
+{
+  %0 = linalg.indexed_generic #trait
+    ins(%arg0 : tensor<1x1xi32>) {
+       ^bb0(%arg1 : index, %arg2 : index, %arg3: i32) :
+         %1 = addi %arg1, %arg2 : index
+	 %2 = index_cast %1 : index to i32
+	 %3 = addi %2, %arg3 : i32
+         linalg.yield %3 : i32
+       } -> tensor<1x1xi32>
+  return %0 : tensor<1x1xi32>
+}
+
+// CHECK-LABEL: func @drop_all_loops_indexed_generic
+//       CHECK:   linalg.indexed_generic
+//       CHECK:   ^{{.+}}(%[[ARG1:.+]]: i32)
+//       CHECK:     linalg.yield %[[ARG1]] : i32
+
+// -----
+
 #accesses = [
   affine_map<(d0) -> (0, d0)>,
   affine_map<(d0) -> (d0)>

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 40ef68d870ea..54d8bef9caf3 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -381,3 +381,43 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
 //      CHECK:   %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
 //      CHECK:   linalg.yield %[[VAL4]] : i32
 //   CHECK-NOT: linalg.indexed_generic
+
+// -----
+
+func @scalar_indexed_generic_fusion
+  (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
+{
+  %c0 = constant 0 : index  
+  %cst = constant dense<1.000000e+00> : tensor<10xf32>
+  %0 = linalg.indexed_generic
+    {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+     iterator_types = []}
+    ins(%arg1 : tensor<i32>) {
+    ^bb0(%arg2: i32):  // no predecessors
+      %3 = index_cast %arg2 : i32 to index
+      %4 = extract_element %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
+      linalg.yield %4 : f32
+    } -> tensor<f32>
+  %1 = linalg.generic
+   {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]}
+    ins(%0, %cst : tensor<f32>, tensor<10xf32>) {
+    ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+      %3 = mulf %arg2, %arg3 : f32
+      linalg.yield %3 : f32
+    } -> tensor<10xf32>
+  return %1 : tensor<10xf32>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+//       CHECK: func @scalar_indexed_generic_fusion
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
+//       CHECK:   %[[T0:.+]] = linalg.indexed_generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel"]
+//  CHECK-SAME:     ins(%[[ARG1]] : tensor<i32>)
+//       CHECK:     extract_element %[[ARG0]]
+//       CHECK:     linalg.yield
+//       CHECK   return %[[T0]]
\ No newline at end of file


        


More information about the Mlir-commits mailing list