[Mlir-commits] [mlir] 0df59f2 - [sparse][mlir] simplify lattice optimization logic

Aart Bik llvmlistbot at llvm.org
Mon Feb 22 16:52:21 PST 2021


Author: Aart Bik
Date: 2021-02-22T16:52:06-08:00
New Revision: 0df59f234bf09dac203b98427f366be97288f636

URL: https://github.com/llvm/llvm-project/commit/0df59f234bf09dac203b98427f366be97288f636
DIFF: https://github.com/llvm/llvm-project/commit/0df59f234bf09dac203b98427f366be97288f636.diff

LOG: [sparse][mlir] simplify lattice optimization logic

Simplifies the way lattices are optimized with less, but more
powerful rules. This also fixes an inaccuracy where too many
lattices resulted (expecting a non-existing universal index).
Also puts no-side-effects on all proper getters and unifies
bufferization flags order in integration tests (for future,
more complex use cases).

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97134

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
    mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
    mlir/integration_test/Sparse/CPU/sparse_sum.mlir
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_1d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
index 1bf39bd03200..9d9402478495 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
@@ -32,9 +32,11 @@
 #define LINALG_SPARSE_OPS
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // Base class.
-class Linalg_SparseOp<string mnemonic> : Op<Linalg_Dialect, mnemonic, []> {
+class Linalg_SparseOp<string mnemonic, list<OpTrait> traits = []>
+  : Op<Linalg_Dialect, mnemonic, traits> {
   let printer = [{ return ::print(p, *this); }];
   let verifier = ?;
   let parser = [{ return ::parse$cppClass(parser, result); }];
@@ -65,7 +67,7 @@ def Linalg_SparseTensorFromPointerOp :
 }
 
 def Linalg_SparseTensorToPointersMemRefOp :
-    Linalg_SparseOp<"sparse_pointers">,
+    Linalg_SparseOp<"sparse_pointers", [NoSideEffect]>,
     Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
   let summary = "Extract pointers array at given dimension from a tensor";
@@ -89,7 +91,7 @@ def Linalg_SparseTensorToPointersMemRefOp :
 }
 
 def Linalg_SparseTensorToIndicesMemRefOp :
-    Linalg_SparseOp<"sparse_indices">,
+    Linalg_SparseOp<"sparse_indices", [NoSideEffect]>,
     Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
   let summary = "Extract indices array at given dimension from a tensor";
@@ -113,7 +115,7 @@ def Linalg_SparseTensorToIndicesMemRefOp :
 }
 
 def Linalg_SparseTensorToValuesMemRefOp :
-    Linalg_SparseOp<"sparse_values">,
+    Linalg_SparseOp<"sparse_values", [NoSideEffect]>,
     Arguments<(ins AnyTensor:$tensor)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
   let summary = "Extract numerical values array from a tensor";

diff  --git a/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir b/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
index 9d2dd7b9575f..5f2b199fb704 100644
--- a/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
+++ b/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s \
 // RUN:   --test-sparsification="lower ptr-type=2 ind-type=2 fast-output" \
-// RUN:   --convert-linalg-to-loops \
+// RUN:   --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
 // RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
 // RUN:   --std-bufferize --finalizing-bufferize  \
-// RUN:   --convert-scf-to-std --convert-vector-to-llvm --convert-std-to-llvm | \
+// RUN:   --convert-vector-to-llvm --convert-std-to-llvm | \
 // RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \

diff  --git a/mlir/integration_test/Sparse/CPU/sparse_sum.mlir b/mlir/integration_test/Sparse/CPU/sparse_sum.mlir
index c6fea1a5d12b..78e5940eae60 100644
--- a/mlir/integration_test/Sparse/CPU/sparse_sum.mlir
+++ b/mlir/integration_test/Sparse/CPU/sparse_sum.mlir
@@ -1,8 +1,9 @@
 // RUN: mlir-opt %s \
 // RUN:   --test-sparsification="lower" \
-// RUN:   --convert-linalg-to-loops \
-// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize --finalizing-bufferize  \
-// RUN:   --convert-scf-to-std --convert-vector-to-llvm --convert-std-to-llvm | \
+// RUN:   --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize  \
+// RUN:   --convert-vector-to-llvm --convert-std-to-llvm | \
 // RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 1898ba688c15..5306e6f40fa3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -175,33 +175,25 @@ class Merger {
     unsigned p0 = latSets[s0][0];
     for (unsigned p1 : latSets[s0]) {
       bool add = true;
-      llvm::BitVector simple = simplifyCond(s0, p1);
       if (p0 != p1) {
         // Is this a straightforward copy?
         unsigned e = latPoints[p1].exp;
         if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
           continue;
-        // Only dense exhausted?
-        llvm::BitVector tmp = latPoints[p1].bits;
-        tmp ^= latPoints[p0].bits;
-        if (!hasAnyDimOf(tmp, Dim::kSparse))
-          continue;
-        // Duplication of an earlier conjunction?
+        // Conjunction already covered?
         for (unsigned p2 : latSets[s]) {
-          tmp = simple;
-          tmp ^= latPoints[p2].simple;
-          if (tmp.count() == 0) {
+          if (onlyDenseDiff(p2, p1)) {
             add = false;
             break;
           }
         }
         assert(!add || latGT(p0, p1));
       }
-      if (add) {
+      if (add)
         latSets[s].push_back(p1);
-        latPoints[latSets[s].back()].simple = simple;
-      }
     }
+    for (unsigned p : latSets[s])
+      latPoints[p].simple = simplifyCond(s, p);
     return s;
   }
 
@@ -215,15 +207,8 @@ class Merger {
     bool isSingleton = true;
     for (unsigned p1 : latSets[s]) {
       if (p0 != p1 && latGT(p0, p1)) {
-        unsigned e = latPoints[p1].exp;
-        if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
-          continue;
-        llvm::BitVector tmp = latPoints[p1].bits;
-        tmp ^= latPoints[p0].bits;
-        if (hasAnyDimOf(tmp, Dim::kSparse)) {
-          isSingleton = false;
-          break;
-        }
+        isSingleton = false;
+        break;
       }
     }
     // Now apply the two basic rules.
@@ -253,6 +238,13 @@ class Merger {
     return false;
   }
 
+  /// Returns true if Li and Lj only 
diff er in dense.
+  bool onlyDenseDiff(unsigned i, unsigned j) {
+    llvm::BitVector tmp = latPoints[j].bits;
+    tmp ^= latPoints[i].bits;
+    return !hasAnyDimOf(tmp, Dim::kSparse);
+  }
+
   /// Bit translation.
   unsigned tensor(unsigned b) const { return b % numTensors; }
   unsigned index(unsigned b) const { return b / numTensors; }
@@ -274,12 +266,12 @@ class Merger {
     return false;
   }
 
-  // Returns true if tensor has any sparse dimension.
+  /// Returns true if tensor has any sparse dimension.
   bool isSparseTensor(unsigned t) const {
     return llvm::any_of(dims[t], [](Dim d) { return d == Dim::kSparse; });
   }
 
-  // Setter
+  /// Setter
   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
 
   /// Getters.
@@ -1193,12 +1185,6 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
       unsigned lj = merger.set(lts)[j];
       unsigned ej = merger.lat(lj).exp;
       if (li == lj || merger.latGT(li, lj)) {
-        if (li != lj) {
-          llvm::BitVector tmp = merger.lat(lj).bits;
-          tmp ^= merger.lat(li).bits;
-          if (!merger.hasAnyDimOf(tmp, Dim::kSparse))
-            continue; // only dense exhausted within if/else
-        }
         // Recurse into body of each branch.
         if (isWhile) {
           scf::IfOp ifOp =

diff  --git a/mlir/test/Dialect/Linalg/sparse_1d.mlir b/mlir/test/Dialect/Linalg/sparse_1d.mlir
index 5730b89d9a1c..320dd9596565 100644
--- a/mlir/test/Dialect/Linalg/sparse_1d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_1d.mlir
@@ -1153,3 +1153,190 @@ func @sum_reduction_inv(%arga: tensor<16xf32>,
   } -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+#trait_four_tensors = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // A
+    affine_map<(i) -> (i)>,  // B
+    affine_map<(i) -> (i)>,  // C
+    affine_map<(i) -> (i)>,  // D
+    affine_map<(i) -> (i)>   // X (out)
+  ],
+  sparse = [
+    ["D"], // A
+    ["S"], // B
+    ["D"], // C
+    ["S"], // D
+    ["D"]  // X
+  ],
+  iterator_types = ["parallel"],
+  doc = "X(i) = A(i) + B(i) + C(i) + D(i)"
+}
+
+// CHECK-LABEL:   func @four_tensors_op(
+// CHECK-SAME:                          %[[VAL_0:.*0]]: tensor<?xf64>,
+// CHECK-SAME:                          %[[VAL_1:.*1]]: tensor<?xf64>,
+// CHECK-SAME:                          %[[VAL_2:.*2]]: tensor<?xf64>,
+// CHECK-SAME:                          %[[VAL_3:.*3]]: tensor<?xf64>,
+// CHECK-SAME:                          %[[VAL_4:.*4]]: tensor<?xf64>) -> tensor<?xf64> {
+// CHECK:           %[[VAL_5:.*]] = constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = constant true
+// CHECK:           %[[VAL_7:.*]] = constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = tensor_to_memref %[[VAL_0]] : memref<?xf64>
+// CHECK:           %[[VAL_9:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_12:.*]] = tensor_to_memref %[[VAL_2]] : memref<?xf64>
+// CHECK:           %[[VAL_13:.*]] = linalg.sparse_pointers %[[VAL_3]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = linalg.sparse_indices %[[VAL_3]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = linalg.sparse_values %[[VAL_3]] : tensor<?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_16:.*]] = dim %[[VAL_4]], %[[VAL_5]] : tensor<?xf64>
+// CHECK:           %[[VAL_17:.*]] = tensor_to_memref %[[VAL_4]] : memref<?xf64>
+// CHECK:           %[[VAL_18:.*]] = alloc(%[[VAL_16]]) : memref<?xf64>
+// CHECK:           linalg.copy(%[[VAL_17]], %[[VAL_18]]) : memref<?xf64>, memref<?xf64>
+// CHECK:           %[[VAL_19:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_23:.*]]:3 = scf.while (%[[VAL_24:.*]] = %[[VAL_19]], %[[VAL_25:.*]] = %[[VAL_21]], %[[VAL_26:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
+// CHECK:             %[[VAL_27:.*]] = cmpi ult, %[[VAL_24]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_28:.*]] = cmpi ult, %[[VAL_25]], %[[VAL_22]] : index
+// CHECK:             %[[VAL_29:.*]] = and %[[VAL_27]], %[[VAL_28]] : i1
+// CHECK:             scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index):
+// CHECK:             %[[VAL_33:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<?xindex>
+// CHECK:             %[[VAL_34:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK:             %[[VAL_35:.*]] = cmpi eq, %[[VAL_33]], %[[VAL_32]] : index
+// CHECK:             %[[VAL_36:.*]] = cmpi eq, %[[VAL_34]], %[[VAL_32]] : index
+// CHECK:             %[[VAL_37:.*]] = and %[[VAL_35]], %[[VAL_36]] : i1
+// CHECK:             scf.if %[[VAL_37]] {
+// CHECK:               %[[VAL_38:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:               %[[VAL_39:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf64>
+// CHECK:               %[[VAL_40:.*]] = addf %[[VAL_38]], %[[VAL_39]] : f64
+// CHECK:               %[[VAL_41:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:               %[[VAL_42:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK:               %[[VAL_43:.*]] = addf %[[VAL_41]], %[[VAL_42]] : f64
+// CHECK:               %[[VAL_44:.*]] = addf %[[VAL_40]], %[[VAL_43]] : f64
+// CHECK:               store %[[VAL_44]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:             } else {
+// CHECK:               %[[VAL_45:.*]] = cmpi eq, %[[VAL_33]], %[[VAL_32]] : index
+// CHECK:               scf.if %[[VAL_45]] {
+// CHECK:                 %[[VAL_46:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                 %[[VAL_47:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf64>
+// CHECK:                 %[[VAL_48:.*]] = addf %[[VAL_46]], %[[VAL_47]] : f64
+// CHECK:                 %[[VAL_49:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                 %[[VAL_50:.*]] = addf %[[VAL_48]], %[[VAL_49]] : f64
+// CHECK:                 store %[[VAL_50]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:               } else {
+// CHECK:                 %[[VAL_51:.*]] = cmpi eq, %[[VAL_34]], %[[VAL_32]] : index
+// CHECK:                 scf.if %[[VAL_51]] {
+// CHECK:                   %[[VAL_52:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                   %[[VAL_53:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                   %[[VAL_54:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK:                   %[[VAL_55:.*]] = addf %[[VAL_53]], %[[VAL_54]] : f64
+// CHECK:                   %[[VAL_56:.*]] = addf %[[VAL_52]], %[[VAL_55]] : f64
+// CHECK:                   store %[[VAL_56]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                 } else {
+// CHECK:                   scf.if %[[VAL_6]] {
+// CHECK:                     %[[VAL_57:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                     %[[VAL_58:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                     %[[VAL_59:.*]] = addf %[[VAL_57]], %[[VAL_58]] : f64
+// CHECK:                     store %[[VAL_59]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK:                   } else {
+// CHECK:                   }
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_60:.*]] = cmpi eq, %[[VAL_33]], %[[VAL_32]] : index
+// CHECK:             %[[VAL_61:.*]] = addi %[[VAL_30]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_62:.*]] = select %[[VAL_60]], %[[VAL_61]], %[[VAL_30]] : index
+// CHECK:             %[[VAL_63:.*]] = cmpi eq, %[[VAL_34]], %[[VAL_32]] : index
+// CHECK:             %[[VAL_64:.*]] = addi %[[VAL_31]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_31]] : index
+// CHECK:             %[[VAL_66:.*]] = addi %[[VAL_32]], %[[VAL_7]] : index
+// CHECK:             scf.yield %[[VAL_62]], %[[VAL_65]], %[[VAL_66]] : index, index, index
+// CHECK:           }
+// CHECK:           %[[VAL_67:.*]]:2 = scf.while (%[[VAL_68:.*]] = %[[VAL_69:.*]]#0, %[[VAL_70:.*]] = %[[VAL_69]]#2) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_71:.*]] = cmpi ult, %[[VAL_68]], %[[VAL_20]] : index
+// CHECK:             scf.condition(%[[VAL_71]]) %[[VAL_68]], %[[VAL_70]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_72:.*]]: index, %[[VAL_73:.*]]: index):
+// CHECK:             %[[VAL_74:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_72]]] : memref<?xindex>
+// CHECK:             %[[VAL_75:.*]] = cmpi eq, %[[VAL_74]], %[[VAL_73]] : index
+// CHECK:             scf.if %[[VAL_75]] {
+// CHECK:               %[[VAL_76:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK:               %[[VAL_77:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_72]]] : memref<?xf64>
+// CHECK:               %[[VAL_78:.*]] = addf %[[VAL_76]], %[[VAL_77]] : f64
+// CHECK:               %[[VAL_79:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK:               %[[VAL_80:.*]] = addf %[[VAL_78]], %[[VAL_79]] : f64
+// CHECK:               store %[[VAL_80]], %[[VAL_18]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_6]] {
+// CHECK:                 %[[VAL_81:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK:                 %[[VAL_82:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK:                 %[[VAL_83:.*]] = addf %[[VAL_81]], %[[VAL_82]] : f64
+// CHECK:                 store %[[VAL_83]], %[[VAL_18]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_84:.*]] = cmpi eq, %[[VAL_74]], %[[VAL_73]] : index
+// CHECK:             %[[VAL_85:.*]] = addi %[[VAL_72]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_86:.*]] = select %[[VAL_84]], %[[VAL_85]], %[[VAL_72]] : index
+// CHECK:             %[[VAL_87:.*]] = addi %[[VAL_73]], %[[VAL_7]] : index
+// CHECK:             scf.yield %[[VAL_86]], %[[VAL_87]] : index, index
+// CHECK:           }
+// CHECK:           %[[VAL_88:.*]]:2 = scf.while (%[[VAL_89:.*]] = %[[VAL_90:.*]]#1, %[[VAL_91:.*]] = %[[VAL_92:.*]]#1) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_93:.*]] = cmpi ult, %[[VAL_89]], %[[VAL_22]] : index
+// CHECK:             scf.condition(%[[VAL_93]]) %[[VAL_89]], %[[VAL_91]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_94:.*]]: index, %[[VAL_95:.*]]: index):
+// CHECK:             %[[VAL_96:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_94]]] : memref<?xindex>
+// CHECK:             %[[VAL_97:.*]] = cmpi eq, %[[VAL_96]], %[[VAL_95]] : index
+// CHECK:             scf.if %[[VAL_97]] {
+// CHECK:               %[[VAL_98:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK:               %[[VAL_99:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK:               %[[VAL_100:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_94]]] : memref<?xf64>
+// CHECK:               %[[VAL_101:.*]] = addf %[[VAL_99]], %[[VAL_100]] : f64
+// CHECK:               %[[VAL_102:.*]] = addf %[[VAL_98]], %[[VAL_101]] : f64
+// CHECK:               store %[[VAL_102]], %[[VAL_18]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_6]] {
+// CHECK:                 %[[VAL_103:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK:                 %[[VAL_104:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK:                 %[[VAL_105:.*]] = addf %[[VAL_103]], %[[VAL_104]] : f64
+// CHECK:                 store %[[VAL_105]], %[[VAL_18]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_106:.*]] = cmpi eq, %[[VAL_96]], %[[VAL_95]] : index
+// CHECK:             %[[VAL_107:.*]] = addi %[[VAL_94]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_94]] : index
+// CHECK:             %[[VAL_109:.*]] = addi %[[VAL_95]], %[[VAL_7]] : index
+// CHECK:             scf.yield %[[VAL_108]], %[[VAL_109]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_110:.*]] = %[[VAL_111:.*]]#1 to %[[VAL_16]] step %[[VAL_7]] {
+// CHECK:             %[[VAL_112:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_110]]] : memref<?xf64>
+// CHECK:             %[[VAL_113:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_110]]] : memref<?xf64>
+// CHECK:             %[[VAL_114:.*]] = addf %[[VAL_112]], %[[VAL_113]] : f64
+// CHECK:             store %[[VAL_114]], %[[VAL_18]]{{\[}}%[[VAL_110]]] : memref<?xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_115:.*]] = tensor_load %[[VAL_18]] : memref<?xf64>
+// CHECK:           return %[[VAL_115]] : tensor<?xf64>
+// CHECK:         }
+func @four_tensors_op(%arga: tensor<?xf64>,
+                      %argb: tensor<?xf64>,
+                      %argc: tensor<?xf64>,
+                      %argd: tensor<?xf64>,
+                      %argx: tensor<?xf64>) -> tensor<?xf64> {
+  %r = linalg.generic #trait_four_tensors
+    ins(%arga, %argb, %argc, %argd: tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>)
+    outs(%argx: tensor<?xf64>) {
+      ^bb(%a: f64, %b: f64, %c: f64, %d: f64, %x: f64):
+        %0 = addf %a, %b : f64
+        %1 = addf %c, %d : f64
+        %2 = addf %0, %1 : f64
+        linalg.yield %2 : f64
+  } -> tensor<?xf64>
+  return %r : tensor<?xf64>
+}


        


More information about the Mlir-commits mailing list