[Mlir-commits] [mlir] 619bfe8 - [mlir][sparse] support new kind of scalar in sparse linalg generic op

Aart Bik llvmlistbot at llvm.org
Wed Jun 16 11:01:04 PDT 2021


Author: Aart Bik
Date: 2021-06-16T11:00:49-07:00
New Revision: 619bfe8bd23f76b22f0a53fedafbfc8c97a15f12

URL: https://github.com/llvm/llvm-project/commit/619bfe8bd23f76b22f0a53fedafbfc8c97a15f12
DIFF: https://github.com/llvm/llvm-project/commit/619bfe8bd23f76b22f0a53fedafbfc8c97a15f12.diff

LOG: [mlir][sparse] support new kind of scalar in sparse linalg generic op

We have several ways of introducing a scalar invariant value into
linalg generic ops (should we limit this somewhat?). This revision
makes sure we handle all of them correctly in the sparse compiler.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D104335

Added: 
    mlir/test/Dialect/SparseTensor/sparse_scalars.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 679afcfed509b..684a97580fda8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -458,11 +458,17 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
                                          Value val) {
   if (auto arg = val.dyn_cast<BlockArgument>()) {
     unsigned argN = arg.getArgNumber();
-    // Any parameter of the generic op is considered a tensor,
-    // indexed by the implicit loop bounds.
-    if (arg.getOwner()->getParentOp() == op)
-      return merger.addExp(Kind::kTensor, argN);
-    // Any parameter of a higher op is invariant.
+    // Any argument of the generic op that is not marked as a scalar
+    // argument is considered a tensor, indexed by the implicit loop
+    // bounds. This includes rank-0 tensor arguments.
+    if (arg.getOwner()->getParentOp() == op) {
+      OpOperand *t = op.getInputAndOutputOperands()[argN];
+      if (!op.isScalar(t))
+        return merger.addExp(Kind::kTensor, argN);
+      val = t->get(); // get scalar value
+    }
+    // Any other argument (marked as scalar argument for the generic op
+    // or belonging to an enveloping op) is considered invariant.
     return merger.addExp(Kind::kInvariant, val);
   }
   Operation *def = val.getDefiningOp();
@@ -719,9 +725,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
   }
   // Actual load.
   SmallVector<Value, 4> args;
-  OpOperand *t = merger.exp(exp).e0 < op.getNumInputs()
-                     ? op.getInputOperand(merger.exp(exp).e0)
-                     : op.getOutputOperand(0);
+  OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
   unsigned tensor = t->getOperandNumber();
   auto map = op.getTiedIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
@@ -919,11 +923,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
   if (merger.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
-    OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs()
-                            ? op.getInputOperand(merger.exp(exp).e0)
-                            : op.getOutputOperand(0);
-    auto map = op.getTiedIndexingMap(tensor);
-    auto enc = getSparseTensorEncoding(tensor->get().getType());
+    OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
+    auto map = op.getTiedIndexingMap(t);
+    auto enc = getSparseTensorEncoding(t->get().getType());
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned idx = map.getDimPosition(perm(enc, d));
       if (!codegen.loops[idx])
@@ -933,7 +935,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
     }
     // All exhausted at this level (atLevel denotes exactly at this level).
     OpOperand *lhs = op.getOutputOperand(0);
-    if (lhs == tensor) {
+    if (lhs == t) {
       codegen.redExp = hoist ? exp : -1u;
     } else if (atLevel) {
       merger.exp(exp).val =
@@ -1413,8 +1415,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     // Detects sparse annotations and translate the per-dimension sparsity
     // information for all tensors to loop indices in the kernel.
     assert(op.getNumOutputs() == 1);
-    assert(llvm::none_of(op.getInputAndOutputOperands(),
-                         [&](OpOperand *t) { return op.isScalar(t); }));
     unsigned numTensors = op.getNumInputsAndOutputs();
     unsigned numLoops = op.iterator_types().getValue().size();
     Merger merger(numTensors, numLoops);

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
new file mode 100644
index 0000000000000..a70b70289411e
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
@@ -0,0 +1,83 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+// A contrived example that demonstrates the many 
diff erent ways
+// in which scalar values can be involved in a sparse kernel
+// through the linalg generic op.
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (sparse tensor)
+    affine_map<(i,j) -> ()>,     // p (scalar tensor)
+    affine_map<(i,j) -> ()>,     // q (true scalar)
+    affine_map<(i,j) -> (i,j)>   // X (dense tensor out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2"
+}
+
+// CHECK-LABEL:   func @mul(
+// CHECK-SAME:              %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:              %[[VAL_1:.*1]]: tensor<f32>,
+// CHECK-SAME:              %[[VAL_2:.*2]]: f32,
+// CHECK-SAME:              %[[VAL_3:.*3]]: f32,
+// CHECK-SAME:              %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
+// CHECK:           %[[VAL_5:.*]] = constant 2.200000e+00 : f32
+// CHECK:           %[[VAL_6:.*]] = constant 0 : index
+// CHECK:           %[[VAL_7:.*]] = constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = addf %[[VAL_2]], %[[VAL_3]] : f32
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
+// CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_4]] : memref<32x16xf32>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref<f32>
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] {
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] {
+// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:               %[[VAL_27:.*]] = mulf %[[VAL_26]], %[[VAL_16]] : f32
+// CHECK:               %[[VAL_28:.*]] = mulf %[[VAL_27]], %[[VAL_2]] : f32
+// CHECK:               %[[VAL_29:.*]] = mulf %[[VAL_28]], %[[VAL_3]] : f32
+// CHECK:               %[[VAL_30:.*]] = mulf %[[VAL_29]], %[[VAL_8]] : f32
+// CHECK:               %[[VAL_31:.*]] = mulf %[[VAL_30]], %[[VAL_5]] : f32
+// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
+// CHECK:               %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32
+// CHECK:               memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_34:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32>
+// CHECK:           return %[[VAL_34]] : tensor<32x16xf32>
+// CHECK:         }
+func @mul(%arga: tensor<32x16xf32, #SparseMatrix>,
+          %argp: tensor<f32>,
+          %argq: f32,
+          %argr: f32,
+          %argx: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
+  %s = addf %argq, %argr : f32
+  %c = constant 2.2 : f32
+  %0 = linalg.generic #trait
+     ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor<f32>, f32)
+    outs(%argx: tensor<32x16xf32>) {
+      ^bb(%a: f32, %p: f32, %q: f32, %x: f32):
+        %0 = mulf %a, %p : f32     // scalar tensor argument
+        %1 = mulf %0, %q : f32     // scalar argument
+        %2 = mulf %1, %argr : f32  // scalar argument from outside block
+        %3 = mulf %2, %s : f32     // scalar value from outside block
+        %4 = mulf %3, %c : f32     // direct constant from outside block
+        %5 = addf %4, %x : f32
+        linalg.yield %5  : f32
+  } -> tensor<32x16xf32>
+
+  return %0 : tensor<32x16xf32>
+}


        


More information about the Mlir-commits mailing list