[Mlir-commits] [mlir] 7cef24e - [mlir][linalg] Adapt the FillOp builder signature.

Tobias Gysi llvmlistbot at llvm.org
Wed Jun 23 01:07:49 PDT 2021


Author: Tobias Gysi
Date: 2021-06-23T08:06:43Z
New Revision: 7cef24ee83cd85b0d7917e7d574287615a717e44

URL: https://github.com/llvm/llvm-project/commit/7cef24ee83cd85b0d7917e7d574287615a717e44
DIFF: https://github.com/llvm/llvm-project/commit/7cef24ee83cd85b0d7917e7d574287615a717e44.diff

LOG: [mlir][linalg] Adapt the FillOp builder signature.

Change the build operand order from output, value to value, output. The patch makes the argument order consistent with the pretty printed order updated by https://reviews.llvm.org/D104356.

Differential Revision: https://reviews.llvm.org/D104359

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 4720582586d1..f83f484187c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -217,7 +217,7 @@ def FillOp : LinalgStructured_Op<"fill", []> {
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$output, "Value":$value)>
+    OpBuilder<(ins "Value":$value, "Value":$output)>
   ];
 
   let verifier = [{ return ::verify(*this); }];

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 4553292e8699..be6b31d13465 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -813,7 +813,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
 
   auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
   auto filledTensor =
-      rewriter.create<linalg::FillOp>(loc, initTensor, fillValue).result();
+      rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
 
   SmallVector<AffineExpr, 2> srcExprs;
   SmallVector<AffineExpr, 2> dstExprs;
@@ -1018,7 +1018,7 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, outputTy.getShape(), outputTy.getElementType());
     Value zeroTensor =
-        rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
+        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
     rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
         op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
         ValueRange{zeroTensor});
@@ -1093,7 +1093,6 @@ class FullyConnectedConverter
   }
 };
 
-
 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
 public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@@ -1737,7 +1736,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     Value zeroVal = rewriter.create<ConstantOp>(
         loc, rewriter.getZeroAttr(resultType.getElementType()));
     Value result =
-        rewriter.create<linalg::FillOp>(loc, init, zeroVal).getResult(0);
+        rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
 
     for (auto arg : args) {
       sizes[axis] = rewriter.create<memref::DimOp>(loc, arg, axisValue);
@@ -1981,7 +1980,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
     auto fillValueIdx = rewriter.create<ConstantOp>(
         loc, rewriter.getIntegerAttr(outElementTy, 0));
     auto filledTensorIdx =
-        rewriter.create<linalg::FillOp>(loc, initTensorIdx, fillValueIdx)
+        rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
             .result();
 
     // Second fill the output buffer for the running max.
@@ -1999,7 +1998,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
 
     auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
     auto filledTensorMax =
-        rewriter.create<linalg::FillOp>(loc, initTensorMax, fillValueMax)
+        rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
             .result();
 
     // We need to reduce along the arg-max axis, with parallel operations along
@@ -2288,7 +2287,7 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
         loc, resultTy.getShape(), resultTy.getElementType());
 
     Value filledInitTensor =
-        rewriter.create<linalg::FillOp>(loc, initTensor, initialValue).result();
+        rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
 
     Value fakeWindowDims =
         rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ba3066c5398b..9a1ceebba97d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -426,8 +426,8 @@ void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   b.create<linalg::YieldOp>(block.getArgument(0));
 }
 
-void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
-                   Value value) {
+void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
+                   Value output) {
   build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
         output);
   fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
@@ -1966,7 +1966,7 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
     auto newInit = rewriter.create<TensorReshapeOp>(
         loc, reshapeOp.getResultType(), oldFill.output(),
         reshapeOp.reassociation());
-    rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, newInit, oldFill.value());
+    rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 577a79b73cd0..414aa632d4e8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -187,7 +187,7 @@ class BufferizeFillOp : public OpConversionPattern<FillOp> {
       return rewriter.notifyMatchFailure(op,
                                          "operand must be of a tensor type");
 
-    rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value());
+    rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
     rewriter.replaceOp(op, adaptor.output());
 
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 5ab066a3471a..2ec183fec8e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -283,7 +283,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,
             .Default([](auto) { return Value(); });
     if (!fillVal)
       return {};
-    b.create<linalg::FillOp>(promotionInfo->fullLocalView, fillVal);
+    b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
   }
 
   // Copy data into the promoted buffers. Use callback if provided.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0d6f672cd79d..79335c3629d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -676,7 +676,7 @@ LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
 
   // Initialize tensor with the pad value
   Value tmpTensor =
-      rewriter.create<linalg::FillOp>(loc, initTensor, padValue).result();
+      rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
 
   // Copy original contents into new tensor
   // Uses linalg.generic, but could be done with tensor.insert_slice

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 38fcf7d9c6ec..829b988dbad7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -756,7 +756,7 @@ struct GenericPadTensorOpVectorizationPattern
     // pattern.)
     auto padValue = padOp.getConstantPaddingValue();
     if (padValue)
-      return rewriter.create<FillOp>(padOp.getLoc(), dest, padValue).result();
+      return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
 
     // Fill could not be vectorized: Lower to tensor::GenerateOp with region.
     auto generateOp = rewriter.create<tensor::GenerateOp>(

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 27b1a00cfdde..e04d48d6ca84 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2440,7 +2440,7 @@ createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
         b.create<scf::YieldOp>(loc, viewAndIndices);
       },
       [&](OpBuilder &b, Location loc) {
-        b.create<linalg::FillOp>(loc, alloc, xferOp.padding());
+        b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
         // Take partial subview of memref which guarantees no dimension
         // overflows.
         Value memRefSubView = createSubViewIntersection(

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 43eabd64bbe4..45e7c8235067 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -309,10 +309,11 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
   auto floatType = src.getType().cast<MemRefType>().getElementType();
   if (!floatType.isa<FloatType>())
     return failure();
-  if (!isOutput)
-    b.create<FillOp>(
-        src.getLoc(), dst,
-        b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
+  if (!isOutput) {
+    Value cst =
+        b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0));
+    b.create<FillOp>(src.getLoc(), cst, dst);
+  }
   b.create<CopyOp>(src.getLoc(), src, dst);
   return success();
 }


        


More information about the Mlir-commits mailing list