[llvm-branch-commits] [mlir] 74cd9e5 - [mlir][sparse] hoist loop invariant tensor loads in sparse compiler

Aart Bik via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Dec 7 12:04:53 PST 2020


Author: Aart Bik
Date: 2020-12-07T11:59:48-08:00
New Revision: 74cd9e587d80063381242006d0690231d756aa7a

URL: https://github.com/llvm/llvm-project/commit/74cd9e587d80063381242006d0690231d756aa7a
DIFF: https://github.com/llvm/llvm-project/commit/74cd9e587d80063381242006d0690231d756aa7a.diff

LOG: [mlir][sparse] hoist loop invariant tensor loads in sparse compiler

After bufferization, the backend has much more trouble hoisting loop invariant
loads from the loops generated by the sparse compiler. Therefore, this is done
during sparse code generation. Note that we don't bother hoisting derived
invariant expressions on SSA values, since the backend does that very well.

Still TBD: scalarize reductions to avoid load-add-store cycles

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D92534

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_2d.mlir
    mlir/test/Dialect/Linalg/sparse_3d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 07a3e1569622..cfdb371e3234 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -59,14 +59,21 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
 /// children tensor expressions.
 struct TensorExp {
   TensorExp(Kind k, unsigned x, unsigned y, Value v)
-      : kind(k), e0(x), e1(y), val(v) {}
+      : kind(k), e0(x), e1(y), val(v) {
+    assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
+           (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
+           (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
+  }
   Kind kind;
+  /// Indices of children expression(s).
   unsigned e0;
   unsigned e1;
+  /// Direct link to IR for an invariant. During code generation,
+  /// field is used to cache "hoisted" loop invariant tensor loads.
   Value val;
 };
 
-/// Lattice point. Each lattice point consist of a conjunction of tensor
+/// Lattice point. Each lattice point consists of a conjunction of tensor
 /// loop indices (encoded in a bitvector) and the index of the corresponding
 /// tensor expression.
 struct LatPoint {
@@ -74,7 +81,9 @@ struct LatPoint {
     bits.set(b);
   }
   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
+  /// Conjunction of tensor loop indices as bitvector.
   llvm::BitVector bits;
+  /// Index of the tensor expresssion.
   unsigned exp;
 };
 
@@ -502,8 +511,16 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
 /// Generates a load on a dense or sparse tensor.
 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
                            PatternRewriter &rewriter, linalg::GenericOp op,
-                           unsigned tensor) {
+                           unsigned exp) {
+  // Test if the load was hoisted to a higher loop nest.
+  Value val = merger.exp(exp).val;
+  if (val) {
+    merger.exp(exp).val = Value(); // reset
+    return val;
+  }
+  // Actual load.
   SmallVector<Value, 4> args;
+  unsigned tensor = merger.exp(exp).e0;
   auto map = op.getIndexingMap(tensor);
   bool sparse = false;
   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
@@ -515,7 +532,9 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
       args.push_back(codegen.pidxs[tensor][idx]); // position index
     }
   }
-  return rewriter.create<LoadOp>(op.getLoc(), codegen.buffers[tensor], args);
+  Location loc = op.getLoc();
+  Value ptr = codegen.buffers[tensor];
+  return rewriter.create<LoadOp>(loc, ptr, args);
 }
 
 /// Generates a store on a dense tensor.
@@ -528,25 +547,33 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
     unsigned idx = map.getDimPosition(i);
     args.push_back(codegen.loops[idx]); // universal dense index
   }
-  rewriter.create<StoreOp>(op.getLoc(), rhs, codegen.buffers[tensor], args);
+  Location loc = op.getLoc();
+  Value ptr = codegen.buffers[tensor];
+  rewriter.create<StoreOp>(loc, rhs, ptr, args);
 }
 
 /// Generates a pointer/index load from the sparse storage scheme.
-static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr,
-                        Value s) {
+static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr,
+                     Value s) {
   Value load = rewriter.create<LoadOp>(loc, ptr, s);
   return load.getType().isa<IndexType>()
              ? load
              : rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
 }
 
+/// Generates an invariant value.
+static Value genInvariantValue(Merger &merger, CodeGen &codegen,
+                               PatternRewriter &rewriter, unsigned exp) {
+  return merger.exp(exp).val;
+}
+
 /// Recursively generates tensor expression.
 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
                     linalg::GenericOp op, unsigned exp) {
   if (merger.exp(exp).kind == Kind::kTensor)
-    return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0);
+    return genTensorLoad(merger, codegen, rewriter, op, exp);
   else if (merger.exp(exp).kind == Kind::kInvariant)
-    return merger.exp(exp).val;
+    return genInvariantValue(merger, codegen, rewriter, exp);
   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
   switch (merger.exp(exp).kind) {
@@ -564,6 +591,33 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   }
 }
 
+/// Hoists loop invariant tensor loads for which indices have been exhausted.
+static void genInvariants(Merger &merger, CodeGen &codegen,
+                          PatternRewriter &rewriter, linalg::GenericOp op,
+                          unsigned exp) {
+  if (merger.exp(exp).kind == Kind::kTensor) {
+    unsigned lhs = op.getNumInputsAndOutputs() - 1;
+    unsigned tensor = merger.exp(exp).e0;
+    if (tensor == lhs)
+      return; // TODO: scalarize reduction as well (using scf.yield)
+    auto map = op.getIndexingMap(tensor);
+    for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
+      unsigned idx = map.getDimPosition(i);
+      if (!codegen.loops[idx])
+        return; // still in play
+    }
+    // All exhausted at this level.
+    merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp);
+
+  } else if (merger.exp(exp).kind != Kind::kInvariant) {
+    // Traverse into the binary operations. Note that we only hoist
+    // tensor loads, since subsequent MLIR/LLVM passes know how to
+    // deal with all other kinds of derived loop invariants.
+    genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0);
+    genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1);
+  }
+}
+
 /// Generates initialization code for the subsequent loop sequence at
 /// current index level. Returns true if the loop sequence needs to
 /// maintain the universal index.
@@ -590,9 +644,9 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
                               : codegen.pidxs[tensor][topSort[pat - 1]];
-        codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0);
+        codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0);
         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
-        codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1);
+        codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1);
       } else {
         // Dense index still in play.
         needsUniv = true;
@@ -608,7 +662,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 /// Generates a for-loop on a single index.
 static Operation *genFor(Merger &merger, CodeGen &codegen,
                          PatternRewriter &rewriter, linalg::GenericOp op,
-                         bool isOuter, unsigned idx, llvm::BitVector &indices) {
+                         bool isOuter, bool isInner, unsigned idx,
+                         llvm::BitVector &indices) {
   unsigned fb = indices.find_first();
   unsigned tensor = merger.tensor(fb);
   assert(idx == merger.index(fb));
@@ -725,10 +780,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(Merger &merger, CodeGen &codegen,
                           PatternRewriter &rewriter, linalg::GenericOp op,
-                          bool isOuter, unsigned idx, bool needsUniv,
-                          llvm::BitVector &indices) {
-  if (indices.count() == 1)
-    return genFor(merger, codegen, rewriter, op, isOuter, idx, indices);
+                          std::vector<unsigned> &topSort, unsigned at,
+                          bool needsUniv, llvm::BitVector &indices) {
+  unsigned idx = topSort[at];
+  if (indices.count() == 1) {
+    bool isOuter = at == 0;
+    bool isInner = at == topSort.size() - 1;
+    return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
+                  indices);
+  }
   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
 }
 
@@ -749,7 +809,7 @@ static void genLocals(Merger &merger, CodeGen &codegen,
       assert(idx == merger.index(b));
       Value ptr = codegen.indices[tensor][idx];
       Value s = codegen.pidxs[tensor][idx];
-      Value load = genIntLoad(rewriter, loc, ptr, s);
+      Value load = genLoad(rewriter, loc, ptr, s);
       codegen.idxs[tensor][idx] = load;
       if (!needsUniv) {
         if (min) {
@@ -886,6 +946,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   assert(lsize != 0);
   unsigned l0 = merger.set(lts)[0];
   LatPoint lat0 = merger.lat(l0);
+  genInvariants(merger, codegen, rewriter, op, exp);
   bool needsUniv =
       genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) &&
       lsize > 1;
@@ -897,9 +958,8 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     // Emit loop.
     llvm::BitVector indices = lati.bits;
     optimizeIndices(merger, lsize, indices);
-    bool isOuter = at == 0;
-    Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx,
-                              needsUniv, indices);
+    Operation *loop =
+        genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
 
     // Visit all lattices points with Li >= Lj to generate the
@@ -931,6 +991,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     }
     rewriter.setInsertionPointAfter(loop);
   }
+  codegen.loops[idx] = Value();
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir
index 874417c25446..bdd2de5e437a 100644
--- a/mlir/test/Dialect/Linalg/sparse_2d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir
@@ -1071,8 +1071,8 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                %[[VAL_0:.*]]: tensor<10x20xf32>,
-// CHECK-SAME:                %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK-SAME:                        %[[VAL_0:.*0]]: tensor<10x20xf32>,
+// CHECK-SAME:                        %[[VAL_1:.*1]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_2:.*]] = constant 999 : index
 // CHECK:           %[[VAL_3:.*]] = constant 10 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
@@ -1200,19 +1200,19 @@ func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
 // CHECK:           scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_6]] {
 // CHECK:             %[[VAL_24:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
 // CHECK:             scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_15]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_26:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
-// CHECK:               %[[VAL_27:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index
-// CHECK:               %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:               scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_6]] {
-// CHECK:                 %[[VAL_30:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_29]]] : memref<?xindex>
-// CHECK:                 %[[VAL_31:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_32:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
-// CHECK:                 %[[VAL_33:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_30]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_35:.*]] = mulf %[[VAL_33]], %[[VAL_34]] : f32
-// CHECK:                 %[[VAL_36:.*]] = mulf %[[VAL_32]], %[[VAL_35]] : f32
-// CHECK:                 %[[VAL_37:.*]] = addf %[[VAL_31]], %[[VAL_36]] : f32
-// CHECK:                 store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref<?x?xf32>
+// CHECK:               %[[VAL_26:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref<?x?xf32>
+// CHECK:               %[[VAL_27:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:               %[[VAL_28:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index
+// CHECK:               %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK:               scf.for %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_6]] {
+// CHECK:                 %[[VAL_31:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<?xindex>
+// CHECK:                 %[[VAL_32:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_33:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
+// CHECK:                 %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_31]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_35:.*]] = mulf %[[VAL_26]], %[[VAL_34]] : f32
+// CHECK:                 %[[VAL_36:.*]] = mulf %[[VAL_33]], %[[VAL_35]] : f32
+// CHECK:                 %[[VAL_37:.*]] = addf %[[VAL_32]], %[[VAL_36]] : f32
+// CHECK:                 store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref<?x?xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }

diff  --git a/mlir/test/Dialect/Linalg/sparse_3d.mlir b/mlir/test/Dialect/Linalg/sparse_3d.mlir
index a6794f2e6487..a5cb834fd431 100644
--- a/mlir/test/Dialect/Linalg/sparse_3d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_3d.mlir
@@ -1192,15 +1192,15 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
 // CHECK:               %[[VAL_25:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
 // CHECK:               scf.for %[[VAL_26:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_6]] {
 // CHECK:                 %[[VAL_27:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex>
-// CHECK:                 scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] {
-// CHECK:                   %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xf32>
-// CHECK:                   %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK:                   %[[VAL_31:.*]] = mulf %[[VAL_29]], %[[VAL_30]] : f32
-// CHECK:                   %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xf32>
+// CHECK:                 scf.for %[[VAL_29:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] {
+// CHECK:                   %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<?x?xf32>
+// CHECK:                   %[[VAL_31:.*]] = mulf %[[VAL_28]], %[[VAL_30]] : f32
+// CHECK:                   %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_29]]] : memref<?x?xf32>
 // CHECK:                   %[[VAL_33:.*]] = mulf %[[VAL_31]], %[[VAL_32]] : f32
-// CHECK:                   %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK:                   %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref<?x?xf32>
 // CHECK:                   %[[VAL_35:.*]] = addf %[[VAL_33]], %[[VAL_34]] : f32
-// CHECK:                   store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK:                   store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref<?x?xf32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             }
@@ -1281,3 +1281,61 @@ func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f
   } -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+#trait_invariants = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i)>,      // a
+    affine_map<(i,j,k) -> (j)>,      // b
+    affine_map<(i,j,k) -> (k)>,      // c
+    affine_map<(i,j,k) -> (i,j,k)>   // x
+  ],
+  sparse = [
+    [ "D" ],           // a
+    [ "D" ],           // b
+    [ "D" ],           // c
+    [ "D", "D", "D" ]  // x
+  ],
+  iterator_types = ["parallel", "parallel", "parallel"],
+  doc = "x(i,j,k) = a(i) * b(j) * c(k)"
+}
+
+// CHECK-LABEL:   func @invariants(
+// CHECK-SAME:                     %[[VAL_0:.*]]: tensor<10xf32>,
+// CHECK-SAME:                     %[[VAL_1:.*]]: tensor<20xf32>,
+// CHECK-SAME:                     %[[VAL_2:.*]]: tensor<30xf32>) -> tensor<10x20x30xf32> {
+// CHECK:           %[[VAL_3:.*]] = constant 10 : index
+// CHECK:           %[[VAL_4:.*]] = constant 20 : index
+// CHECK:           %[[VAL_5:.*]] = constant 30 : index
+// CHECK:           %[[VAL_6:.*]] = constant 0 : index
+// CHECK:           %[[VAL_7:.*]] = constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = alloca() : memref<10xf32>
+// CHECK:           %[[VAL_9:.*]] = alloca() : memref<20xf32>
+// CHECK:           %[[VAL_10:.*]] = alloca() : memref<30xf32>
+// CHECK:           %[[VAL_11:.*]] = alloca() : memref<10x20x30xf32>
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK:             %[[VAL_13:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32>
+// CHECK:             scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
+// CHECK:               %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<20xf32>
+// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
+// CHECK:                 %[[VAL_17:.*]] = mulf %[[VAL_13]], %[[VAL_15]] : f32
+// CHECK:                 %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<30xf32>
+// CHECK:                 %[[VAL_19:.*]] = mulf %[[VAL_17]], %[[VAL_18]] : f32
+// CHECK:                 store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_14]], %[[VAL_16]]] : memref<10x20x30xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_20:.*]] = tensor_load %[[VAL_11]] : memref<10x20x30xf32>
+// CHECK:           return %[[VAL_20]] : tensor<10x20x30xf32>
+// CHECK:         }
+func @invariants(%arga: tensor<10xf32>,
+                 %argb: tensor<20xf32>,
+                 %argc: tensor<30xf32>) -> tensor<10x20x30xf32> {
+  %0 = linalg.generic #trait_invariants
+    ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>) {
+      ^bb(%a : f32, %b : f32, %c : f32):
+        %0 = mulf %a, %b  : f32
+        %1 = mulf %0, %c  : f32
+        linalg.yield %1: f32
+  } -> tensor<10x20x30xf32>
+  return %0 : tensor<10x20x30xf32>
+}


        


More information about the llvm-branch-commits mailing list