[Mlir-commits] [mlir] a2a4bc5 - [mlir][linalg] All StructuredOp parameters are inputs or outputs.

Tobias Gysi llvmlistbot at llvm.org
Tue Jun 29 00:46:51 PDT 2021


Author: Tobias Gysi
Date: 2021-06-29T07:45:50Z
New Revision: a2a4bc561ddf61bd5104674072c79fede3380ab1

URL: https://github.com/llvm/llvm-project/commit/a2a4bc561ddf61bd5104674072c79fede3380ab1
DIFF: https://github.com/llvm/llvm-project/commit/a2a4bc561ddf61bd5104674072c79fede3380ab1.diff

LOG: [mlir][linalg] All StructuredOp parameters are inputs or outputs.

Adapt the StructuredOp verifier to ensure all operands are either in the input or the output group. The change is possible after adding support for scalar input operands (https://reviews.llvm.org/D104220).

Differential Revision: https://reviews.llvm.org/D104783

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index ad91e23607141..e1f096d194b2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -253,7 +253,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumInputs() + getNumOutputs();
+        return this->getOperation()->getNumOperands();
       }]
     >,
     //===------------------------------------------------------------------===//
@@ -346,8 +346,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         result.reserve(numOutputs);
         llvm::transform(
           this->getOperation()->getOpOperands()
-            .drop_front(getNumInputs())
-            .take_front(numOutputs),
+            .take_back(numOutputs),
           std::back_inserter(result),
           [](OpOperand &opOperand) { return &opOperand; });
         return result;
@@ -458,8 +457,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         OpOperandVector result;
         result.reserve(numInputsAndOutputs);
         llvm::transform(
-          this->getOperation()->getOpOperands()
-            .take_front(numInputsAndOutputs),
+          this->getOperation()->getOpOperands(),
           std::back_inserter(result),
           [](OpOperand &opOperand) { return &opOperand; });
         return result;
@@ -928,22 +926,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     /// `createFlatListOfOperandStaticDims`.
     SmallVector<int64_t, 4> computeStaticLoopSizes();
 
-    /// Returns all the operands past the inputs, output_buffers and
-    /// init_tensors operands. Asserts that these operands are value types to
-    /// allow transformations like tiling to just use the values when cloning
-    /// `linalgOp`.
-    Operation::operand_range getAssumedNonShapedOperands() {
-      Operation::operand_range res{
-        getOperation()->getOperands().begin() + getNumInputsAndOutputs(),
-        getOperation()->getOperands().end()};
-      for (Type t : TypeRange{res}) {
-        (void)t;
-        assert((t.isSignlessIntOrIndexOrFloat() || t.template isa<VectorType>())
-               &&"expected scalar or vector type");
-      }
-      return res;
-    }
-
     /// Returns the value that expresses the shape of the output in terms of
     /// shape of the input operands where possible
     LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 45a9f8eb15c7e..e83c62425af4d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -318,14 +318,15 @@ LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-  // Expect at least one input/output operand.
+  // Expect at least one output operand.
   // This means an op that constructs a tensor out of indices cannot be a
   // LinalgOp at the moment. For now this will have to be a special op until we
   // have output shape operands that are not tensors.
-  int64_t numInputsAndOutputs = linalgOp.getNumInputsAndOutputs();
-  if (numInputsAndOutputs == 0)
-    return op->emitOpError("expected at least one input/output operand");
-  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, numInputsAndOutputs)))
+  int64_t numInputs = linalgOp.getNumInputs();
+  int64_t numOutputs = linalgOp.getNumOutputs();
+  if (numOutputs == 0)
+    return op->emitOpError("expected at least one output operand");
+  if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
     return failure();
   // Should have at least one output tensor per result tensor.
   // Can also have outbut buffers that do not correspond to results.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 11cb3e15c0e0c..f4524f19f3f14 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3038,8 +3038,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
                                  : opOperand->get());
       newResultTypes.push_back(newOperands.back().getType());
     }
-    auto extraOperands = op.getAssumedNonShapedOperands();
-    newOperands.append(extraOperands.begin(), extraOperands.end());
     // Clone op.
     Operation *newOp =
         op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
@@ -3109,7 +3107,6 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
         newOperands.push_back(opOperand->get());
     SmallVector<Value> outputOperands = op.getOutputOperands();
     llvm::append_range(newOperands, outputOperands);
-    llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
 
     // Repair the indexing maps by filtering out the ones that have been
     // eliminated.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 414aa632d4e86..fba709a871525 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -119,8 +119,6 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
   SmallVector<Value, 8> newOperands = inputs;
   newOperands.append(outputs.begin(), outputs.end());
-  auto otherOperands = linalgOp.getAssumedNonShapedOperands();
-  newOperands.append(otherOperands.begin(), otherOperands.end());
   linalgOp.clone(rewriter, linalgOp.getLoc(),
                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
   // Replace the results of the old op with the new output buffers.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index c951e70f18d83..287d2d47ca7fe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1241,8 +1241,6 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   // Clone the newly bufferized op.
   SmallVector<Value> newOperands = newInputBuffers;
   newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
-  auto otherOperands = op.getAssumedNonShapedOperands();
-  newOperands.append(otherOperands.begin(), otherOperands.end());
   op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
 
   // Replace the results of the old op with the new output buffers.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 0ff0594168a0b..d5964951f3c36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -205,10 +205,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
                                       getTiledOperands(b, producer), ivs,
                                       tileSizes, sizeBounds));
 
-  // Append the other operands.
-  auto operands = producer.getAssumedNonShapedOperands();
-  clonedShapes.append(operands.begin(), operands.end());
-
   // Iterate over the results in order.
   // Extract the subtensor type from the linearized range.
   // Since we do not enforce any canonicalizations on the fly, this is always

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index a9366d1a271d3..b6420f7b104bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -242,8 +242,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
         b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
-    auto nonShapedOperands = op.getAssumedNonShapedOperands();
-    tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end());
 
     // TODO: use an interface/adaptor to avoid leaking position in
     // `tiledOperands`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 79335c3629d70..f1c8a6f7b0fd6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -190,8 +190,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
   // Clone `opToPad` to operate on the statically padded shapes.
   auto resultTensorTypes =
       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
-  ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
-  newOperands.append(otherOperands.begin(), otherOperands.end());
   linalg::LinalgOp paddedOp =
       opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
 


        


More information about the Mlir-commits mailing list