[Mlir-commits] [mlir] 335d2df - [mlir][Python][Linalg] Add missing attributes to linalg ops

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 1 01:21:00 PDT 2021


Author: Nicolas Vasilache
Date: 2021-04-01T08:16:50Z
New Revision: 335d2df5335f95d49c864ecdba4fd5731c7c3e89

URL: https://github.com/llvm/llvm-project/commit/335d2df5335f95d49c864ecdba4fd5731c7c3e89
DIFF: https://github.com/llvm/llvm-project/commit/335d2df5335f95d49c864ecdba4fd5731c7c3e89.diff

LOG: [mlir][Python][Linalg] Add missing attributes to linalg ops

This revision tightens up the handling of attributes for both named
and generic linalg ops.
To demonstrate the IR validity, a working e2e Linalg example is added.

Differential Revision: https://reviews.llvm.org/D99430

Added: 
    mlir/test/Bindings/Python/dialects/linalg/opsrun.py

Modified: 
    mlir/include/mlir-c/AffineMap.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/lib/CAPI/IR/AffineMap.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/Bindings/Python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
index de4f42f09f249..e35b7cc6b51d5 100644
--- a/mlir/include/mlir-c/AffineMap.h
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -169,6 +169,17 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults);
 MLIR_CAPI_EXPORTED MlirAffineMap
 mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults);
 
+/// Returns the simplified affine map resulting from dropping the symbols that
+/// do not appear in any of the individual maps in `affineMaps`.
+/// Asserts that all maps in `affineMaps` are normalized to the same number of
+/// dims and symbols.
+/// Takes a callback `populateResult` to fill the `res` container with value
+/// `m` at entry `idx`. This allows returning without worrying about ownership
+/// considerations.
+MLIR_CAPI_EXPORTED void mlirAffineMapCompressUnusedSymbols(
+    MlirAffineMap *affineMaps, intptr_t size, void *result,
+    void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m));
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index abc3e1b4a6fe7..2463452a8f06d 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -340,6 +340,11 @@ AffineMap simplifyAffineMap(AffineMap map);
 /// Drop the dims that are not used.
 AffineMap compressUnusedDims(AffineMap map);
 
+/// Drop the dims that are not used by any of the individual maps in `maps`.
+/// Asserts that all maps in `maps` are normalized to the same number of
+/// dims and symbols.
+SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
+
 /// Drop the dims that are not listed in `unusedDims`.
 AffineMap compressDims(AffineMap map,
                        const llvm::SmallDenseSet<unsigned> &unusedDims);
@@ -347,6 +352,11 @@ AffineMap compressDims(AffineMap map,
 /// Drop the symbols that are not used.
 AffineMap compressUnusedSymbols(AffineMap map);
 
+/// Drop the symbols that are not used by any of the individual maps in `maps`.
+/// Asserts that all maps in `maps` are normalized to the same number of
+/// dims and symbols.
+SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
+
 /// Drop the symbols that are not listed in `unusedSymbols`.
 AffineMap compressSymbols(AffineMap map,
                           const llvm::SmallDenseSet<unsigned> &unusedSymbols);

diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 73a57d95e1586..5d3b790b35d0e 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -538,6 +538,23 @@ void mlir::python::populateIRAffine(py::module &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
+      .def_static("compress_unused_symbols",
+                  [](py::list affineMaps, DefaultingPyMlirContext context) {
+                    SmallVector<MlirAffineMap> maps;
+                    pyListToVector<PyAffineMap, MlirAffineMap>(
+                        affineMaps, maps, "attempting to create an AffineMap");
+                    std::vector<MlirAffineMap> compressed(affineMaps.size());
+                    auto populate = [](void *result, intptr_t idx,
+                                       MlirAffineMap m) {
+                      static_cast<MlirAffineMap *>(result)[idx] = (m);
+                    };
+                    mlirAffineMapCompressUnusedSymbols(
+                        maps.data(), maps.size(), compressed.data(), populate);
+                    std::vector<PyAffineMap> res;
+                    for (auto m : compressed)
+                      res.push_back(PyAffineMap(context->getRef(), m));
+                    return res;
+                  })
       .def_property_readonly(
           "context",
           [](PyAffineMap &self) { return self.getContext().getObject(); },

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2395a422e354b..682f191387017 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -19,6 +19,13 @@
     "emit_named_structured_op",
 ]
 
+def isa(cls : Type, ty : Type):
+  try:
+    cls(ty)
+    return True
+  except ValueError:
+    return False
+
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
                                  *ins: Value,
                                  outs: Value):
@@ -37,6 +44,8 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
                                            out_arg_defs, outs)
 
+  result_types = [t for t in out_types if isa(RankedTensorType, t)]
+
   # Extract type vars for input/output based types.
   type_mapping = dict()  # type: Dict[str, Type]
   for arg_def, arg_element_type in zip(
@@ -48,30 +57,37 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   # Emit the generic op.
   # TODO: Support emission of pure memref form.
   indexing_maps_attr = ArrayAttr.get(
-      [AffineMapAttr.get(am) for am in op_config.indexing_maps])
+      [AffineMapAttr.get(am)
+       # TODO: linalg verification does not currently allow symbols.
+       # Compress them for now.
+       for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
   iterator_types_attr = ArrayAttr.get(
       [StringAttr.get(s) for s in op_config.iterator_types])
+  sparse_attr = ArrayAttr.get(
+      [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)])
+  if len(sparse_attr) == 0:
+    sparse_attr = None
 
-  return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types,
-          type_mapping, indexing_maps_attr, iterator_types_attr)
+  return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
+          type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr)
 
 
 def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
                                *ins: Value,
                                outs: Value = ()):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
-  type_mapping, indexing_maps_attr, iterator_types_attr =   \
+  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
+  type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr =   \
      prepare_common_structured_op(op_config, *ins, outs = outs)
 
   generic_op = linalg.GenericOp(
-      result_tensors=out_types,
+      result_tensors=result_types,
       inputs=ins,
       outputs=outs,
       indexing_maps=indexing_maps_attr,
       iterator_types=iterator_types_attr,
       doc=None,  # TODO: Make optional.
       library_call=None,  # TODO: Make optional.
-      sparse=BoolAttr.get(False))  # TODO: Make optional.
+      sparse=sparse_attr)  # TODO: Make optional.
 
   # Construct the body.
   block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs)
@@ -84,7 +100,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
       body_builder.assign(assignment)
     body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
 
-  if len(out_arg_defs) == 1:
+  if len(result_types) == 1:
     return generic_op.result
   else:
     return generic_op.results
@@ -95,8 +111,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
                              op_class_name: str,
                              *ins: Value,
                              outs: Value = ()):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
-  type_mapping, indexing_maps_attr, iterator_types_attr =   \
+  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
+  type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr =   \
      prepare_common_structured_op(op_config, *ins, outs = outs)
 
   # If we get here, there must exist a builtin class `op_class_name`.
@@ -107,11 +123,16 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
     raise NotImplementedError(
         f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
 
-  named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
+  named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
   linalgDialect = ctx.get_dialect_descriptor("linalg")
   fill_builtin_region(linalgDialect, named_op.operation)
+  # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
+  # attribute that the non-yaml path does not. The non-yaml path hardcodes the 
+  # indexing_maps in C++ directly.
+  named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
+  # iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
 
-  if len(out_arg_defs) == 1:
+  if len(result_types) == 1:
     return named_op.result
   else:
     return named_op.results

diff  --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
index f532d5dae72e0..e0c07afc3b75e 100644
--- a/mlir/lib/CAPI/IR/AffineMap.cpp
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -137,3 +137,14 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap,
                                           intptr_t numResults) {
   return wrap(unwrap(affineMap).getMinorSubMap(numResults));
 }
+
+void mlirAffineMapCompressUnusedSymbols(
+    MlirAffineMap *affineMaps, intptr_t size, void *result,
+    void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) {
+  SmallVector<AffineMap> maps;
+  for (intptr_t idx = 0; idx < size; ++idx)
+    maps.push_back(unwrap(affineMaps[idx]));
+  intptr_t idx = 0;
+  for (auto m : mlir::compressUnusedSymbols(maps))
+    populateResult(result, idx++, wrap(m));
+}

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index dc9a5c54c7ff7..4e25f8c4deea5 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -543,6 +543,41 @@ AffineMap mlir::compressUnusedDims(AffineMap map) {
   return compressDims(map, unusedDims);
 }
 
+static SmallVector<AffineMap>
+compressUnusedImpl(ArrayRef<AffineMap> maps,
+                   llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
+  if (maps.empty())
+    return SmallVector<AffineMap>();
+  SmallVector<AffineExpr> allExprs;
+  allExprs.reserve(maps.size() * maps.front().getNumResults());
+  unsigned numDims = maps.front().getNumDims(),
+           numSymbols = maps.front().getNumSymbols();
+  for (auto m : maps) {
+    assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
+           "expected maps with same num dims and symbols");
+    llvm::append_range(allExprs, m.getResults());
+  }
+  AffineMap unifiedMap = compressionFun(
+      AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
+  unsigned unifiedNumDims = unifiedMap.getNumDims(),
+           unifiedNumSymbols = unifiedMap.getNumSymbols();
+  ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
+  SmallVector<AffineMap> res;
+  res.reserve(maps.size());
+  for (auto m : maps) {
+    res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols,
+                                 unifiedResults.take_front(m.getNumResults()),
+                                 m.getContext()));
+    unifiedResults = unifiedResults.drop_front(m.getNumResults());
+  }
+  return res;
+}
+
+SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
+  return compressUnusedImpl(maps,
+                            [](AffineMap m) { return compressUnusedDims(m); });
+}
+
 AffineMap
 mlir::compressSymbols(AffineMap map,
                       const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
@@ -576,6 +611,11 @@ AffineMap mlir::compressUnusedSymbols(AffineMap map) {
   return compressSymbols(map, unusedSymbols);
 }
 
+SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
+  return compressUnusedImpl(
+      maps, [](AffineMap m) { return compressUnusedSymbols(m); });
+}
+
 AffineMap mlir::simplifyAffineMap(AffineMap map) {
   SmallVector<AffineExpr, 8> exprs;
   for (auto e : map.getResults()) {

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
index f27f79a4fb037..5445daefa49f0 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -37,9 +37,9 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
     # Note that these all have the same indexing maps. We verify the first and
     # then do more permutation tests on casting and body generation
     # behavior.
-    # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+    # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+    # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+    # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
     # CHECK-LABEL: func @test_matmul_mono
     # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py
index 489aa63f47fdf..afcb5820a2216 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/ops.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py
@@ -94,6 +94,7 @@ def named_form(lhs, rhs):
         init_result = linalg.InitTensorOp([4, 8], f32)
         # First check the named form with custom format
         #      CHECK: linalg.matmul
+        #  CHECK-NOT: linalg.memoized_indexing_maps
         # CHECK-SAME:    ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
         # CHECK-SAME:   outs(%{{.*}} : tensor<4x8xf32>)
         # CHECK-SAME:   -> tensor<4x8xf32>
@@ -118,7 +119,7 @@ def named_form(lhs, rhs):
         # CHECK-NEXT:    std.mulf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    std.addf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-        # CHECK-NEXT:    {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : 
+        # CHECK-NEXT:    {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : 
         # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
         return linalg.matmul(lhs, rhs, outs=[init_result.result])
 

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/opsrun.py b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py
new file mode 100644
index 0000000000000..c46863ba90360
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py
@@ -0,0 +1,105 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import sys
+from mlir.ir import *
+from mlir.dialects import builtin
+from mlir.dialects import linalg
+from mlir.dialects import std
+from mlir.passmanager import *
+from mlir.execution_engine import *
+
+# Log everything to stderr and flush so that we have a unified stream to match
+# errors/info emitted by MLIR to stderr.
+def log(*args):
+  print(*args, file=sys.stderr)
+  sys.stderr.flush()
+
+boilerplate = """
+func @main() -> f32 attributes {llvm.emit_c_interface} {
+  %v0 = constant 0.0 : f32
+  %v1 = constant 1.0 : f32
+  %v2 = constant 2.0 : f32
+
+  %A = memref.alloc() : memref<4x16xf32>
+  %B = memref.alloc() : memref<16x8xf32>
+  %C = memref.alloc() : memref<4x8xf32>
+  linalg.fill(%A, %v1) : memref<4x16xf32>, f32
+  linalg.fill(%B, %v2) : memref<16x8xf32>, f32
+  linalg.fill(%C, %v0) : memref<4x8xf32>, f32
+
+  call @matmul_on_buffers(%A, %B, %C) : 
+    (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
+
+  %c0 = constant 0 : index
+  %0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
+
+  // TODO: FFI-based solution to allow testing and printing with python code.
+  return %0 : f32
+}
+"""
+
+def transform(module):
+  import mlir.conversions
+  import mlir.dialects.linalg.passes
+  import mlir.transforms
+
+  # TODO: Allow cloning functions from one module to another.
+  # Atm we have to resort to string concatenation.
+  mod = Module.parse(
+    str(module.operation.regions[0].blocks[0].operations[0].operation) +
+    boilerplate)
+  pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," + 
+                         "convert-vector-to-llvm," + 
+                         "convert-std-to-llvm")
+  pm.run(mod)
+  return mod
+
+def test_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
+                                   MemRefType.get((16, 8), f32),
+                                   MemRefType.get((4, 8), f32))
+      def matmul_on_buffers(lhs, rhs, out):
+        linalg.matmul(lhs, rhs, outs=[out])
+    
+    execution_engine = ExecutionEngine(transform(module))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result f32.
+    # Arguments must be passed as pointers.
+    c_float_p = ctypes.c_float * 1
+    res = c_float_p(-1.)
+    execution_engine.invoke("main", res)
+
+    log('RESULT: ', res[0])
+    # CHECK: RESULT: 32.0
+
+test_builtin()
+
+def test_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
+                                   MemRefType.get((16, 8), f32),
+                                   MemRefType.get((4, 8), f32))
+      def matmul_on_buffers(lhs, rhs, out):
+        linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
+    
+    execution_engine = ExecutionEngine(transform(module))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result f32.
+    # Arguments must be passed as pointers.
+    c_float_p = ctypes.c_float * 1
+    res = c_float_p(-1.)
+    execution_engine.invoke("main", res)
+
+    log('RESULT: ', res[0])
+    # CHECK: RESULT: 32.0
+
+test_generic()


        


More information about the Mlir-commits mailing list