[Mlir-commits] [mlir] f543122 - [mlir][Linalg] Drop function attribute from generic ops.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 16 06:51:37 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-16T09:47:08-04:00
New Revision: f54312277cdbc9e52657ec904ca4c1c333208c43

URL: https://github.com/llvm/llvm-project/commit/f54312277cdbc9e52657ec904ca4c1c333208c43
DIFF: https://github.com/llvm/llvm-project/commit/f54312277cdbc9e52657ec904ca4c1c333208c43.diff

LOG: [mlir][Linalg] Drop function attribute from generic ops.

The function attribute in generic ops is not paying for itself.
A region is the more standardized way of specifying a custom computation.
If needed this region can call a function directly.
This is deemed more natural than managing a dedicated function attribute.

This also simplifies named ops generation by trimming unnecessary complexity.

Differential Revision: https://reviews.llvm.org/D78266

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 0ff455391cb4..61d909139f1b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -523,7 +523,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
                    AffineMapArrayAttr:$indexing_maps,
                    ArrayAttr:$iterator_types,
                    OptionalAttr<StrAttr>:$doc,
-                   OptionalAttr<FlatSymbolRefAttr>:$fun,
                    OptionalAttr<StrAttr>:$library_call);
   let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
   let regions = (region AnyRegion:$region);
@@ -531,7 +530,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
     SmallVector<StringRef, 8> linalgTraitAttrNames() {
       return SmallVector<StringRef, 8>{
         getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
-        getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
+        getIndexingMapsAttrName(), getLibraryCallAttrName(),
         getIteratorTypesAttrName()
       };
     }
@@ -540,12 +539,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
 
     unsigned getNumOutputs() { return args_out().getSExtValue(); }
 
-    FuncOp getFunction() {
-      auto moduleOp = getParentOfType<ModuleOp>();
-      return fun().hasValue() ?
-        moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
-    }
-
     StringRef getLibraryCallName() {
       return library_call().hasValue() ? library_call().getValue() : "";
     }
@@ -581,13 +574,6 @@ def GenericOp : GenericOpBase<"generic"> {
       - args_in: an I64Attr representing the number of input (readonly) views
       - args_out: an I64Attr representing the number of output (readwrite) views
       - doc [optional]: a documentation string
-      - fun: a FlatSymbolRefAttr that must resolve to an existing function
-        symbol. To support inplace updates in a generic fashion, the signature
-        of the function must be:
-        ```
-          fun([input views element types], [output views element types])
-            -> ([output views element types])
-        ```
       - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
         and output view. Such AffineMapAttr specifies the mapping between the
         loops and the indexing within each view.
@@ -604,11 +590,6 @@ def GenericOp : GenericOpBase<"generic"> {
     Example:
     Defining a #matmul_trait attribute in MLIR can be done as follows:
       ```mlir
-      func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
-        %d = mulf %a, %b: f32
-        %e = addf %c, %d: f32
-        return %e: f32
-      }
       #matmul_accesses = [
         (m, n, k) -> (m, k),
         (m, n, k) -> (k, n),
@@ -616,7 +597,6 @@ def GenericOp : GenericOpBase<"generic"> {
       ]
       #matmul_trait = {
         doc = "C(m, n) += A(m, k) * B(k, n)",
-        fun = @fma,
         indexing_maps = #matmul_accesses,
         library_call = "linalg_matmul",
         n_views = [2, 1],
@@ -626,10 +606,14 @@ def GenericOp : GenericOpBase<"generic"> {
 
     And can be reused in multiple places as:
       ```mlir
-      linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
-        memref<?x?xf32, stride_specification>,
-        memref<?x?xf32, stride_specification>,
-        memref<?x?xf32, stride_specification>
+      linalg.generic #matmul_trait %A, %B, %C [other-attributes] {
+        (%a: f32, %b: f32, %c: f32) :
+          %d = mulf %a, %b: f32
+          %e = addf %c, %d: f32
+          linalg_yield %e : f32
+      } : memref<?x?xf32, stride_specification>,
+          memref<?x?xf32, stride_specification>,
+          memref<?x?xf32, stride_specification>
       ```
 
     This may lower to either:
@@ -649,9 +633,9 @@ def GenericOp : GenericOpBase<"generic"> {
           %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
           %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
           %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
-          %d = call @func_of_elements(%a, %b, %c)
-                 : (f32, f32, f32) -> (f32)
-          store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
+          %d = mulf %a, %b: f32
+          %e = addf %c, %d: f32
+          store %e, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
         }
       }
     }
@@ -662,7 +646,7 @@ def GenericOp : GenericOpBase<"generic"> {
     mixing input and output ranked tensor values with input and output memrefs.
 
     ```mlir
-    %C = linalg.generic #trait_attribute %A, %B {other-attributes} :
+    %C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} :
       tensor<?x?xf32>,
       memref<?x?xf32, stride_specification>
       -> (tensor<?x?xf32>)
@@ -708,13 +692,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
       - args_in: an I64Attr representing the number of input (readonly) views
       - args_out: an I64Attr representing the number of output (readwrite) views
       - doc [optional]: a documentation string
-      - fun: a FlatSymbolRefAttr that must resolve to an existing function
-        symbol. To support inplace updates in a generic fashion, the signature
-        of the function must be:
-        ```
-          fun([index types of induction variables], [input views element types],
-              [output views element types]) -> ([output views element types])
-        ```
       - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
         and output view. Such AffineMapAttr specifies the mapping between the
         loops and the indexing within each view.
@@ -732,15 +709,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
     Defining a #matmul_trait attribute in MLIR can be done as follows:
 
     ```mlir
-    func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
-              %a: f32, %b: f32, %c: f32)
-      -> f32
-    {
-      "some_optional_condition"(%offset_m, %offset_n, %offset_k)
-      %d = mulf %a, %b: f32
-      %e = addf %c, %d: f32
-      return %e: f32
-    }
     #matmul_accesses = [
       (m, n, k) -> (m, k),
       (m, n, k) -> (k, n),
@@ -748,7 +716,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
     ]
     #matmul_trait = {
       doc = "C(m, n) += A(m, k) * B(k, n)",
-      fun = @fma,
       indexing_maps = #matmul_accesses,
       library_call = "linalg_matmul",
       n_views = [2, 1],
@@ -759,10 +726,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
     And can be reused in multiple places as:
 
     ```mlir
-    linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] :
-      memref<?x?xf32, stride_specification>,
-      memref<?x?xf32, stride_specification>,
-      memref<?x?xf32, stride_specification>
+    linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] {
+      (%offset_m: index, %offset_n: index, %offset_k: index,
+       %a: f32, %b: f32, %c: f32) :
+        "some_optional_computation"(%offset_m, %offset_n, %offset_k)
+        %d = mulf %a, %b: f32
+        %e = addf %c, %d: f32
+        linalg_yield %e : f32
+    } : memref<?x?xf32, stride_specification>,
+        memref<?x?xf32, stride_specification>,
+        memref<?x?xf32, stride_specification>
     ```
 
     This may lower to either:
@@ -784,8 +757,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
           %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
           %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
           %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
-          %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c)
-                 : (index, index, index, f32, f32, f32) -> (f32)
+          "some_optional_computation"(%m, %n, %k)
+          %d = mulf %a, %b: f32
+          %e = addf %c, %d: f32
           store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
         }
       }

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 6262e7757c6c..5a36aabfab75 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -66,10 +66,6 @@ constexpr StringRef getArgsOutAttrName() { return "args_out"; }
 /// string of the structured op.
 constexpr StringRef getDocAttrName() { return "doc"; }
 
-/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the
-/// MLIR function that implements the body of the structured op.
-constexpr StringRef getFunAttrName() { return "fun"; }
-
 /// Attribute name for the StrArrayAttr which encodes the external library
 /// function that implements the structured op.
 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }

diff  --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 59a565c5e395..2fa09b7422a9 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -177,7 +177,6 @@ Operation *mlir::edsc::makeGenericLinalgOp(
               builder.getAffineMapArrayAttr(maps),
               builder.getStrArrayAttr(iteratorStrTypes),
               StringAttr() /*doc*/,
-              FlatSymbolRefAttr() /*fun*/,
               StringAttr() /*library_call*/
               /* TODO: other attributes in op */
               )

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index df1a957d344c..9f664586453a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -133,10 +133,11 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
       attrs.push_back(attr);
 
   auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
-  p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
+  p << op.getOperationName() << " " << dictAttr;
+  p.printOptionalAttrDict(op.getAttrs(), attrNames);
+  p << " " << op.getOperands();
   if (!op.region().empty())
     p.printRegion(op.region());
-  p.printOptionalAttrDict(op.getAttrs(), attrNames);
   p << ": " << op.getOperandTypes();
   auto outputTensorTypes = op.getResultTypes();
   if (!outputTensorTypes.empty())
@@ -156,21 +157,21 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
   // The name is unimportant as we will overwrite result.attributes.
   // The core linalg traits must contain the information necessary to pass the
   // verifier.
-  if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
-      parser.parseOperandList(operandsInfo))
+  if (parser.parseAttribute(dictAttr, "_", result.attributes))
     return failure();
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
 
+  // Optional attributes may be added.
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseOperandList(operandsInfo))
+    return failure();
+
   Region &region = *result.addRegion();
   SmallVector<Type, 8> operandTypes, regionTypes;
-  // Optional attributes may be added.
-  // Either Optional getFunAttrName() attribute or region must be specified.
-  if (!dictAttr.get(getFunAttrName()) &&
-      parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes))
+  if (parser.parseRegion(region, regionOperandsInfo, regionTypes))
     return failure();
-  if (parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonTypeList(operandTypes))
+  if (parser.parseColonTypeList(operandTypes))
     return failure();
   // Generic ops may specify that a subset of its outputs are tensors. Such
   // outputs are specified in the result type.
@@ -183,10 +184,7 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
                                 parser.getCurrentLocation(), result.operands);
 }
 
-template <typename GenericOpType>
-static LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
-
-template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
+LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
   auto nOperands = op.getNumOperands();
   if (block.getNumArguments() != nOperands)
     return op.emitOpError("expected number of block arguments to match number "
@@ -205,7 +203,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
   return success();
 }
 
-template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
+LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
   auto nInputViews = op.getNumInputs();
   auto nLoops = op.getNumLoops();
   auto nOperands = op.getNumOperands();
@@ -234,81 +232,6 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
   return success();
 }
 
-template <typename GenericOpType>
-static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
-
-template <typename GenericOpType>
-static LogicalResult verifyFuncArgsGeneric(GenericOpType op,
-                                           FunctionType funType) {
-  auto res = verifyFuncArgs(op, funType);
-  if (failed(res))
-    return res;
-
-  auto nInputs = op.getNumInputs();
-  auto nOutputs = op.getNumOutputs();
-  // linalg.generic output element types are exactly the function results.
-  for (unsigned idx = 0; idx < nOutputs; ++idx) {
-    ShapedType shapedType = op.getShapedType(nInputs + idx);
-    if (funType.getResult(idx) != shapedType.getElementType())
-      return op.emitOpError("expected function result ")
-             << (idx + 1) << " of the same type as elemental type "
-             << shapedType.getElementType() << " of output " << (idx + 1);
-  }
-  return success();
-}
-
-template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
-  auto nOperands = op.getNumOperands();
-  if (funType.getNumInputs() != nOperands)
-    return op.emitOpError(
-        "expected function arguments to match number of operands");
-  if (funType.getNumResults() != op.getNumOutputs())
-    return op.emitOpError("expected function results(")
-           << funType.getNumResults() << ") to match number of outputs("
-           << op.getNumOutputs() << ")";
-
-  // linalg.generic operands element types are exactly the first function
-  // arguments.
-  for (unsigned idx = 0; idx < nOperands; ++idx) {
-    ShapedType shapedType = op.getShapedType(idx);
-    if (funType.getInput(idx) != shapedType.getElementType())
-      return op.emitOpError("expected function argument ")
-             << (idx + 1) << " of the same type as elemental type "
-             << shapedType.getElementType() << " of operand " << (idx + 1);
-  }
-
-  return success();
-}
-
-template <>
-LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
-  auto nLoops = op.getNumLoops();
-  auto nOutputs = op.getNumOutputs();
-  auto nOperands = op.getNumOperands();
-  if (funType.getNumInputs() != nOperands + nLoops)
-    return op.emitOpError("expected function arguments to match number of "
-                          "loops + number of operands");
-  if (funType.getNumResults() != nOutputs)
-    return op.emitOpError(
-        "expected function results to match number of outputs");
-  for (unsigned i = 0; i < nLoops; ++i)
-    if (!funType.getInput(i).isIndex())
-      return op.emitOpError("expected function argument ")
-             << (i + 1) << " to be an index";
-
-  // linalg.generic operands element types are exactly the first function
-  // arguments.
-  for (unsigned idx = 0; idx < nOperands; ++idx) {
-    ShapedType shapedType = op.getShapedType(idx);
-    if (funType.getInput(idx + nLoops) != shapedType.getElementType())
-      return op.emitOpError("expected function argument ")
-             << (idx + nLoops + 1) << " of the same type as elemental type "
-             << shapedType.getElementType() << " of input " << (idx + 1);
-  }
-
-  return success();
-}
-
 template <typename GenericOpType>
 static LogicalResult verifyGenericOp(GenericOpType op) {
   auto nInputViews = op.getNumInputs();
@@ -320,20 +243,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
            << " inputs (tensor or buffer) and output buffer operands";
 
   auto &region = op.region();
-  auto funOp = op.getFunction();
-  auto funType = funOp ? funOp.getType() : FunctionType();
-  if (!region.empty()) {
-    if (region.getBlocks().size() != 1)
-      return op.emitOpError("expected region with 1 block");
-    if (failed(verifyBlockArgs(op, region.getBlocks().front())))
-      return failure();
-  } else {
-    if (!funOp || !funOp.getType())
-      return op.emitOpError(
-          "expected function attribute to refer to a defined symbol");
-    if (failed(verifyFuncArgsGeneric(op, funType)))
-      return failure();
-  }
+  if (region.getBlocks().size() != 1)
+    return op.emitOpError("expected region with 1 block");
+  if (failed(verifyBlockArgs(op, region.getBlocks().front())))
+    return failure();
 
   SmallVector<AffineMap, 4> indexingMaps;
   indexingMaps.reserve(op.indexing_maps().size());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 96cbdab5ac47..a5f4cd9e4592 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -382,8 +382,7 @@ static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
   // - only handle ops that use regions for specifying the scalar operations.
   if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
       producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
-      producerOp.getNumParallelLoops() != producerOp.getNumLoops() ||
-      producerOp.fun() || consumerOp.fun())
+      producerOp.getNumParallelLoops() != producerOp.getNumLoops())
     return false;
 
   // Get the consumer index map. The number of results of the consumer index map
@@ -472,7 +471,6 @@ Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
       b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
       b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
       /*doc=*/nullptr,
-      /*fun=*/nullptr,
       /*library_call=*/nullptr);
 
   // Build the region of the fused op.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 529448497728..07a2c370a152 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -400,21 +400,6 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
       indexedValues[nInputs + i] = std_load(output, indexing);
     }
 
-    auto funcOp = genericOp.getFunction();
-    if (funcOp) {
-      // 2. Emit call.
-      Operation *callOp = std_call(funcOp, indexedValues);
-      assert(callOp->getNumResults() == genericOp.getNumOutputs());
-
-      // 3. Emit std_store.
-      for (unsigned i = 0; i < nOutputs; ++i) {
-        Value output = genericOp.getOutputBuffer(i);
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-        std_store(callOp->getResult(i), output, indexing);
-      }
-      return;
-    }
     // TODO(ntv): When a region inliner exists, use it.
     // 2. Inline region, currently only works for a single basic block.
     // 3. Emit std_store.
@@ -495,20 +480,6 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
       indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
     }
 
-    if (auto funcOp = indexedGenericOp.getFunction()) {
-      // 2. Emit call.
-      Operation *callOp = std_call(funcOp, indexedValues);
-      assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs());
-
-      // 3. Emit std_store.
-      for (unsigned i = 0; i < nOutputs; ++i) {
-        Value output = indexedGenericOp.getOutputBuffer(i);
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-        std_store(callOp->getResult(i), output, indexing);
-      }
-      return;
-    }
     // TODO(ntv): When a region inliner exists, use it.
     // 2. Inline region, currently only works for a single basic block.
     // 3. Emit std_store.

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 0041f97d7eea..e6414a0fbd78 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -54,200 +54,131 @@ func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
 
 // -----
 
-func @generic_at_least_2_operands(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected 2 or more operands}}
+func @generic_no_region(%arg0: memref<f32>) {
+  // expected-error @+6 {{expected '{' to begin a region}}
   linalg.generic {
     args_in = 1,
     args_out = 1,
-    fun = @foo,
     indexing_maps =  [ affine_map<() -> (0)> ],
     iterator_types = []
-  } %arg0: memref<f32>
+  } %arg0 : memref<f32>
 }
 
 // -----
 
-func @generic_exactly_2_views(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}}
+func @generic_at_least_2_operands(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected 2 or more operands}}
   linalg.generic {
     args_in = 1,
     args_out = 1,
-    fun = @foo,
     indexing_maps =  [ affine_map<() -> (0)> ],
     iterator_types = []
-  } %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
+  } %arg0 {} : memref<f32>
 }
 
 // -----
 
-func @generic_undefined_fun(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function attribute to refer to a defined symbol}}
+func @generic_exactly_2_views(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}}
   linalg.generic {
     args_in = 1,
     args_out = 1,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = []
-  } %arg0, %arg0: memref<f32>, memref<f32>
-}
-
-// -----
-
-func @foo() { return }
-
-func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function arguments to match number of operands}}
-  linalg.generic {
-    args_in = 0,
-    args_out = 1,
-    fun = @foo,
     indexing_maps =  [ affine_map<() -> (0)> ],
     iterator_types = []
-  } %arg0: memref<f32>
+  } %arg0, %arg0, %arg0 {}: memref<f32>, memref<f32>, memref<f32>
 }
 
 // -----
 
-func @foo(%0: i32) { return }
-
 func @generic_mismatched_num_returns(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function results(0) to match number of outputs(1)}}
+  // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (0)}}
   linalg.generic {
     args_in = 0,
     args_out = 1,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = []
-  } %arg0: memref<f32>
-}
-
-// -----
-
-func @foo(%0: i32, %1: i32, %2: i32) { return }
-
-func @generic_mismatched_num_returns(%0: memref<i32>, %1: memref<f32>) {
-  // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of operand 2}}
-  linalg.generic {
-    args_in = 3,
-    args_out = 0,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = []
-  } %0, %1, %1: memref<i32>, memref<f32>, memref<f32>
-}
-
-// -----
-
-func @foo(%0: i32, %1: i32, %2: f32) -> i32 { return %1: i32}
-
-func @generic_mismatched_num_returns(%0: memref<i32>, %1: memref<f32>) {
-  // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}}
-  linalg.generic {
-    args_in = 2,
-    args_out = 1,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (0)> ],
+    indexing_maps =  [ affine_map<() -> ()> ],
     iterator_types = []
-  } %0, %0, %1: memref<i32>, memref<i32>, memref<f32>
+  } %arg0 {
+    ^bb(%0: f32):
+      linalg.yield
+  }: memref<f32>
 }
 
 // -----
 
-func @foo(%0: i32) -> i32 { return %0: i32 }
-
 func @generic_symbol_in_map(%arg0: memref<i32>) {
   // expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
   linalg.generic {
     args_in = 0,
     args_out = 1,
-    fun = @foo,
     indexing_maps =  [ affine_map<()[N] -> (0)> ],
     iterator_types = ["parallel"]
-  } %arg0: memref<i32>
+  } %arg0 {
+    ^bb(%i : i32):
+  }: memref<i32>
 }
 
 // -----
 
-func @foo(%0: i32) -> i32 { return %0: i32 }
-
 func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
   // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
   linalg.generic {
     args_in = 0,
     args_out = 1,
-    fun = @foo,
     indexing_maps =  [ affine_map<() -> (0)> ],
     iterator_types = ["parallel"]
-  } %arg0: memref<1xi32>
+  } %arg0 {
+    ^bb(%i : i32):
+  }: memref<1xi32>
 }
 
 // -----
 
-func @foo(%0: f32) -> f32 { return %0: f32 }
-
 func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   // expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>'}}
   linalg.generic {
     args_in = 0,
     args_out = 1,
-    fun = @foo,
     indexing_maps =  [ affine_map<() -> (0, 0)> ],
     iterator_types = []
-  } %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
-}
-
-// -----
-
-func @foo(%0: i32) -> f32 {
-  %1 = constant 0.0: f32
-  return %1: f32
-}
-
-func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of operand 1}}
-  linalg.generic {
-    args_in = 0,
-    args_out = 1,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = []
-  } %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
+  } %arg0 {
+    ^bb(%f : f32):
+      linalg.yield %f: f32
+  }: memref<?xf32, affine_map<(i)[off]->(off + i)>>
 }
 
 // -----
 
-func @foo(%0: f32) -> i4 {
-  %1 = constant 1: i4
-  return %1: i4
-}
-
-func @generic_fun_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}}
+func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+  // expected-error @+9 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
     args_in = 0,
     args_out = 1,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = []
-  } %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
+    indexing_maps =  [ affine_map<(i) -> (i)> ],
+    iterator_types = ["parallel"]
+  } %arg0 {
+    ^bb(%0: f32):
+      %1 = constant 1: i4
+      linalg.yield %1: i4
+  }: memref<?xf32, affine_map<(i)[off]->(off + i)>>
 }
 
 // -----
 
-func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
-
 func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
   linalg.generic {
     args_in = 1,
     args_out = 1,
-    fun = @foo,
     indexing_maps =  [
       affine_map<(i, j) -> (i + j)>,
       affine_map<(i, j) -> (i + j)>
     ],
     iterator_types = ["parallel","parallel"]
-  } %arg0, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>, memref<?xf32, affine_map<(i)[off]->(off + i)>>
+  } %arg0, %arg1 {
+    ^bb(%0: f32, %1: f32):
+      linalg.yield %1: f32
+  }: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
+     memref<?xf32, affine_map<(i)[off]->(off + i)>>
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -341,88 +272,53 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
 
 // -----
 
-func @foo(%f: f32) -> (f32) {
-  return %f : f32
-}
-func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function arguments to match number of loops + number of operands}}
-  linalg.indexed_generic {
-    args_in = 0,
-    args_out = 1,
-    indexing_maps =  [ affine_map<(d0) -> (d0)> ],
-    iterator_types = ["parallel"],
-    fun = @foo
-  } %arg0:  memref<f32>
-}
-
-// -----
-
-func @foo(%i: i32, %val: f32) -> (f32) {
-  return %val : f32
-}
-func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function argument 1 to be an index}}
-  linalg.indexed_generic {
-    args_in = 0,
-    args_out = 1,
-    iterator_types = ["parallel"],
-    indexing_maps = [ affine_map<(i) -> (i)> ],
-    fun = @foo
-  } %arg0 : memref<f32>
-}
-
-// -----
-
-func @foo(%i: index, %val: i1) -> (i1) {
-  return %val : i1
-}
-func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of input 1}}
+func @indexed_generic_arg_count(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}}
   linalg.indexed_generic {
     args_in = 0,
     args_out = 1,
-    indexing_maps =  [ affine_map<(d0) -> (d0)> ],
-    iterator_types = ["parallel"],
-    fun = @foo
-  } %arg0: memref<f32>
+    indexing_maps =  [ affine_map<()[] -> ()> ],
+    iterator_types = []
+  } %arg0 {
+    ^bb(%0: index, %1: f32):
+      linalg.yield %1: f32
+  } :  memref<f32>
+  return
 }
 
 // -----
 
-func @foo(%i: index, %val: i1) -> (i1, i1) {
-  return %val, %val : i1, i1
-}
-func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected function results to match number of outputs}}
+func @indexed_generic_induction_var_arg_type(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected block argument 1 to be an index}}
   linalg.indexed_generic {
     args_in = 0,
     args_out = 1,
-    indexing_maps =  [ affine_map<(d0) -> (d0)> ],
     iterator_types = ["parallel"],
-    fun = @foo
-  } %arg0: memref<f32>
+    indexing_maps = [ affine_map<(i) -> (i)> ]
+  } %arg0 {
+    ^bb(%0: i32, %1: f32):
+      linalg.yield %1: f32
+  } : memref<f32>
 }
 
 // -----
 
-func @foo(%i: index, %val: i32) -> (f32) {
-  %val_float = sitofp %val : i32 to f32
-  return %val_float : f32
-}
-func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
-  // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'i32' of output 1}}
+func @indexed_generic_result_count(%arg0: memref<?xf32>) {
+  // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (2)}}
   linalg.indexed_generic {
     args_in = 0,
     args_out = 1,
     indexing_maps =  [ affine_map<(d0) -> (d0)> ],
-    iterator_types = ["parallel"],
-    fun = @foo
-  } %arg0: memref<i32>
+    iterator_types = ["parallel"]
+  } %arg0 {
+    ^bb(%i: index, %val: f32):
+      linalg.yield %val, %val: f32, f32
+  }: memref<?xf32>
 }
 
 // -----
 
-func @generic_fun_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   // expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
     args_in = 0,
@@ -453,7 +349,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
 
 // -----
 
-func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
+func @generic_result_0_element_type(%arg0: memref<?xf32>) {
   // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
   linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 48e4b6ecd10d..3751c105f310 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -533,51 +533,11 @@ func @pooling_sum(%arg0: memref<?x?xf32>,
 //       CHECKPARALLEL:         %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
 //       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
 
-func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
-  %f0 = constant 0.0 : f32
-  return %f0, %f0 : f32, f32
-}
 #accesses = [
   affine_map<(i, j, k) -> (i, j)>,
   affine_map<(i, j, k) -> (i, j, k)>,
   affine_map<(i, j, k) -> (i, k, j)>
 ]
-#trait = {
-  args_in = 1,
-  args_out = 2,
-  iterator_types = ["parallel", "parallel", "parallel"],
-  indexing_maps = #accesses,
-  fun = @foo,
-  library_call = "some_external_function_name_1",
-  doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
-}
-func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.generic #trait %arg0, %arg1, %arg2:
-    memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
-  return
-}
-// CHECKLOOP-LABEL: @foo
-// CHECKLOOP-LABEL: @generic_function
-//       CHECKLOOP: loop.for %[[i:.*]] = {{.*}}
-//       CHECKLOOP:   loop.for %[[j:.*]] = {{.*}}
-//       CHECKLOOP:     loop.for %[[k:.*]] = {{.*}}
-//       CHECKLOOP:       %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
-//       CHECKLOOP:       %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKLOOP:       %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKLOOP:       %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
-//       CHECKLOOP:       store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKLOOP:       store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-
-// CHECKPARALLEL-LABEL: @foo
-// CHECKPARALLEL-LABEL: @generic_function
-//       CHECKPARALLEL: loop.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
-//       CHECKPARALLEL:   %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
-//       CHECKPARALLEL:   %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKPARALLEL:   %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKPARALLEL:   %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
-//       CHECKPARALLEL:   store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKPARALLEL:   store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-
 #trait2 = {
   args_in = 1,
   args_out = 2,
@@ -617,52 +577,6 @@ func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1:
 //       CHECKPARALLEL:   store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
 //       CHECKPARALLEL:   store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
 
-func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) -> (f32, f32) {
-  %i_int = index_cast %i: index to i32
-  %i_float = sitofp %i_int : i32 to f32
-  return %i_float, %i_float : f32, f32
-}
-#trait3 = {
-  args_in = 1,
-  args_out = 2,
-  iterator_types = ["parallel", "parallel", "parallel"],
-  indexing_maps = #accesses,
-  fun = @indexed_foo,
-  library_call = "some_external_function_name_1",
-  doc = "b(i,j,k), c(i,k,j) = foo(a(i, j), b(i,j,k), c(i,k,j))"
-}
-func @indexed_generic_function(
-         %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-         %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-         %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.indexed_generic #trait3 %arg0, %arg1, %arg2:
-    memref<?x?xf32, offset: ?, strides: [?, 1]>,
-    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
-  return
-}
-// CHECKLOOP-LABEL: @indexed_foo
-// CHECKLOOP-LABEL: @indexed_generic_function
-//       CHECKLOOP: loop.for %[[i:.*]] = {{.*}}
-//       CHECKLOOP:   loop.for %[[j:.*]] = {{.*}}
-//       CHECKLOOP:     loop.for %[[k:.*]] = {{.*}}
-//       CHECKLOOP:       %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
-//       CHECKLOOP:       %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKLOOP:       %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKLOOP:       %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32)
-//       CHECKLOOP:       store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKLOOP:       store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-
-// CHECKPARALLEL-LABEL: @indexed_foo
-// CHECKPARALLEL-LABEL: @indexed_generic_function
-//       CHECKPARALLEL: loop.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
-//       CHECKPARALLEL:   %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
-//       CHECKPARALLEL:   %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKPARALLEL:   %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKPARALLEL:   %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32)
-//       CHECKPARALLEL:   store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
-//       CHECKPARALLEL:   store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
-
 #trait4 = {
   args_in = 1,
   args_out = 2,

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index c28c671d2885..89b910e7b04a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -289,11 +289,6 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 
-func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
-  %f0 = constant 0.0 : f32
-  return %f0 : f32
-}
-
 #accesses = [
   affine_map<(i, j, k) -> (j, i)>,
   affine_map<(i, j, k) -> (i, k, i + j)>
@@ -304,46 +299,45 @@ func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
   args_out = 1,
   indexing_maps = #accesses,
   iterator_types = ["parallel", "parallel", "parallel"],
-  fun = @foo,
   library_call = "some_external_function_name_1"
 }
 
 func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
               %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.generic #trait %arg0, %arg1 {foo = 1} :
-    memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
-    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  linalg.generic #trait {foo = 1} %arg0, %arg1 {
+    ^bb(%0: vector<3x4xi4>, %1: f32) :
+      %f0 = constant 0.0 : f32
+      linalg.yield %f0 : f32
+  } : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
+      memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
   return
 }
-// CHECK-LABEL: func @foo
 // CHECK-LABEL: func @generic
-//       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo,
+//       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
 //  CHECK-SAME:     library_call = "some_external_function_name_1"
-//  CHECK-SAME:     {foo = 1 : i64}:
-//  CHECK-SAME:     memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
+//  CHECK-SAME:     {foo = 1 : i64}
+//       CHECK:     memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
 
 func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
                                 %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.generic #trait %arg0, %arg1 {foo = 1} :
-    tensor<?x?xvector<3x4xi4>>,
-    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  linalg.generic #trait {foo = 1} %arg0, %arg1 {
+    ^bb(%0: vector<3x4xi4>, %1: f32) :
+      %f0 = constant 0.0 : f32
+      linalg.yield %f0 : f32
+  } : tensor<?x?xvector<3x4xi4>>,
+      memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
   return
 }
 // CHECK-LABEL: func @generic_with_tensor_input
-//       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo,
+//       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
 //  CHECK-SAME:     library_call = "some_external_function_name_1"}
-//  CHECK-SAME:     {foo = 1 : i64}:
-//  CHECK-SAME:     tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[strided3D]]>
+//  CHECK-SAME:     {foo = 1 : i64}
+//       CHECK:     tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[strided3D]]>
 
 // -----
 
-func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
-  %f0 = constant 0.0 : f32
-  return %f0 : f32
-}
-
 #accesses = [
   affine_map<(i, j, k) -> (j, i)>,
   affine_map<(i, j, k) -> (i, k, i + j)>
@@ -354,31 +348,30 @@ func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
   args_out = 1,
   indexing_maps = #accesses,
   iterator_types = ["parallel", "parallel", "parallel"],
-  fun = @foo,
   library_call = "some_external_function_name_1"
 }
 
 func @generic_with_tensor_input_and_output(
     %arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
     -> (tensor<?x?x?xf32>) {
-  %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} :
-    tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+  %0 = linalg.generic #trait2 {foo = 1} %arg0, %arg1 {
+    ^bb(%0: vector<3x4xi4>, %1: f32) :
+      %f0 = constant 0.0 : f32
+      linalg.yield %f0 : f32
+  } : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
   return %0 : tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @generic_with_tensor_input_and_output
-//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo,
+//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-//  CHECK-SAME:     library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}:
-//  CHECK-SAME:     tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+//  CHECK-SAME:     library_call = "some_external_function_name_1"}
+//  CHECK-SAME:     {foo = 1 : i64}
+//  CHECK-SAME:     %{{.*}}, %{{.*}}
+//       CHECK:     tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
 //       CHECK:   return {{.*}} : tensor<?x?x?xf32>
 
 // -----
 
-func @foo(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) -> f32 {
-  %f0 = constant 0.0 : f32
-  return %f0 : f32
-}
-
 #accesses = [
   affine_map<(i, j, k) -> (j, i)>,
   affine_map<(i, j, k) -> (i, k, i + j)>
@@ -389,22 +382,26 @@ func @foo(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) -> f32 {
   args_out = 1,
   indexing_maps = #accesses,
   iterator_types = ["parallel", "parallel", "parallel"],
-  fun = @foo,
   library_call = "some_external_function_name_1"
 }
 
 func @indexed_generic_with_tensor_input_and_output(
     %arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
     -> (tensor<?x?x?xf32>) {
-  %0 = linalg.indexed_generic #trait2 %arg0, %arg1 {foo = 1} :
-    tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+  %0 = linalg.indexed_generic #trait2 {foo = 1} %arg0, %arg1 {
+    ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) :
+      %f0 = constant 0.0 : f32
+      linalg.yield %f0 : f32
+  } : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
   return %0 : tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
-//       CHECK:   linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo,
+//       CHECK:   linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-//  CHECK-SAME:     library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}:
-//  CHECK-SAME:     tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+//  CHECK-SAME:     library_call = "some_external_function_name_1"}
+//  CHECK-SAME:     {foo = 1 : i64}
+//  CHECK-SAME:     %{{.*}}, %{{.*}}
+//       CHECK:     tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
 //       CHECK:   return {{.*}} : tensor<?x?x?xf32>
 
 // -----
@@ -460,10 +457,10 @@ func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
 
 func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
                      %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.generic #trait3 %arg0, %arg1 {
+  linalg.generic #trait3 {foo = 1} %arg0, %arg1 {
     ^bb(%a: vector<3x4xi4>, %b: f32) :
       linalg.yield %b : f32
-  } {foo = 1}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
+  } : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
                memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
   return
 }
@@ -471,17 +468,18 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
 //       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
 //  CHECK-SAME:     library_call = "some_external_function_name_2"
+//  CHECK-SAME:     {foo = 1 : i64}
 //       CHECK:    ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
 //       CHECK:      linalg.yield %{{.*}} : f32
-//       CHECK:    } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
-//  CHECK-SAME:                       memref<?x?x?xf32, #[[strided3D]]>
+//       CHECK:    memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
+//  CHECK-SAME:    memref<?x?x?xf32, #[[strided3D]]>
 
 func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
                       %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.indexed_generic #trait3 %arg0, %arg1 {
+  linalg.indexed_generic #trait3 {foo = 1} %arg0, %arg1 {
   ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
       linalg.yield %b : f32
-  } {foo = 1}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
+  }: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
                memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
   return
 }
@@ -489,9 +487,10 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
 //       CHECK:   linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64,
 //  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
 //  CHECK-SAME:     library_call = "some_external_function_name_2"
+//  CHECK-SAME:     {foo = 1 : i64}
 //       CHECK:    ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
 //       CHECK:      linalg.yield %{{.*}} : f32
-//       CHECK:    } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
+//       CHECK:    }: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
 //  CHECK-SAME:                       memref<?x?x?xf32, #[[strided3D]]>
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index a0a7b74d4257..7f76819b0849 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -212,57 +212,71 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
 // CHECK-LABEL: func @test_vectorize_fill
 //       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
 
-func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
-          %d = mulf %a, %b: f32
-          %e = addf %c, %d: f32
-          return %e: f32
-        }
 #matmul_accesses = [
-          affine_map<(m, n, k) -> (m, k)>,
-          affine_map<(m, n, k) -> (k, n)>,
-          affine_map<(m, n, k) -> (m, n)>
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
 ]
 #generic_matmul_trait = {
-          args_in = 2,
-          args_out = 1,
-          fun = @fma,
-          indexing_maps = #matmul_accesses,
-          library_call = "linalg_matmul",
-          iterator_types = ["parallel", "parallel", "reduction"]
-        }
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = #matmul_accesses,
+  library_call = "linalg_matmul",
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
 func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
            %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
            %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.generic #generic_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
-
+  linalg.generic #generic_matmul_trait %A, %B, %C {
+    ^bb(%a: f32, %b: f32, %c: f32):
+      %d = mulf %a, %b: f32
+      %e = addf %c, %d: f32
+      linalg.yield %e: f32
+  }: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+     memref<?x?xf32, offset: ?, strides: [?, 1]>,
+     memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 // CHECK-LABEL : func @fma
 // CHECK-LABEL : func @permute_generic
-// CHECK       : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+// CHECK       : linalg.generic {args_in = 2, args_out = 1,
+// CHECK-SAME  : indexing_maps = [#[[kn]], #[[nm]], #[[km]]],
+// CHECK-SAME  : iterator_types = ["parallel", "reduction", "parallel"],
+// CHECK-SAME  : library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}}
+// CHECK       :   memref<?x?xf32, #[[STRIDED_2D]]>,
+// CHECK-SAME  :   memref<?x?xf32, #[[STRIDED_2D]]>,
+// CHECK-SAME  :   memref<?x?xf32, #[[STRIDED_2D]]>
 
-func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
-          %d = mulf %a, %b: f32
-          %e = addf %c, %d: f32
-          return %e: f32
-}
 #indexed_matmul_trait = {
-          args_in = 2,
-          args_out = 1,
-          fun = @fma_indexed,
-          indexing_maps = #matmul_accesses,
-          library_call = "linalg_matmul_indexed",
-          iterator_types = ["parallel", "parallel", "reduction"]
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = #matmul_accesses,
+  library_call = "linalg_matmul_indexed",
+  iterator_types = ["parallel", "parallel", "reduction"]
 }
-func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-           %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-           %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.indexed_generic #indexed_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
+func @permute_generic_indexed(
+    %A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+    %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+    %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+  linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
+    ^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
+      %d = mulf %a, %b: f32
+      %e = addf %c, %d: f32
+      linalg.yield %e: f32
+  } : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+      memref<?x?xf32, offset: ?, strides: [?, 1]>,
+      memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 // CHECK-LABEL : func @fma_indexed
 // CHECK-LABEL : func @permute_generic_indexed
-// CHECK       : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+// CHECK       : linalg.indexed_generic {args_in = 2, args_out = 1,
+// CHECK-SAME  :   indexing_maps = [#[[kn]], #[[nm]], #[[km]]],
+// CHECK-SAME  :   iterator_types = ["parallel", "reduction", "parallel"],
+// CHECK-SAME  :   library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} :
+// CHECK       :     memref<?x?xf32, #[[STRIDED_2D]]>,
+// CHECK-SAME  :     memref<?x?xf32, #[[STRIDED_2D]]>,
+// CHECK-SAME  :     memref<?x?xf32, #[[STRIDED_2D]]>
 
 func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
           %y: memref<?xf32, offset: ?, strides: [1]>,

diff  --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index 795247ad2dff..a55cdbffbdb6 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -111,7 +111,7 @@ def : Pattern<(FillOp:$op $_, $_),
                 HasLinalgTransformMarker<"VECTORIZE">,
                 PreconditionVectorizeLinalgOp
                ]>>)]>;
-def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
+def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
               [(VectorizeLinalgOp)],
               [(Constraint<And<[
                 HasLinalgTransformMarker<"VECTORIZE">,
@@ -122,7 +122,7 @@ def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
 //===----------------------------------------------------------------------===//
 // Linalg generic permutation patterns.
 //===----------------------------------------------------------------------===//
-def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
+def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
               (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
               [(Constraint<And<[
                 HasNoLinalgTransformMarker,
@@ -130,7 +130,7 @@ def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
                 PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
               ]>>)]>;
 
-def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
+def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_),
               (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
               [(Constraint<And<[
                 HasNoLinalgTransformMarker,


        


More information about the Mlir-commits mailing list