[Mlir-commits] [mlir] 0df59f2 - [sparse][mlir] simplify lattice optimization logic
Aart Bik
llvmlistbot at llvm.org
Mon Feb 22 16:52:21 PST 2021
Author: Aart Bik
Date: 2021-02-22T16:52:06-08:00
New Revision: 0df59f234bf09dac203b98427f366be97288f636
URL: https://github.com/llvm/llvm-project/commit/0df59f234bf09dac203b98427f366be97288f636
DIFF: https://github.com/llvm/llvm-project/commit/0df59f234bf09dac203b98427f366be97288f636.diff
LOG: [sparse][mlir] simplify lattice optimization logic
Simplifies the way lattices are optimized with less, but more
powerful rules. This also fixes an inaccuracy where too many
lattices resulted (expecting a non-existing universal index).
Also puts no-side-effects on all proper getters and unifies
bufferization flags order in integration tests (for future,
more complex use cases).
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D97134
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
mlir/integration_test/Sparse/CPU/sparse_sum.mlir
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/sparse_1d.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
index 1bf39bd03200..9d9402478495 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgSparseOps.td
@@ -32,9 +32,11 @@
#define LINALG_SPARSE_OPS
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
// Base class.
-class Linalg_SparseOp<string mnemonic> : Op<Linalg_Dialect, mnemonic, []> {
+class Linalg_SparseOp<string mnemonic, list<OpTrait> traits = []>
+ : Op<Linalg_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let verifier = ?;
let parser = [{ return ::parse$cppClass(parser, result); }];
@@ -65,7 +67,7 @@ def Linalg_SparseTensorFromPointerOp :
}
def Linalg_SparseTensorToPointersMemRefOp :
- Linalg_SparseOp<"sparse_pointers">,
+ Linalg_SparseOp<"sparse_pointers", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extract pointers array at given dimension from a tensor";
@@ -89,7 +91,7 @@ def Linalg_SparseTensorToPointersMemRefOp :
}
def Linalg_SparseTensorToIndicesMemRefOp :
- Linalg_SparseOp<"sparse_indices">,
+ Linalg_SparseOp<"sparse_indices", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extract indices array at given dimension from a tensor";
@@ -113,7 +115,7 @@ def Linalg_SparseTensorToIndicesMemRefOp :
}
def Linalg_SparseTensorToValuesMemRefOp :
- Linalg_SparseOp<"sparse_values">,
+ Linalg_SparseOp<"sparse_values", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extract numerical values array from a tensor";
diff --git a/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir b/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
index 9d2dd7b9575f..5f2b199fb704 100644
--- a/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
+++ b/mlir/integration_test/Sparse/CPU/sparse_sampled_matmul.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s \
// RUN: --test-sparsification="lower ptr-type=2 ind-type=2 fast-output" \
-// RUN: --convert-linalg-to-loops \
+// RUN: --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
// RUN: --std-bufferize --finalizing-bufferize \
-// RUN: --convert-scf-to-std --convert-vector-to-llvm --convert-std-to-llvm | \
+// RUN: --convert-vector-to-llvm --convert-std-to-llvm | \
// RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
diff --git a/mlir/integration_test/Sparse/CPU/sparse_sum.mlir b/mlir/integration_test/Sparse/CPU/sparse_sum.mlir
index c6fea1a5d12b..78e5940eae60 100644
--- a/mlir/integration_test/Sparse/CPU/sparse_sum.mlir
+++ b/mlir/integration_test/Sparse/CPU/sparse_sum.mlir
@@ -1,8 +1,9 @@
// RUN: mlir-opt %s \
// RUN: --test-sparsification="lower" \
-// RUN: --convert-linalg-to-loops \
-// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize --finalizing-bufferize \
-// RUN: --convert-scf-to-std --convert-vector-to-llvm --convert-std-to-llvm | \
+// RUN: --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize \
+// RUN: --convert-vector-to-llvm --convert-std-to-llvm | \
// RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 1898ba688c15..5306e6f40fa3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -175,33 +175,25 @@ class Merger {
unsigned p0 = latSets[s0][0];
for (unsigned p1 : latSets[s0]) {
bool add = true;
- llvm::BitVector simple = simplifyCond(s0, p1);
if (p0 != p1) {
// Is this a straightforward copy?
unsigned e = latPoints[p1].exp;
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
continue;
- // Only dense exhausted?
- llvm::BitVector tmp = latPoints[p1].bits;
- tmp ^= latPoints[p0].bits;
- if (!hasAnyDimOf(tmp, Dim::kSparse))
- continue;
- // Duplication of an earlier conjunction?
+ // Conjunction already covered?
for (unsigned p2 : latSets[s]) {
- tmp = simple;
- tmp ^= latPoints[p2].simple;
- if (tmp.count() == 0) {
+ if (onlyDenseDiff(p2, p1)) {
add = false;
break;
}
}
assert(!add || latGT(p0, p1));
}
- if (add) {
+ if (add)
latSets[s].push_back(p1);
- latPoints[latSets[s].back()].simple = simple;
- }
}
+ for (unsigned p : latSets[s])
+ latPoints[p].simple = simplifyCond(s, p);
return s;
}
@@ -215,15 +207,8 @@ class Merger {
bool isSingleton = true;
for (unsigned p1 : latSets[s]) {
if (p0 != p1 && latGT(p0, p1)) {
- unsigned e = latPoints[p1].exp;
- if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
- continue;
- llvm::BitVector tmp = latPoints[p1].bits;
- tmp ^= latPoints[p0].bits;
- if (hasAnyDimOf(tmp, Dim::kSparse)) {
- isSingleton = false;
- break;
- }
+ isSingleton = false;
+ break;
}
}
// Now apply the two basic rules.
@@ -253,6 +238,13 @@ class Merger {
return false;
}
+ /// Returns true if Li and Lj only
diff er in dense.
+ bool onlyDenseDiff(unsigned i, unsigned j) {
+ llvm::BitVector tmp = latPoints[j].bits;
+ tmp ^= latPoints[i].bits;
+ return !hasAnyDimOf(tmp, Dim::kSparse);
+ }
+
/// Bit translation.
unsigned tensor(unsigned b) const { return b % numTensors; }
unsigned index(unsigned b) const { return b / numTensors; }
@@ -274,12 +266,12 @@ class Merger {
return false;
}
- // Returns true if tensor has any sparse dimension.
+ /// Returns true if tensor has any sparse dimension.
bool isSparseTensor(unsigned t) const {
return llvm::any_of(dims[t], [](Dim d) { return d == Dim::kSparse; });
}
- // Setter
+ /// Setter
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
/// Getters.
@@ -1193,12 +1185,6 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
unsigned lj = merger.set(lts)[j];
unsigned ej = merger.lat(lj).exp;
if (li == lj || merger.latGT(li, lj)) {
- if (li != lj) {
- llvm::BitVector tmp = merger.lat(lj).bits;
- tmp ^= merger.lat(li).bits;
- if (!merger.hasAnyDimOf(tmp, Dim::kSparse))
- continue; // only dense exhausted within if/else
- }
// Recurse into body of each branch.
if (isWhile) {
scf::IfOp ifOp =
diff --git a/mlir/test/Dialect/Linalg/sparse_1d.mlir b/mlir/test/Dialect/Linalg/sparse_1d.mlir
index 5730b89d9a1c..320dd9596565 100644
--- a/mlir/test/Dialect/Linalg/sparse_1d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_1d.mlir
@@ -1153,3 +1153,190 @@ func @sum_reduction_inv(%arga: tensor<16xf32>,
} -> tensor<f32>
return %0 : tensor<f32>
}
+
+#trait_four_tensors = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // A
+ affine_map<(i) -> (i)>, // B
+ affine_map<(i) -> (i)>, // C
+ affine_map<(i) -> (i)>, // D
+ affine_map<(i) -> (i)> // X (out)
+ ],
+ sparse = [
+ ["D"], // A
+ ["S"], // B
+ ["D"], // C
+ ["S"], // D
+ ["D"] // X
+ ],
+ iterator_types = ["parallel"],
+ doc = "X(i) = A(i) + B(i) + C(i) + D(i)"
+}
+
+// CHECK-LABEL: func @four_tensors_op(
+// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?xf64>,
+// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?xf64>,
+// CHECK-SAME: %[[VAL_2:.*2]]: tensor<?xf64>,
+// CHECK-SAME: %[[VAL_3:.*3]]: tensor<?xf64>,
+// CHECK-SAME: %[[VAL_4:.*4]]: tensor<?xf64>) -> tensor<?xf64> {
+// CHECK: %[[VAL_5:.*]] = constant 0 : index
+// CHECK: %[[VAL_6:.*]] = constant true
+// CHECK: %[[VAL_7:.*]] = constant 1 : index
+// CHECK: %[[VAL_8:.*]] = tensor_to_memref %[[VAL_0]] : memref<?xf64>
+// CHECK: %[[VAL_9:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<?xf64> to memref<?xf64>
+// CHECK: %[[VAL_12:.*]] = tensor_to_memref %[[VAL_2]] : memref<?xf64>
+// CHECK: %[[VAL_13:.*]] = linalg.sparse_pointers %[[VAL_3]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = linalg.sparse_indices %[[VAL_3]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = linalg.sparse_values %[[VAL_3]] : tensor<?xf64> to memref<?xf64>
+// CHECK: %[[VAL_16:.*]] = dim %[[VAL_4]], %[[VAL_5]] : tensor<?xf64>
+// CHECK: %[[VAL_17:.*]] = tensor_to_memref %[[VAL_4]] : memref<?xf64>
+// CHECK: %[[VAL_18:.*]] = alloc(%[[VAL_16]]) : memref<?xf64>
+// CHECK: linalg.copy(%[[VAL_17]], %[[VAL_18]]) : memref<?xf64>, memref<?xf64>
+// CHECK: %[[VAL_19:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]]:3 = scf.while (%[[VAL_24:.*]] = %[[VAL_19]], %[[VAL_25:.*]] = %[[VAL_21]], %[[VAL_26:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[VAL_27:.*]] = cmpi ult, %[[VAL_24]], %[[VAL_20]] : index
+// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[VAL_25]], %[[VAL_22]] : index
+// CHECK: %[[VAL_29:.*]] = and %[[VAL_27]], %[[VAL_28]] : i1
+// CHECK: scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index):
+// CHECK: %[[VAL_33:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<?xindex>
+// CHECK: %[[VAL_34:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK: %[[VAL_35:.*]] = cmpi eq, %[[VAL_33]], %[[VAL_32]] : index
+// CHECK: %[[VAL_36:.*]] = cmpi eq, %[[VAL_34]], %[[VAL_32]] : index
+// CHECK: %[[VAL_37:.*]] = and %[[VAL_35]], %[[VAL_36]] : i1
+// CHECK: scf.if %[[VAL_37]] {
+// CHECK: %[[VAL_38:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_39:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf64>
+// CHECK: %[[VAL_40:.*]] = addf %[[VAL_38]], %[[VAL_39]] : f64
+// CHECK: %[[VAL_41:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_42:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK: %[[VAL_43:.*]] = addf %[[VAL_41]], %[[VAL_42]] : f64
+// CHECK: %[[VAL_44:.*]] = addf %[[VAL_40]], %[[VAL_43]] : f64
+// CHECK: store %[[VAL_44]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: %[[VAL_45:.*]] = cmpi eq, %[[VAL_33]], %[[VAL_32]] : index
+// CHECK: scf.if %[[VAL_45]] {
+// CHECK: %[[VAL_46:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_47:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf64>
+// CHECK: %[[VAL_48:.*]] = addf %[[VAL_46]], %[[VAL_47]] : f64
+// CHECK: %[[VAL_49:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_50:.*]] = addf %[[VAL_48]], %[[VAL_49]] : f64
+// CHECK: store %[[VAL_50]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: %[[VAL_51:.*]] = cmpi eq, %[[VAL_34]], %[[VAL_32]] : index
+// CHECK: scf.if %[[VAL_51]] {
+// CHECK: %[[VAL_52:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_53:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_54:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK: %[[VAL_55:.*]] = addf %[[VAL_53]], %[[VAL_54]] : f64
+// CHECK: %[[VAL_56:.*]] = addf %[[VAL_52]], %[[VAL_55]] : f64
+// CHECK: store %[[VAL_56]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_6]] {
+// CHECK: %[[VAL_57:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_58:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_59:.*]] = addf %[[VAL_57]], %[[VAL_58]] : f64
+// CHECK: store %[[VAL_59]], %[[VAL_18]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_60:.*]] = cmpi eq, %[[VAL_33]], %[[VAL_32]] : index
+// CHECK: %[[VAL_61:.*]] = addi %[[VAL_30]], %[[VAL_7]] : index
+// CHECK: %[[VAL_62:.*]] = select %[[VAL_60]], %[[VAL_61]], %[[VAL_30]] : index
+// CHECK: %[[VAL_63:.*]] = cmpi eq, %[[VAL_34]], %[[VAL_32]] : index
+// CHECK: %[[VAL_64:.*]] = addi %[[VAL_31]], %[[VAL_7]] : index
+// CHECK: %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_31]] : index
+// CHECK: %[[VAL_66:.*]] = addi %[[VAL_32]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_62]], %[[VAL_65]], %[[VAL_66]] : index, index, index
+// CHECK: }
+// CHECK: %[[VAL_67:.*]]:2 = scf.while (%[[VAL_68:.*]] = %[[VAL_69:.*]]#0, %[[VAL_70:.*]] = %[[VAL_69]]#2) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_71:.*]] = cmpi ult, %[[VAL_68]], %[[VAL_20]] : index
+// CHECK: scf.condition(%[[VAL_71]]) %[[VAL_68]], %[[VAL_70]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_72:.*]]: index, %[[VAL_73:.*]]: index):
+// CHECK: %[[VAL_74:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_72]]] : memref<?xindex>
+// CHECK: %[[VAL_75:.*]] = cmpi eq, %[[VAL_74]], %[[VAL_73]] : index
+// CHECK: scf.if %[[VAL_75]] {
+// CHECK: %[[VAL_76:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK: %[[VAL_77:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_72]]] : memref<?xf64>
+// CHECK: %[[VAL_78:.*]] = addf %[[VAL_76]], %[[VAL_77]] : f64
+// CHECK: %[[VAL_79:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK: %[[VAL_80:.*]] = addf %[[VAL_78]], %[[VAL_79]] : f64
+// CHECK: store %[[VAL_80]], %[[VAL_18]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_6]] {
+// CHECK: %[[VAL_81:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK: %[[VAL_82:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK: %[[VAL_83:.*]] = addf %[[VAL_81]], %[[VAL_82]] : f64
+// CHECK: store %[[VAL_83]], %[[VAL_18]]{{\[}}%[[VAL_73]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_84:.*]] = cmpi eq, %[[VAL_74]], %[[VAL_73]] : index
+// CHECK: %[[VAL_85:.*]] = addi %[[VAL_72]], %[[VAL_7]] : index
+// CHECK: %[[VAL_86:.*]] = select %[[VAL_84]], %[[VAL_85]], %[[VAL_72]] : index
+// CHECK: %[[VAL_87:.*]] = addi %[[VAL_73]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_86]], %[[VAL_87]] : index, index
+// CHECK: }
+// CHECK: %[[VAL_88:.*]]:2 = scf.while (%[[VAL_89:.*]] = %[[VAL_90:.*]]#1, %[[VAL_91:.*]] = %[[VAL_92:.*]]#1) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_93:.*]] = cmpi ult, %[[VAL_89]], %[[VAL_22]] : index
+// CHECK: scf.condition(%[[VAL_93]]) %[[VAL_89]], %[[VAL_91]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_94:.*]]: index, %[[VAL_95:.*]]: index):
+// CHECK: %[[VAL_96:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_94]]] : memref<?xindex>
+// CHECK: %[[VAL_97:.*]] = cmpi eq, %[[VAL_96]], %[[VAL_95]] : index
+// CHECK: scf.if %[[VAL_97]] {
+// CHECK: %[[VAL_98:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK: %[[VAL_99:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK: %[[VAL_100:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_94]]] : memref<?xf64>
+// CHECK: %[[VAL_101:.*]] = addf %[[VAL_99]], %[[VAL_100]] : f64
+// CHECK: %[[VAL_102:.*]] = addf %[[VAL_98]], %[[VAL_101]] : f64
+// CHECK: store %[[VAL_102]], %[[VAL_18]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_6]] {
+// CHECK: %[[VAL_103:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK: %[[VAL_104:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK: %[[VAL_105:.*]] = addf %[[VAL_103]], %[[VAL_104]] : f64
+// CHECK: store %[[VAL_105]], %[[VAL_18]]{{\[}}%[[VAL_95]]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_106:.*]] = cmpi eq, %[[VAL_96]], %[[VAL_95]] : index
+// CHECK: %[[VAL_107:.*]] = addi %[[VAL_94]], %[[VAL_7]] : index
+// CHECK: %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_94]] : index
+// CHECK: %[[VAL_109:.*]] = addi %[[VAL_95]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_108]], %[[VAL_109]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_110:.*]] = %[[VAL_111:.*]]#1 to %[[VAL_16]] step %[[VAL_7]] {
+// CHECK: %[[VAL_112:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_110]]] : memref<?xf64>
+// CHECK: %[[VAL_113:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_110]]] : memref<?xf64>
+// CHECK: %[[VAL_114:.*]] = addf %[[VAL_112]], %[[VAL_113]] : f64
+// CHECK: store %[[VAL_114]], %[[VAL_18]]{{\[}}%[[VAL_110]]] : memref<?xf64>
+// CHECK: }
+// CHECK: %[[VAL_115:.*]] = tensor_load %[[VAL_18]] : memref<?xf64>
+// CHECK: return %[[VAL_115]] : tensor<?xf64>
+// CHECK: }
+func @four_tensors_op(%arga: tensor<?xf64>,
+ %argb: tensor<?xf64>,
+ %argc: tensor<?xf64>,
+ %argd: tensor<?xf64>,
+ %argx: tensor<?xf64>) -> tensor<?xf64> {
+ %r = linalg.generic #trait_four_tensors
+ ins(%arga, %argb, %argc, %argd: tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>)
+ outs(%argx: tensor<?xf64>) {
+ ^bb(%a: f64, %b: f64, %c: f64, %d: f64, %x: f64):
+ %0 = addf %a, %b : f64
+ %1 = addf %c, %d : f64
+ %2 = addf %0, %1 : f64
+ linalg.yield %2 : f64
+ } -> tensor<?xf64>
+ return %r : tensor<?xf64>
+}
More information about the Mlir-commits
mailing list