[Mlir-commits] [mlir] 662f9bf - [mlir][linalg][python] Adapt the OpDSL to use scalars.

Tobias Gysi llvmlistbot at llvm.org
Tue Jun 15 06:20:19 PDT 2021


Author: Tobias Gysi
Date: 2021-06-15T12:54:00Z
New Revision: 662f9bff337b99819301113fc8634eb5123b9e23

URL: https://github.com/llvm/llvm-project/commit/662f9bff337b99819301113fc8634eb5123b9e23
DIFF: https://github.com/llvm/llvm-project/commit/662f9bff337b99819301113fc8634eb5123b9e23.diff

LOG: [mlir][linalg][python] Adapt the OpDSL to use scalars.

The patch replaces the existing capture functionality by scalar operands that have been introduced by https://reviews.llvm.org/D104109. Scalar operands behave as tensor operands except for the fact that they are not indexed. As a result ScalarDefs can be accessed directly as no indexing expression is needed.

The patch only updates the OpDSL. The C++ side is updated by a follow up patch.

Differential Revision: https://reviews.llvm.org/D104220

Added: 
    

Modified: 
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opdsl/assignments.py
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/dialects/linalg/opsrun.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 2ac0641a309f7..fe067d6947138 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -8,7 +8,7 @@
 represent actual op definitions (i.e. YAML).
 """
 
-from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
 
 from mlir import ir as _ir
 
@@ -50,7 +50,7 @@ def visit_affine_exprs(expr):
     self.visit_tensor_exprs(visit_affine_exprs)
     return results
 
-  def collect_uses(self, uses: Set["TensorUse"]):
+  def collect_tensor_uses(self, uses: Set["TensorUse"]):
     """Collects all TensorUses reachable through this expression."""
 
     def visit_tensor_use(expr):
@@ -68,14 +68,14 @@ def visit_index(expr):
 
     self.visit_tensor_exprs(visit_index)
 
-  def collect_captures(self, captures: Set["CaptureDef"]):
-    """Collects all CaptureDefs reachable through this expression."""
+  def collect_scalar_uses(self, uses: Set["ScalarDef"]):
+    """Collects all ScalarDefs reachable through this expression."""
 
-    def visit_capture_def(expr):
-      if isinstance(expr, CaptureDef):
-        captures.add(expr)
+    def visit_scalar_def(expr):
+      if isinstance(expr, ScalarDef):
+        uses.add(expr)
 
-    self.visit_tensor_exprs(visit_capture_def)
+    self.visit_tensor_exprs(visit_scalar_def)
 
   def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
     return PrimFn.add(self, rhs)
@@ -101,19 +101,19 @@ class TensorUse(TensorExpression):
     TensorDef.__setitem__
   """
 
-  def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]):
-    self.tensor_def = tensor_def
+  def __init__(self, operand_def: "OperandDef",
+               indices: Sequence[AffineExprDef]):
+    self.operand_def = operand_def
     self.indices = tuple(indices)
 
   def to_scalar_expression(self) -> ScalarExpression:
-    assert self.tensor_def.tensor_name is not None
-    return ScalarArg(self.tensor_def.tensor_name).expr()
+    return ScalarArg(self.tensor_name).expr()
 
   @property
   def tensor_name(self) -> str:
-    n = self.tensor_def.tensor_name
-    assert n is not None, "TensorDef not attached"
-    return n
+    name = self.operand_def.name
+    assert name is not None, "TensorDef not attached"
+    return name
 
   def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
     return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
@@ -133,40 +133,57 @@ def __repr__(self):
     return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
 
 
-class TensorDef:
-  """Bookkeeping of a single registered tensor, held in dict by name."""
+class OperandDef:
+  """Definition of a Tensor or Scalar operand passed to an operation."""
 
-  def __init__(self,
-               type_var: TypeVar,
-               *shape: AffineExprDef,
-               indexing_map: Optional[_ir.AffineMap] = None,
-               output: bool = False):
+  def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef],
+               scalar: bool, output: bool):
     if not isinstance(type_var, TypeVar):
-      raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}")
+      raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}")
     self.owner = None  # type: Optional["LinalgOpDef"]
     self.type_var = type_var
     self.shape = shape
-    self.indexing_map = indexing_map
+    self.scalar = scalar
     self.output = output
-    self.tensor_name = None  # type: Optional[str]
+    self.name = None  # type: Optional[str]
     self.registered_index = -1  # type: int
 
-  @property
-  def rank(self) -> int:
-    """The rank of the tensor."""
-    return len(self.shape)
-
-  def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"):
+  def attach(self, index: int, name: str, owner: "LinalgOpDef"):
     if self.owner:
-      raise ValueError(f"TensorDef already registered with op: {self}")
+      raise ValueError(f"OperandDef already registered with op: {self}")
     self.registered_index = index
-    self.tensor_name = tensor_name
+    self.name = name
     self.owner = owner
 
+  def __hash__(self):
+    return hash(id(self))
+
+  def __repr__(self):
+    output = "OUTPUT " if self.output else ""
+    scalar = "SCALAR " if self.scalar else ""
+    return (f"{self.name}:OperandDef({output}{scalar}"
+            f"{repr(self.type_var)}, shape={self.shape})")
+
+
+class TensorDef:
+  """Tensor operand definition.
+
+  Tensor operands are indexed using the associated indexing_map when forwarded
+  to the body of the structured op. A unique name identifies the tensor operands
+  and an index determines their position in the operation's parameter list.
+  """
+
+  def __init__(self,
+               type_var: TypeVar,
+               *shape: AffineExprDef,
+               output: bool = False):
+    self.operand_def = OperandDef(type_var, shape, False, output)
+
   def __getitem__(self, dims) -> TensorUse:
-    assert self.owner, "TensorDef is not attached to an op"
+    assert self.operand_def.owner, "TensorDef is not attached to an op"
     state = AffineBuildState(
-        global_state=self.owner._affine_state, allow_new_symbols=False)
+        global_state=self.operand_def.owner._affine_state,
+        allow_new_symbols=False)
     if not isinstance(dims, tuple):
       dims = (dims,)  # Handle single subscript case.
     # Special case: (None) is a 0d-scalar use.
@@ -179,7 +196,7 @@ def __getitem__(self, dims) -> TensorUse:
         raise KeyError(
             "A TensorDef can only be subscripted by a tuple of affine dims")
       exprs.append(expr_def)
-    return TensorUse(self, exprs)
+    return TensorUse(self.operand_def, exprs)
 
   def __setitem__(self, dims, value):
     """Creates a new 1:1 comprehension by binding this tensor to an expression.
@@ -192,46 +209,28 @@ def __setitem__(self, dims, value):
                        f"Got: {repr(value)}")
     use = self[dims]
     comp = Comprehension((use, value))
-    self.owner.comprehensions.append(comp)
+    self.operand_def.owner.comprehensions.append(comp)
 
-  def __hash__(self):
-    return hash(id(self))
 
-  def __repr__(self):
-    output = "OUTPUT " if self.output else ""
-    return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
-            f"shape={self.shape})")
-
-
-class CaptureDef(TensorExpression):
-  """Defines an SSA value captured by the operation.
+class ScalarDef(TensorExpression):
+  """Scalar operand definition.
 
-  The captured SSA values are not indexed by the indexing_maps of the
-  structured op (as opposed to memrefs and tensors). A unique name
-  identifies the captures and an index determines their position the
-  operation's parameter list.
+  Scalar operands are forwarded to the body of the structured op as they are.
+  A unique name identifies the scalars and an index determines their position in
+  the operation's parameter list.
   """
 
   def __init__(self, type_var: TypeVar):
-    if not isinstance(type_var, TypeVar):
-      raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}")
-    self.owner = None  # type: Optional["LinalgOpDef"]
-    self.type_var = type_var
-    self.capture_name = None  # type: Optional[str]
-    self.registered_index = -1  # type: int
+    self.operand_def = OperandDef(type_var, (), True, False)
 
-  def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"):
-    if self.owner:
-      raise ValueError(f"CaptureDef already registered with op: {self}")
-    self.registered_index = index
-    self.capture_name = capture_name
-    self.owner = owner
+  @property
+  def scalar_name(self) -> str:
+    name = self.operand_def.name
+    assert name is not None, "ScalarDef not attached"
+    return name
 
   def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarCapture(self.capture_name).expr()
-
-  def __repr__(self):
-    return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})")
+    return ScalarArg(self.scalar_name).expr()
 
 
 class Comprehension:
@@ -472,43 +471,34 @@ def __init__(self,
                doc: Optional[str] = None):
     self.metadata = OpMetadataDef(
         name=name, cpp_class_name=cpp_class_name, doc=doc)
-    self.registered_tensors = dict()  # type: Dict[str, TensorDef]
-    self.registered_captures = dict()  # type: Dict[str, CaptureDef]
+    self.registered_operands = dict()  # type: Dict[str, OperandDef]
     self.comprehensions = list()  # type: List[Comprehension]
     self._affine_state = AffineBuildState()
 
   @property
-  def inputs(self) -> Sequence[TensorDef]:
-    return [t for t in self.registered_tensors.values() if not t.output]
+  def outputs(self) -> Sequence[OperandDef]:
+    return [
+        operand for operand in self.registered_operands.values()
+        if operand.output
+    ]
 
-  @property
-  def outputs(self) -> Sequence[TensorDef]:
-    return [t for t in self.registered_tensors.values() if t.output]
-
-  def add_tensor(self, tensor_name: str, tensor: TensorDef):
-    """Registers a tensor."""
-    if tensor_name in self.registered_tensors:
-      raise ValueError(f"Tensor {tensor_name} is already registered "
-                       f"to {self.registered_tensors['tensor_name']}")
-    tensor.attach(len(self.registered_tensors), tensor_name, self)
-    self.registered_tensors[tensor_name] = tensor
-
-  def add_capture(self, capture_name: str, capture: CaptureDef):
-    """Registers a capture."""
-    if capture_name in self.registered_captures:
-      raise ValueError(f"Capture {capture_name} is already registered "
-                       f"to {self.registered_captures['capture_name']}")
-    capture.attach(len(self.registered_captures), capture_name, self)
-    self.registered_captures[capture_name] = capture
+  def add_operand(self, name: str, operand: OperandDef):
+    """Registers an operand."""
+    if name in self.registered_operands:
+      raise ValueError(f"The operand {name} is already registered "
+                       f"to {self.registered_operands['name']}")
+    if not operand.output and self.outputs:
+      raise ValueError(f"The operand {name} is an input registered after "
+                       f"the output {self.outputs[-1]}")
+    operand.attach(len(self.registered_operands), name, self)
+    self.registered_operands[name] = operand
 
   def __repr__(self):
     lines = [
         f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
     ]
-    for name, tensor in self.registered_tensors.items():
-      lines.append(f"  {tensor}")
-    for name, capture in self.registered_captures.items():
-      lines.append(f"  {capture}")
+    for name, operand in self.registered_operands.items():
+      lines.append(f"  {operand}")
     if self.comprehensions:
       lines[-1] += " {"
       for comprehension in self.comprehensions:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 9026e2030e1f2..6dd86334b95a5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -18,11 +18,7 @@
 from .comprehension import *
 from .yaml_helper import *
 
-__all__ = [
-    "LinalgStructuredOpConfig",
-    "LinalgOpConfig",
-    "TensorDefConfig",
-]
+__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"]
 
 
 def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
@@ -43,49 +39,42 @@ def __repr__(self):
     return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
 
 
-class TensorDefConfig(YAMLObject):
-  """Wrapper around a TensorDef with additional context-bound state."""
-  yaml_tag = "LinalgTensorDef"
+class OperandDefConfig(YAMLObject):
+  """Wrapper containing an operand definition with additional state."""
+  yaml_tag = "!LinalgOperandDefConfig"
 
-  def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap):
-    self.tensor_def = tensor_def
-    self.shape_map = shape_map
+  def __init__(self,
+               operand_def: OperandDef,
+               shape_map: Optional[_ir.AffineMap] = None):
+    self.operand_def = operand_def
+    self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
     self.indexing_map = None  # type: Optional[_ir.AffineMap]
 
   @property
-  def usage(self) -> str:
-    if self.tensor_def.output:
-      return "output"
-    else:
-      return "input"
-
-  def to_yaml_custom_dict(self):
-    return dict(
-        name=self.tensor_def.tensor_name,
-        usage=self.usage,
-        shape=_serialize_affine_map(self.shape_map),
-        element_type_var=self.tensor_def.type_var.name,
-    )
-
-  def __repr__(self):
-    return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})"
+  def name(self) -> str:
+    return self.operand_def.name
 
+  @property
+  def type_var(self) -> TypeVar:
+    return self.operand_def.type_var
 
-class CaptureDefConfig(YAMLObject):
-  """Wrapper around a CaptureDef."""
-  yaml_tag = "LinalgCaptureDef"
-
-  def __init__(self, capture_def: CaptureDef):
-    self.capture_def = capture_def
+  @property
+  def usage(self) -> str:
+    if self.operand_def.output:
+      return "output"
+    return "input"
 
   def to_yaml_custom_dict(self):
-    return dict(
-        name=self.capture_def.capture_name,
-        type_var=self.capture_def.type_var.name,
-    )
+    self_dict = dict(name=self.name)
+    self_dict["usage"] = self.usage
+    if not self.operand_def.scalar:
+      self_dict["shape"] = _serialize_affine_map(self.shape_map)
+    self_dict["type_var"] = self.type_var.name
+    return self_dict
 
   def __repr__(self):
-    return f"Def({self.capture_def})"
+    return (f"OperandDefConfig({self.operand_def}, "
+            f"shape_map={self.shape_map}, indexing_map={self.indexing_map})")
 
 
 class LinalgIndexingMapsConfig(YAMLObject):
@@ -124,67 +113,73 @@ def __init__(self,
     self.context = context if context is not None else _ir.Context()
     self.affine_state = AffineBuildState()
     self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
-    self.tensor_args = dict()  # type: Dict[TensorDef, TensorDefConfig]
-    self.capture_args = dict()  # type: Dict[CaptureDef, CaptureDefConfig]
+    self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
     self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
 
     # Compute the ordered set of writes and collect the tensor, capture, and
     # index uses.
-    collected_uses = set()
-    collected_captures = set()
+    collected_tensor_uses = set()
+    collected_scalar_uses = set()
     collected_indices = set()
     for write_use, read_use in zip(comprehension.definitions,
                                    comprehension.values):
       self.writes.append((write_use, read_use))
 
     for write_use, read_use in self.writes:
-      collected_uses.add(write_use)
-      read_use.collect_uses(collected_uses)
-      read_use.collect_captures(collected_captures)
+      collected_tensor_uses.add(write_use)
+      read_use.collect_tensor_uses(collected_tensor_uses)
+      read_use.collect_scalar_uses(collected_scalar_uses)
       read_use.collect_indices(collected_indices)
 
     # Need to add all definitions before uses, so process twice.
-    for use in collected_uses:
-      self.add_tensor_arg(use.tensor_def)
-    for capture in collected_captures:
-      self.add_capture_arg(capture)
-    for use in collected_uses:
-      self.add_use(use)
+    for use in collected_tensor_uses:
+      self.add_operand(use.operand_def)
+    for use in collected_scalar_uses:
+      self.add_operand(use.operand_def)
+    for use in collected_tensor_uses:
+      self.add_tensor_use(use)
 
     # Now normalize all defs and uses indexing maps now that full count of
     # dims and symbols are known.
     for cuse in self.uses.values():
       cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
-    for cdef in self.tensor_args.values():
-      cdef.shape_map = self._normalize_affine_map(
-          cdef.shape_map, with_dims=False)
+    for cdef in self.operands.values():
+      if not cdef.operand_def.scalar:
+        cdef.shape_map = self._normalize_affine_map(
+            cdef.shape_map, with_dims=False)
 
     # Now for each write use, propagate the indexing maps from the use to the
     # tensor, ensuring that there are not conflicts.
     for write_use, _ in self.writes:
-      write_tensor_def = self.tensor_args[write_use.tensor_def]
-      if write_tensor_def.indexing_map:
+      write_tensor_config = self.operands[write_use.operand_def]
+      if write_tensor_config.indexing_map:
         raise ValueError(
-            f"Unexpected multi-write to a single tensor: {write_tensor_def}")
-      write_tensor_def.indexing_map = self.uses[write_use].indexing_map
+            f"Unexpected multi-write to a single tensor: {write_tensor_config}")
+      write_tensor_config.indexing_map = self.uses[write_use].indexing_map
 
     # For each read use, propagate the indexing maps from the use to the
     # tensor, ensuring that there are not conflicts.
     for _, read_expr in self.writes:
       read_uses = set()  # type: Set[TensorUse]
-      read_expr.collect_uses(read_uses)
+      read_expr.collect_tensor_uses(read_uses)
       for read_use in read_uses:
-        read_tensor_def = self.tensor_args[read_use.tensor_def]
-        if (read_tensor_def.indexing_map and
-            read_tensor_def.indexing_map != self.uses[read_use].indexing_map):
+        read_operand_config = self.operands[read_use.operand_def]
+        if (read_operand_config.indexing_map and
+            read_operand_config.indexing_map !=
+            self.uses[read_use].indexing_map):
           raise ValueError(
               f"Unexpected multi-read of a tensor with 
diff erent accesses:"
-              f"{read_tensor_def} vs {read_use}")
-        read_tensor_def.indexing_map = self.uses[read_use].indexing_map
+              f"{read_operand_config} vs {read_use}")
+        read_operand_config.indexing_map = self.uses[read_use].indexing_map
+
+    # Set the indexing map of all scalar uses to the empty map.
+    for operand_config in self.operands.values():
+      if operand_config.operand_def.scalar:
+        operand_config.indexing_map = self._create_empty_affine_map()
 
     # Sanity check that all defs have an indexing map.
-    assert all(d.indexing_map for d in self.tensor_args.values()), (
-        f"Missing indexing map on TensorDef: {self.tensor_args}")
+    assert all(d.indexing_map for d in self.operands.values()), (
+        f"Missing indexing map on OperandConfigDef: {self.operands}")
 
     # Collect reduction dims and ensure all the same.
     all_reduction_dims = set(comprehension.all_reduction_dims)
@@ -209,22 +204,10 @@ def __init__(self,
     ]
 
   @property
-  def ordered_tensor_args(self) -> Sequence[TensorDefConfig]:
+  def ordered_operands(self) -> Sequence[OperandDefConfig]:
     return sorted(
-        self.tensor_args.values(),
-        key=lambda tdc: tdc.tensor_def.registered_index)
-
-  @property
-  def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]:
-    return sorted(
-        self.uses.values(),
-        key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
-
-  @property
-  def ordered_capture_args(self) -> Sequence[CaptureDefConfig]:
-    return sorted(
-        self.capture_args.values(),
-        key=lambda cdc: cdc.capture_def.registered_index)
+        self.operands.values(),
+        key=lambda operand: operand.operand_def.registered_index)
 
   @property
   def ordered_dims(self) -> Sequence[Tuple[str, int]]:
@@ -238,7 +221,7 @@ def ordered_dims(self) -> Sequence[Tuple[str, int]]:
 
   @property
   def indexing_maps(self) -> Sequence[_ir.AffineMap]:
-    return [use.indexing_map for use in self.ordered_tensor_uses]
+    return [d.indexing_map for d in self.ordered_operands]
 
   @property
   def iterator_types(self) -> Sequence[str]:
@@ -251,23 +234,25 @@ def get_type(symbolic_name, position):
 
     return [get_type(*dim) for dim in self.ordered_dims]
 
-  def add_tensor_arg(self, tensor_def: TensorDef):
-    if tensor_def in self.tensor_args:
+  def add_operand(self, operand_def: OperandDef):
+    if operand_def in self.operands:
+      return
+    if operand_def.scalar:
+      self.operands[operand_def] = OperandDefConfig(operand_def)
       return
     with self.context:
       local_state = AffineBuildState(
           global_state=self.affine_state, allow_new_dims=False)
       exprs = []
-      for expr in tensor_def.shape:
+      for expr in operand_def.shape:
         exprs.append(expr.build(state=local_state))
       assert local_state.local_dim_count == 0
-      indexing_map = _ir.AffineMap.get(
+      shape_map = _ir.AffineMap.get(
           dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
+      def_config = OperandDefConfig(operand_def, shape_map)
+      self.operands[operand_def] = def_config
 
-      def_config = TensorDefConfig(tensor_def, indexing_map)
-      self.tensor_args[tensor_def] = def_config
-
-  def add_use(self, tensor_use: TensorUse):
+  def add_tensor_use(self, tensor_use: TensorUse):
     if tensor_use in self.uses:
       return
     with self.context:
@@ -285,11 +270,13 @@ def add_use(self, tensor_use: TensorUse):
       use_config = TensorUseConfig(tensor_use, indexing_map)
       self.uses[tensor_use] = use_config
 
-  def add_capture_arg(self, capture_def: CaptureDef):
-    if capture_def in self.capture_args:
-      return
-    def_config = CaptureDefConfig(capture_def)
-    self.capture_args[capture_def] = def_config
+  def _create_empty_affine_map(self) -> _ir.AffineMap:
+    """Create an affine map with an empty range."""
+    with self.context:
+      return _ir.AffineMap.get(
+          dim_count=self.affine_state.dim_count,
+          symbol_count=self.affine_state.symbol_count,
+          exprs=list())
 
   def _normalize_affine_map(self,
                             affine_map: _ir.AffineMap,
@@ -302,9 +289,7 @@ def _normalize_affine_map(self,
           exprs=list(affine_map.results))
 
   def to_yaml_custom_dict(self):
-    self_dict = dict(args=self.ordered_tensor_args)
-    if self.ordered_capture_args:
-      self_dict["captures"] = self.ordered_capture_args
+    self_dict = dict(args=self.ordered_operands)
     # TODO: Refactor the hierarchy internally when supporting more
     # than static (preserving this serialized form).
     self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
@@ -315,11 +300,8 @@ def to_yaml_custom_dict(self):
 
   def __repr__(self):
     lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
-    lines.append("tensor_args=[")
-    for def_config in self.ordered_tensor_args:
-      lines.append(f"  {repr(def_config)}")
-    lines.append("], capture_args=[")
-    for def_config in self.ordered_capture_args:
+    lines.append("operands=[")
+    for def_config in self.ordered_operands:
       lines.append(f"  {repr(def_config)}")
     lines.append("], indexing_maps=[")
     for m in self.indexing_maps:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 428eadfe01681..191b1b34fd836 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -53,8 +53,8 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
         False, a named form is emitted (which must have been built in to the
         compiler).
     """
-    op_configs = LinalgOpConfig.from_linalg_op_def(self.model,
-                                                   context=ir.Context.current)
+    op_configs = LinalgOpConfig.from_linalg_op_def(
+        self.model, context=ir.Context.current)
 
     if len(op_configs) != 1:
       # TODO: Support composite ops.
@@ -63,8 +63,9 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
 
     ctx = ir.Context.current
     linalgDialect = ctx.get_dialect_descriptor("linalg")
-    fully_qualified_name = 'linalg.' + self.op_name
-    emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))
+    fully_qualified_name = "linalg." + self.op_name
+    emit_generic = (
+        emit_generic or not ctx.is_registered_operation(fully_qualified_name))
 
     op_config = op_configs[0]
     if op_config.structured_op:
@@ -72,9 +73,9 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
         return emit_generic_structured_op(op_config.structured_op, *args,
                                           **kwargs)
       else:
-        return emit_named_structured_op(
-          op_config.structured_op, self.op_name,
-          self.model.metadata.cpp_class_name, *args, **kwargs)
+        return emit_named_structured_op(op_config.structured_op, self.op_name,
+                                        self.model.metadata.cpp_class_name,
+                                        *args, **kwargs)
 
     raise NotImplementedError(
         f"Emission of linalg op type not supported: {op_config}")
@@ -86,9 +87,8 @@ def linalg_structured_op(dsl_func=None,
                          op_class_name=None) -> DefinedOpCallable:
   if dsl_func is None:
     # Curry the keyword args in for delayed application.
-    return functools.partial(tc_def_op,
-                             op_name=op_name,
-                             op_class_name=op_class_name)
+    return functools.partial(
+        tc_def_op, op_name=op_name, op_class_name=op_class_name)
   # Determine default names by introspecting the function.
   if op_name is None:
     op_name = dsl_func.__name__
@@ -96,9 +96,8 @@ def linalg_structured_op(dsl_func=None,
     # Camel case it.
     op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
 
-  tc_model = LinalgOpDef(name=op_name,
-                         cpp_class_name=op_class_name,
-                         doc=inspect.getdoc(dsl_func))
+  tc_model = LinalgOpDef(
+      name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
 
   # Extract arguments and TensorDefs from the signature.
   dsl_func_args = list()
@@ -106,12 +105,12 @@ def linalg_structured_op(dsl_func=None,
   for param_name, param in sig.parameters.items():
     param_default = param.default
     if isinstance(param_default, TensorDef):
-      tc_model.add_tensor(param_name, param_default)
-    elif isinstance(param_default, CaptureDef):
-      tc_model.add_capture(param_name, param_default)
+      tc_model.add_operand(param_name, param_default.operand_def)
+    elif isinstance(param_default, ScalarDef):
+      tc_model.add_operand(param_name, param_default.operand_def)
     else:
       raise ValueError(f"@tc_def_op function parameters must be defaulted as "
-                       f"TensorDef(...) or CaptureDef(...): Found {param_name}"
+                       f"TensorDef(...) or ScalarDef(...): Found {param_name}"
                        f": {param_default}")
     dsl_func_args.append(param_default)
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 5538a9e42e102..2b8b910507cec 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -29,20 +29,15 @@ def isa(cls: Type, ty: Type):
 
 
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
-                                 *ins: Value, outs: Sequence[Value],
-                                 captures: Sequence[Value]):
-  all_arg_defs = op_config.ordered_tensor_args
+                                 *ins: Value, outs: Sequence[Value]):
+  all_arg_defs = op_config.ordered_operands
   in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
   out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
-  capture_arg_defs = op_config.ordered_capture_args
 
   # Verify outs and captures are sequences.
   if not isinstance(outs, Sequence):
     raise ValueError(f"Expected named argument outs to have type Sequence "
                      f"but got {type(outs)}")
-  if not isinstance(captures, Sequence):
-    raise ValueError(f"Expected named argument captures to have type Sequence "
-                     f"but got {type(outs)}")
 
   # Arity validation.
   if len(ins) != len(in_arg_defs):
@@ -51,9 +46,6 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   if outs and len(outs) != len(out_arg_defs):
     raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
                      f"{len(outs)} for {op_config}")
-  if captures and len(captures) != len(capture_arg_defs):
-    raise ValueError(f"Expected {len(capture_arg_defs)} captures but got "
-                     f"{len(captures)} for {op_config}")
 
   outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
                                            out_arg_defs, outs)
@@ -68,18 +60,10 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   type_mapping["I64"] = IntegerType.get_signless(64)
 
   # Extract type vars for input/output based types.
-  for arg_def, arg_element_type in zip(
-      in_arg_defs + out_arg_defs,
-      _get_shaped_element_types_from_values(*ins, *outs)):
-    _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type,
-                      type_mapping)
-
-  # Extract type vars for captures and compute capture argument mapping.
-  capture_arg_mapping = dict()  # type: Dict[str, Value]
-  for arg_def, capture_value in zip(capture_arg_defs, captures):
-    _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type,
-                      type_mapping)
-    capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value
+  block_arg_types = list()  # type: List[Type]
+  for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs,
+                                       _get_types_from_values(*ins, *outs)):
+    _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
 
   # Emit the generic op.
   # TODO: Support emission of pure memref form.
@@ -94,18 +78,16 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
       [StringAttr.get(s) for s in op_config.iterator_types])
 
   return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
-          type_mapping, capture_arg_mapping, indexing_maps_attr,
-          iterator_types_attr)
+          type_mapping, indexing_maps_attr, iterator_types_attr,
+          block_arg_types)
 
 
 def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
                                *ins: Value,
-                               outs: Sequence[Value] = (),
-                               captures: Sequence[Value] = ()):
+                               outs: Sequence[Value] = ()):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \
-     prepare_common_structured_op(op_config, *ins, outs = outs,
-                                  captures=captures)
+  indexing_maps_attr, iterator_types_attr, block_arg_types = \
+     prepare_common_structured_op(op_config, *ins, outs = outs)
 
   generic_op = linalg.GenericOp(
       result_tensors=result_types,
@@ -117,16 +99,14 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
       library_call=None)  # TODO: Make optional.
 
   # Construct the body.
-  block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs)
-  block_arg_types = _get_shaped_element_types_from_values(*ins, *outs)
+  block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
   block = generic_op.regions[0].blocks.append(*block_arg_types)
   block_arg_mapping = dict(zip(block_arg_names, block.arguments))
   with InsertionPoint(block):
-    body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
-                                capture_arg_mapping)
+    body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
     for assignment in op_config.assignments:
       body_builder.assign(assignment)
-    body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
+    body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
 
   if len(result_types) == 1:
     return generic_op.result
@@ -138,12 +118,10 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
                              op_name: str,
                              op_class_name: str,
                              *ins: Value,
-                             outs: Sequence[Value] = (),
-                             captures: Sequence[Value] = ()):
+                             outs: Sequence[Value] = ()):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \
-     prepare_common_structured_op(op_config, *ins, outs = outs,
-                                  captures = captures)
+  indexing_maps_attr, iterator_types_attr, block_arg_types = \
+     prepare_common_structured_op(op_config, *ins, outs = outs)
 
   # If we get here, there must exist a builtin class `op_class_name`.
   ctx = Context.current
@@ -173,11 +151,9 @@ class _BodyBuilder:
   """Constructs a structured op body by evaluating assignments."""
 
   def __init__(self, type_mapping: Dict[str, Type],
-               block_arg_mapping: Dict[str, Value],
-               capture_arg_mapping: Dict[str, Value]):
+               block_arg_mapping: Dict[str, Value]):
     self.type_mapping = type_mapping
     self.block_arg_mapping = block_arg_mapping
-    self.capture_arg_mapping = capture_arg_mapping
     self.yield_mapping = dict()  # type: Dict[str, Value]
 
   def assign(self, assignment: ScalarAssign):
@@ -194,13 +170,6 @@ def expression(self, expr: ScalarExpression) -> Value:
       except KeyError:
         raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
                          f"this structured op.")
-    elif expr.scalar_capture:
-      try:
-        return self.capture_arg_mapping[expr.scalar_capture.capture]
-      except KeyError:
-        raise ValueError(
-            f"Capture {expr.scalar_capture.capture} is not bound for "
-            f"this structured op.")
     elif expr.scalar_const:
       value_attr = Attribute.parse(expr.scalar_const.value)
       return std.ConstantOp(value_attr.type, value_attr).result
@@ -229,7 +198,7 @@ def cast(self, type_var_name: str, operand: Value) -> Value:
       to_type = self.type_mapping[type_var_name]
     except KeyError:
       raise ValueError(f"Unbound type variable '{type_var_name}' ("
-                       f"expected one of {self.type_mappings.keys()}")
+                       f"expected one of {self.type_mapping.keys()}")
     if operand.type == to_type:
       return operand
     if _is_integer_type(to_type):
@@ -300,9 +269,9 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
 
 
 def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
-                           in_arg_defs: Sequence[TensorDefConfig],
+                           in_arg_defs: Sequence[OperandDefConfig],
                            ins: Sequence[Value],
-                           out_arg_defs: Sequence[TensorDefConfig],
+                           out_arg_defs: Sequence[OperandDefConfig],
                            outs: Sequence[Value]):
   """Infers implicit outs and output types.
 
@@ -319,28 +288,34 @@ def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
                             "structured ops")
 
 
-def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]:
+def _get_types_from_values(*values: Value) -> Sequence[Type]:
   types = []
   for v in values:
-    try:
-      t = ShapedType(v.type)
-    except Exception as e:
-      raise ValueError(f"Expected ShapedType but got {v}") from e
-    types.append(t.element_type)
+    types.append(v.type)
   return types
 
 
-def _get_tensor_def_names(
-    *tensor_def_configs: TensorDefConfig) -> Sequence[str]:
-  return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
+def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]:
+  return [odc.operand_def.name for odc in operand_configs]
 
 
-def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
+def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
+                      type_mapping: Dict[str, Type],
+                      block_arg_types: Sequence[Type]):
+  element_or_self_type = operand_type
+  # Get the element type for tensor operands and the type itself for scalars.
+  if operand_config.operand_def.shape:
+    try:
+      element_or_self_type = ShapedType(operand_type).element_type
+    except Exception as e:
+      raise ValueError(f"Expected ShapedType but got {operand_type}") from e
+  name = operand_config.type_var.name
   if name in type_mapping:
-    if type_mapping[name] != type:
+    if type_mapping[name] != element_or_self_type:
       raise ValueError(f"Cannot overwrite type mapping {name} = "
-                       f"{type_mapping[name]} by type {type}")
-  type_mapping[name] = type
+                       f"{type_mapping[name]} by type {element_or_self_type}")
+  type_mapping[name] = element_or_self_type
+  block_arg_types.append(element_or_self_type)
 
 
 def _is_floating_point_type(t: Type) -> bool:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index 2cc426b6211a0..48627bfab544c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -22,7 +22,6 @@
     "ScalarAssign",
     "ScalarApplyFn",
     "ScalarArg",
-    "ScalarCapture",
     "ScalarConst",
     "ScalarIndex",
     "ScalarExpression",
@@ -57,19 +56,6 @@ def __repr__(self):
     return f"(ScalarArg({self.arg})"
 
 
-class ScalarCapture:
-  """A type of ScalarExpression that references a named capture."""
-
-  def __init__(self, capture: str):
-    self.capture = capture
-
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_capture=self)
-
-  def __repr__(self):
-    return f"(ScalarCapture({self.capture})"
-
-
 class ScalarConst:
   """A type of ScalarExpression representing a constant."""
 
@@ -116,7 +102,6 @@ class ScalarExpression(YAMLObject):
   Can be one of:
     - ScalarApplyFn
     - ScalarArg
-    - ScalarCapture
     - ScalarConst
     - ScalarIndex
     - ScalarSymbolicCast
@@ -126,18 +111,15 @@ class ScalarExpression(YAMLObject):
   def __init__(self,
                scalar_apply: Optional[ScalarApplyFn] = None,
                scalar_arg: Optional[ScalarArg] = None,
-               scalar_capture: Optional[ScalarCapture] = None,
                scalar_const: Optional[ScalarConst] = None,
                scalar_index: Optional[ScalarIndex] = None,
                symbolic_cast: Optional[ScalarSymbolicCast] = None):
-    if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_capture) +
-        bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1:
-      raise ValueError(
-          "One of 'scalar_apply', 'scalar_arg', 'scalar_capture', 'scalar_const', "
-          "'scalar_index', 'symbolic_cast' must be specified")
+    if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_const) +
+        bool(scalar_index) + bool(symbolic_cast)) != 1:
+      raise ValueError("One of 'scalar_apply', 'scalar_arg', 'scalar_const', "
+                       "'scalar_index', 'symbolic_cast' must be specified")
     self.scalar_apply = scalar_apply
     self.scalar_arg = scalar_arg
-    self.scalar_capture = scalar_capture
     self.scalar_const = scalar_const
     self.scalar_index = scalar_index
     self.symbolic_cast = symbolic_cast
@@ -151,8 +133,6 @@ def to_yaml_custom_dict(self):
           ))
     elif self.scalar_arg:
       return dict(scalar_arg=self.scalar_arg.arg)
-    elif self.scalar_capture:
-      return dict(scalar_capture=self.scalar_capture.capture)
     elif self.scalar_const:
       return dict(scalar_const=self.scalar_const.value)
     elif self.scalar_index:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index ad79963450cee..c6586824a840e 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -75,7 +75,11 @@ def dot(
 
 
 @linalg_structured_op
-def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)):
+def fill_rng_2d(
+    min=ScalarDef(F64),
+    max=ScalarDef(F64),
+    seed=ScalarDef(I32),
+    O=TensorDef(T, S.M, S.N, output=True)):
   """Fills the output tensor with pseudo random numbers.
 
   The operation generations pseudo random numbers using a linear congruential
@@ -85,13 +89,7 @@ def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)):
   and runs them in parallel. The seed operand and the indices of the data
   element seed the random number generation. The min and max operands limit
   the range of the generated random numbers.
-
-  Note: The captures are hard-coded till there is capture support on the C++
-  side.
   """
-  min = cast(F64, const(-1000))
-  max = cast(F64, const(+1000))
-  seed = cast(I32, const(42))
   multiplier = cast(I32, const(1103515245))
   increment = cast(I32, const(12345))
   rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index ce11188ba32dc..f9a0b019034b3 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -9,15 +9,15 @@
 # CHECK:     name: A
 # CHECK:     usage: input
 # CHECK:     shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
-# CHECK:     element_type_var: T
+# CHECK:     type_var: T
 # CHECK:     name: B
 # CHECK:     usage: input
 # CHECK:     shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
-# CHECK:     element_type_var: T
+# CHECK:     type_var: T
 # CHECK:     name: C
 # CHECK:     usage: output
 # CHECK:     shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
-# CHECK:     element_type_var: U
+# CHECK:     type_var: U
 @linalg_structured_op
 def matmul(
     A=TensorDef(T, S.M, S.K),
@@ -28,10 +28,11 @@ def matmul(
 
 # CHECK: ---
 # CHECK-LABEL: fill
-# CHECK: captures:
-# CHECK: - !<LinalgCaptureDef>
-# CHECK:   name: value
-# CHECK:   type_var: T
+# CHECK: args:
+# CHECK:     name: value
+# CHECK:     usage: input
+# CHECK-NOT: shape:
+# CHECK:     type_var: T
 @linalg_structured_op
-def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)):
+def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
   O[D.m, D.n] = value

diff  --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index 32c56d1649ad2..508e240e5916a 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -82,7 +82,7 @@ def indices(O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK: assignments:
 # CHECK:  -
 # CHECK:    arg: O
-# CHECK:      scalar_capture: value
+# CHECK:      scalar_arg: value
 @linalg_structured_op
-def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)):
+def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
   O[D.m, D.n] = value

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index f84db9b407a70..6b12dc1167730 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -26,10 +26,10 @@ def matmul_poly(
 
 @linalg_structured_op
 def fill_rng(
-    O=TensorDef(T, S.M, S.N, output=True),
-    min=CaptureDef(F64),
-    max=CaptureDef(F64),
-    seed=CaptureDef(I32)):
+    min=ScalarDef(F64),
+    max=ScalarDef(F64),
+    seed=ScalarDef(I32),
+    O=TensorDef(T, S.M, S.N, output=True)):
   multiplier = cast(I32, const(1103515245))
   increment = cast(I32, const(12345))
   rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
@@ -159,7 +159,7 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
     # CHECK-LABEL: @test_fill_rng
-    # CHECK-SAME:  %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32
+    # CHECK:      ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
     # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
     # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
     # CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
@@ -178,10 +178,10 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
     # CHECK-DAG:    %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
     # CHECK-DAG:    %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
     # CHECK-DAG:    %{{.*}} = fptosi %[[RND5]] : f64 to i32
-    @builtin.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i32), f64, f64, i32)
-    def test_fill_rng(init_result, min, max, seed):
-      return fill_rng(outs=[init_result], captures=[min, max, seed])
+    @builtin.FuncOp.from_py_func(f64, f64, i32,
+                                 RankedTensorType.get((4, 16), i32))
+    def test_fill_rng(min, max, seed, init_result):
+      return fill_rng(min, max, seed, outs=[init_result])
 
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index 8d48f0a340620..2b58f38f36319 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -43,9 +43,12 @@ def log(*args):
 fill_boiler = """
 func @main() -> i32 attributes {llvm.emit_c_interface} {
   %O = memref.alloc() : memref<4x16xi32>
+  %min = constant -1000.0 : f64
+  %max = constant 1000.0 : f64
+  %seed = constant 42 : i32
 
-  call @fill_on_buffers(%O) :
-    (memref<4x16xi32>) -> ()
+  call @fill_on_buffers(%min, %max, %seed, %O) :
+    (f64, f64, i32, memref<4x16xi32>) -> ()
 
   %c0 = constant 0 : index
   %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
@@ -128,33 +131,6 @@ def matmul_on_buffers(lhs, rhs, out):
 test_matmul_generic()
 
 
-def test_fill_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
-
-      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32))
-      def fill_on_buffers(out):
-        linalg.fill_rng_2d(outs=[out])
-
-    execution_engine = ExecutionEngine(transform(module, fill_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: -480
-
-
-test_fill_builtin()
-
-
 def test_fill_generic():
   with Context() as ctx, Location.unknown():
     module = Module.create()
@@ -162,9 +138,9 @@ def test_fill_generic():
     i32 = IntegerType.get_signless(32)
     with InsertionPoint(module.body):
 
-      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32))
-      def fill_on_buffers(out):
-        linalg.fill_rng_2d(outs=[out])
+      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
+      def fill_on_buffers(min, max, seed, out):
+        linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
 
     execution_engine = ExecutionEngine(transform(module, fill_boiler))
 


        


More information about the Mlir-commits mailing list