[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 8 09:00:21 PDT 2023


================
@@ -1467,80 +1490,97 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
                opFoldIsConstantValue(range.stride, 1);
       })) {
     return rewriter.notifyMatchFailure(
-        genericOp,
-        "expected all loop ranges to have zero start and unit stride");
+        op, "expected all loop ranges to have zero start and unit stride");
   }
 
   // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypesArray(), collapsingInfo);
+  SmallVector<utils::IteratorType> iteratorTypes =
+      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
   auto indexingMaps = llvm::to_vector(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+      llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
         return getCollapsedOpIndexingMap(map, collapsingInfo);
       }));
 
-  Location loc = genericOp->getLoc();
+  Location loc = op->getLoc();
 
   // Get the input operands.
-  auto inputOperands = llvm::to_vector(llvm::map_range(
-      genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
-        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+  auto inputOperands = llvm::to_vector(
+      llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
                                      rewriter);
       }));
 
   // Get the output operands and result types.
   SmallVector<Type> resultTypes;
   SmallVector<Value> outputOperands;
-  resultTypes.reserve(genericOp.getNumDpsInits());
-  outputOperands.reserve(genericOp.getNumDpsInits());
-  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
-    Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
-                                            collapsingInfo, rewriter);
+  resultTypes.reserve(op.getNumDpsInits());
+  outputOperands.reserve(op.getNumDpsInits());
+  for (OpOperand &output : op.getDpsInitsMutable()) {
+    Value newOutput =
+        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
     outputOperands.push_back(newOutput);
-    resultTypes.push_back(newOutput.getType());
+    // If the op has "buffer semantics", then the init operands are ranked
+    // memrefs and the op has no results.
+    if (!hasBufferSemantics)
+      resultTypes.push_back(newOutput.getType());
   }
 
   // Create the generic op.
-  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
-      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &genericOp->getRegion(0).front();
-  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
-  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
-                       collapsedOpBlock->getArguments());
-
-  if (collapsedGenericOp.hasIndexSemantics()) {
+  Operation *collapsedOp;
----------------
MaheshRavishankar wrote:

Could you move this into a utility function.

https://github.com/llvm/llvm-project/pull/68526


More information about the Mlir-commits mailing list