[Mlir-commits] [mlir] f345f7e - [mlir][OpDSL] Support pointwise ops with rank zero inputs.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 8 09:40:46 PST 2022


Author: gysit
Date: 2022-03-08T17:39:47Z
New Revision: f345f7e30bd3a8e15052f5669c1977aa088e468f

URL: https://github.com/llvm/llvm-project/commit/f345f7e30bd3a8e15052f5669c1977aa088e468f
DIFF: https://github.com/llvm/llvm-project/commit/f345f7e30bd3a8e15052f5669c1977aa088e468f.diff

LOG: [mlir][OpDSL] Support pointwise ops with rank zero inputs.

Allow pointwise operations to take rank zero input tensors similarly to scalar inputs. Use an empty indexing map to broadcast rank zero tensors to the iteration domain of the operation.

Depends On D120734

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120807

Added: 
    

Modified: 
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/python/dialects/linalg/opdsl/emit_fill.py
    mlir/test/python/integration/dialects/linalg/opsrun.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index ff5c405d788a9..93baef14bc197 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -187,7 +187,11 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
       if arg_def.operand_def.kind == OperandKind.SCALAR:
         indexing_maps.append(scalar_map)
       if arg_def.operand_def.is_tensor():
-        indexing_maps.append(tensor_map)
+        idx = arg_def.operand_def.registered_index
+        if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
+          indexing_maps.append(scalar_map)
+        else:
+          indexing_maps.append(tensor_map)
     indexing_maps_attr = ArrayAttr.get(
         [AffineMapAttr.get(am) for am in indexing_maps])
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 425eeb8373276..0c98629041c4e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -320,3 +320,18 @@ func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %o
 
 // CHECK-LABEL: @generalize_elemwise_mul
 // CHECK:        = arith.mulf
+
+// -----
+
+// Verifies pointwise ops support rank zero input tensors
+func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>}
+                              ins(%lhs, %rhs: tensor<f32>, tensor<f32>)
+                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_rank_zero
+// CHECK:       linalg.generic
+// CHECK-SAME:  iterator_types = ["parallel", "parallel"]
+// CHECK:        = arith.subf

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
index 814a6d2a6ccef..55ca50be5fad1 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
@@ -15,6 +15,9 @@
 def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
   O[None] = TypeFn.cast_signed(U, value)
 
+ at linalg_structured_op
+def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
+  O[None] = TypeFn.cast_signed(U, I[None])
 
 with Context() as ctx, Location.unknown():
   module = Module.create()
@@ -25,6 +28,8 @@ def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
     # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
     # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
     # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+    # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
+    # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 
     # CHECK-LABEL: @test_fill_0d
     # CHECK: linalg.generic
@@ -42,5 +47,13 @@ def test_fill_0d(value, init_result):
     def test_fill_2d(value, init_result):
       return fill_poly(value, outs=[init_result])
 
+    # CHECK-LABEL: @test_fill_rank_zero_3d
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
+    def test_fill_rank_zero_3d(input, init_result):
+      return fill_rank_zero_poly(input, outs=[init_result])
 
 print(module)

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 458780dc0eefc..279af2bf77b34 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -25,19 +25,19 @@ def log(*args):
   %v1 = arith.constant 1.0 : f32
   %v2 = arith.constant 2.0 : f32
 
-  %lhs = memref.alloc() : memref<4x8xf32>
+  %lhs = memref.alloc() : memref<f32>
   %rhs = memref.alloc() : memref<4x8xf32>
   %O0 = memref.alloc() : memref<4x8xf32>
   %O1 = memref.alloc() : memref<4x8xf32>
-  linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
+  linalg.fill(%v1, %lhs) : f32, memref<f32>
   linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
   linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
   linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
 
   call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
-    (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
   call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
-    (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
 
   %c0 = arith.constant 0 : index
   %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
@@ -212,14 +212,14 @@ def test_elemwise_builtin():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_exp_add_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(lhs, outs=[out])
         linalg.elemwise_binary(out, rhs, outs=[out])
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_log_mul_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
@@ -251,14 +251,14 @@ def test_elemwise_generic():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_exp_add_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
         linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_log_mul_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 5cade2a24f430..a6963604e1408 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -678,7 +678,7 @@ ArrayAttr {0}::indexing_maps() {{
     getNumParallelLoops(), context);
   SmallVector<AffineMap> indexingMaps;
   for (OpOperand *opOperand : getInputAndOutputOperands())
-    indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap);
+    indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
   return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
 }
 )FMT";


        


More information about the Mlir-commits mailing list