[Mlir-commits] [mlir] 4882cac - [mlir][linalg] Adapt FillOp to use a scalar operand.

Tobias Gysi llvmlistbot at llvm.org
Mon Jun 21 23:45:20 PDT 2021


Author: Tobias Gysi
Date: 2021-06-22T06:44:52Z
New Revision: 4882cacf129b31129a3d7ffdc0c2ed8d0fb67673

URL: https://github.com/llvm/llvm-project/commit/4882cacf129b31129a3d7ffdc0c2ed8d0fb67673
DIFF: https://github.com/llvm/llvm-project/commit/4882cacf129b31129a3d7ffdc0c2ed8d0fb67673.diff

LOG: [mlir][linalg] Adapt FillOp to use a scalar operand.

Adapt the FillOp definition to use a scalar operand instead of a capture. This patch is a follow up to https://reviews.llvm.org/D104109. As the input operands are in front of the output operands the patch changes the internal operand order of the FillOp. The pretty printed version of the operation remains unchanged though. The patch also adapts the linalg to standard lowering to ensure the c signature of the FillOp remains unchanged as well.

Differential Revision: https://reviews.llvm.org/D104121

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 43fc077e2f064..453f6d37f499a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -175,14 +175,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
 }
 
 def FillOp : LinalgStructured_Op<"fill", []> {
-  let arguments = (ins AnyShaped:$output,
-                   AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger,
-                              AnyVector]>:$value);
+  let arguments = (ins
+    AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, AnyVector]>:$value,
+    AnyShaped:$output);
   let results = (outs Optional<AnyRankedTensor>:$result);
   let regions = (region AnyRegion:$region);
   let extraClassDeclaration = structuredOpsDecls # [{
-    ValueRange inputs() { return {}; }
-    ValueRange outputs() { return getOperands().take_front(); }
+    ValueRange inputs() { return getOperands().take_front(); }
+    ValueRange outputs() { return getOperands().take_back(); }
 
     // Rank-polymorphic.
     //   filling_value -> O(ivs) with parallel iterators.
@@ -196,6 +196,7 @@ def FillOp : LinalgStructured_Op<"fill", []> {
       MLIRContext *context = getContext();
       // filling_value -> O(ivs)
       return Builder(getContext()).getAffineMapArrayAttr({
+          AffineMap::get(getNumParallelLoops(), 0, {}, getContext()),
           extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
     }
 
@@ -206,13 +207,13 @@ def FillOp : LinalgStructured_Op<"fill", []> {
     getRegionBuilder() {
       return ®ionBuilder;
     }
-    static unsigned getNumRegionArgs() { return 1; }
+    static unsigned getNumRegionArgs() { return 2; }
   }];
 
   let assemblyFormat = [{
     `(` $output `,` $value `)` attr-dict `:`
         type($output) `,` type($value) (`->` type($result)^)?
-      custom<FillOpRegion>($region, ref(type($output)), ref($value))
+      custom<FillOpRegion>($region, ref(type($output)), ref(type($value)))
   }];
 
   let builders = [

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 6f422e5f629fb..67a935d4bf0fe 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -100,9 +100,21 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
   if (isa<CopyOp>(op))
     return failure();
 
+  // Swap the operand order of the FillOp to maintain the pretty printed
+  // signature that takes an output buffer followed by the fill value.
+  SmallVector<Value> originalOperandOrder = op->getOperands();
+  if (auto fillOp = dyn_cast<FillOp>(op.getOperation())) {
+    Value value = fillOp.value();
+    Value output = fillOp.output();
+    op->setOperands(ValueRange{output, value});
+  }
+
   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
-  if (!libraryCallName)
+  if (!libraryCallName) {
+    // Restore the operand order in case it has been modified.
+    op->setOperands(originalOperandOrder);
     return failure();
+  }
 
   // TODO: Add support for more complex library call signatures that include
   // indices or captured values.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ac3c6776134da..1012fb0528e84 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -421,32 +421,29 @@ void CopyOp::getEffects(
 //===----------------------------------------------------------------------===//
 void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                            ValueRange captures) {
-  assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
-  b.create<linalg::YieldOp>(captures);
+  assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
+  b.create<linalg::YieldOp>(block.getArgument(0));
 }
 
 void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
                    Value value) {
-  build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
-        value);
-  fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
-                                 TypeRange{output.getType()}, value);
+  build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
+        output);
+  fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
+                                 TypeRange{value.getType()},
+                                 TypeRange{output.getType()}, {});
 }
 
 ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
-                              OpAsmParser::OperandType valueRef) {
+                              Type valueType) {
   OpBuilder opBuilder(parser.getBuilder().getContext());
-  // Resolve `valueRef` into `value` at parse time so we can build the region
-  // with captures.
-  SmallVector<Value> value;
-  parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
-  fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
-                                 TypeRange{outputType}, value);
+  fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
+                                 TypeRange{outputType});
   return success();
 }
 
 /// FillOp region is elided when printing.
-void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
+void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
 
 static LogicalResult verify(FillOp op) {
   OpOperand *output = op.getOutputOperand(0);

diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index 0aea4e603b9e2..c7ddfb962375d 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -10,7 +10,7 @@
 from _mlir.dialects.linalg import fill_builtin_region
 
 
-def isa(cls : Type, ty : Type):
+def isa(cls: Type, ty: Type):
   try:
     cls(ty)
     return True
@@ -21,23 +21,19 @@ def isa(cls : Type, ty : Type):
 class FillOp:
   """Extends the linalg.fill op."""
 
-  def __init__(self,
-               output: Value,
-               value: Value,
-               *,
-               loc=None,
-               ip=None):
+  def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
     results = []
     if isa(RankedTensorType, output.type):
       results = [output.type]
-    op = self.build_generic(results=results,
-                            operands=[output, value],
-                            attributes=None,
-                            loc=loc,
-                            ip=ip)
+    op = self.build_generic(
+        results=results,
+        operands=[value, output],
+        attributes=None,
+        loc=loc,
+        ip=ip)
     OpView.__init__(self, op)
     linalgDialect = Context.current.get_dialect_descriptor("linalg")
-    fill_builtin_region(linalgDialect, self.operation, [value])
+    fill_builtin_region(linalgDialect, self.operation, [])
     # TODO: self.result is None. When len(results) == 1 we expect it to be
     # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
     # in the generator of _linalg_ops_gen.py where we have:
@@ -78,11 +74,12 @@ def __init__(self,
     attributes["static_sizes"] = ArrayAttr.get(
         [IntegerAttr.get(i64_type, s) for s in static_size_ints],
         context=context)
-    op = self.build_generic(results=[result_type],
-                            operands=operands,
-                            attributes=attributes,
-                            loc=loc,
-                            ip=ip)
+    op = self.build_generic(
+        results=[result_type],
+        operands=operands,
+        attributes=attributes,
+        loc=loc,
+        ip=ip)
     OpView.__init__(self, op)
 
 
@@ -91,10 +88,11 @@ class StructuredOpMixin:
 
   def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
     super().__init__(
-        self.build_generic(results=list(results),
-                           operands=[list(inputs), list(outputs)],
-                           loc=loc,
-                           ip=ip))
+        self.build_generic(
+            results=list(results),
+            operands=[list(inputs), list(outputs)],
+            loc=loc,
+            ip=ip))
 
 
 def select_opview_mixin(parent_opview_cls):

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 4ea25dffdcef1..b7821aa430aa7 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -296,14 +296,15 @@ module {
 // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
 // TLOOP-SAME: step (%[[C32]], %[[C64]])
 // TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
-// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]])
+// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]],
+// TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
 // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
 
 // TLOOP:    %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
 // TLOOP:    %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
 // TLOOP:    %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
 // TLOOP:    %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
-// TLOOP:    %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
+// TLOOP:    %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32_]])
 
 // TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
 // TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
@@ -398,3 +399,4 @@ module {
 // TLOOP:    linalg.yield %[[SUB_RESULT]] : [[TY]]
 // TLOOP:  }
 // TLOOP:  return %[[AB]] : [[TY]]
+

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7e8d1584d38dd..38efd01e04a8f 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -476,15 +476,17 @@ func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
   return
 }
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK: func @generalize_fill
 // CHECK-SAME: (%[[ARG0:.+]]: memref<?x?xf32>, %[[VAL:.+]]: f32)
 
 // CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[VAL]] : f32)
 // CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
 
-// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32)
-// CHECK-NEXT:      linalg.yield %[[VAL]] : f32
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      linalg.yield %[[BBARG0]] : f32

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index aed7e080deaf6..01e09c0968282 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -668,7 +668,7 @@ func @illegal_fill_memref_with_tensor_return
 func @illegal_fill_tensor_with_memref_return
   (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
 {
-  // expected-error @+1 {{expected type of operand #0 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
+  // expected-error @+1 {{expected type of operand #1 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
   %0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 402a26475a503..2f4c1189c7303 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -235,8 +235,8 @@ static void applyPatterns(FuncOp funcOp) {
   patterns.add<LinalgPromotionPattern<FillOp>>(
       ctx,
       LinalgPromotionOptions()
-          .setOperandsToPromote({0})
-          .setUseFullTileBuffers({true})
+          .setOperandsToPromote({1})
+          .setUseFullTileBuffers({false, true})
           .setAlignment(32),
       LinalgTransformationFilter(
           Identifier::get("_promote_views_aligned_", ctx),


        


More information about the Mlir-commits mailing list