[Mlir-commits] [mlir] Loop unroll jam for SCF (PR #98316)

Anirudh Sathish llvmlistbot at llvm.org
Wed Jul 10 06:14:06 PDT 2024


https://github.com/Anirudh-Sathish created https://github.com/llvm/llvm-project/pull/98316

### Code for loop unroll jam of SCF


>From 5080a4efbedfe63d2fce780c1beba895e746bc9a Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 19 Jun 2024 16:42:09 +0000
Subject: [PATCH 01/13] Feat: Add code for matrix multiplication  in toy
 language

---
 mlir/examples/toy/Ch7/include/toy/Ops.td      | 18 ++++++
 mlir/examples/toy/Ch7/mlir/Dialect.cpp        | 44 ++++++++++++++
 .../toy/Ch7/mlir/LowerToAffineLoops.cpp       | 57 ++++++++++++++++++-
 mlir/examples/toy/Ch7/mlir/MLIRGen.cpp        | 10 ++++
 4 files changed, 128 insertions(+), 1 deletion(-)

diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index cfd6859eb27bf..f629efb057230 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -450,4 +450,22 @@ def TransposeOp : Toy_Op<"transpose",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : Toy_Op<"matmul",
+    [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]>
+{
+  let summary = "matrix multiplication operation";
+  let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+  let results = (outs F64Tensor:$results);
+  let builders = [
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+  ];
+  let assemblyFormat = "attr-dict $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($results)";
+  let hasVerifier = 1;
+}
+
+
 #endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index b268b1ef157f9..2668993da789c 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -498,6 +498,50 @@ mlir::LogicalResult TransposeOp::verify() {
   return mlir::success();
 }
 
+
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+void MatmulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+                        mlir::Value lhs, mlir::Value rhs) {
+  state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+  state.addOperands({lhs,rhs});
+}
+
+void MatmulOp::inferShapes() {
+  auto lhsType = llvm::dyn_cast<RankedTensorType>(getOperand(0).getType());
+  auto rhsType = llvm::dyn_cast<RankedTensorType>(getOperand(1).getType());
+
+  // Ensure that both operands are ranked tensor types
+  if (!lhsType || !rhsType) {
+    emitError("Both operands of MatmulOp must be ranked tensors.");
+    return;
+  }
+  // Get the shapes of the operands
+  ArrayRef<int64_t> lhsShape = lhsType.getShape();
+  ArrayRef<int64_t> rhsShape = rhsType.getShape();
+   // Check that the dimensions are compatible for matrix multiplication
+  if (lhsShape.size() != 2 || rhsShape.size() != 2 || lhsShape[1] != rhsShape[0]) {
+    emitError("MatmulOp requires lhs of shape (m, k) and rhs of shape (k, n).");
+    return;
+  }
+  
+  // Infer the shape of the result
+  SmallVector<int64_t, 2> resultShape = {lhsShape[0], rhsShape[1]};
+  getResult().setType(RankedTensorType::get(resultShape, lhsType.getElementType()));
+
+}
+
+mlir::LogicalResult MatmulOp::verify() {
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
+  if (!inputType || !resultType)
+    return mlir::success();
+
+  return mlir::success();
+}
+
+
 //===----------------------------------------------------------------------===//
 // Toy Types
 //===----------------------------------------------------------------------===//
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index bded61542188e..f2b55cde6926e 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -316,6 +316,61 @@ struct TransposeOpLowering : public ConversionPattern {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Matmul operations
+//===----------------------------------------------------------------------===//
+
+struct MatmulOpLowering : public ConversionPattern {
+  MatmulOpLowering(MLIRContext *ctx)
+      : ConversionPattern(toy::MatmulOp::getOperationName(), 1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto loc = op->getLoc();
+    toy::MatmulOpAdaptor matmulAdaptor(operands);
+    Value lhs = matmulAdaptor.getLhs();
+    Value rhs = matmulAdaptor.getRhs();
+    auto resultType = dyn_cast<MemRefType>(lhs.getType());
+    Value result = rewriter.create<memref::AllocOp>(loc, resultType);
+    auto lhsType = dyn_cast<MemRefType>(lhs.getType());
+    auto rhsType = dyn_cast<MemRefType>(rhs.getType());
+    int64_t M = lhsType.getShape()[0];
+    int64_t N = rhsType.getShape()[1];
+    int64_t K = lhsType.getShape()[1];
+     for (int64_t i = 0; i < M; ++i) {
+      for (int64_t j = 0; j < N; ++j) {
+        // Initialize the sum to zero.
+        Value sum = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0.0));
+
+        for (int64_t k = 0; k < K; ++k) {
+          // Load lhs[i, k] and rhs[k, j].
+          Value lhsVal = rewriter.create<affine::AffineLoadOp>(loc, lhs, ValueRange{
+            rewriter.create<arith::ConstantIndexOp>(loc, i),
+            rewriter.create<arith::ConstantIndexOp>(loc, k)
+          });
+          Value rhsVal = rewriter.create<affine::AffineLoadOp>(loc, rhs, ValueRange{
+            rewriter.create<arith::ConstantIndexOp>(loc, k),
+            rewriter.create<arith::ConstantIndexOp>(loc, j)
+          });
+
+          // Perform the multiplication and accumulate the result.
+          Value product = rewriter.create<arith::MulFOp>(loc, lhsVal, rhsVal);
+          sum = rewriter.create<arith::AddFOp>(loc, sum, product);
+        }
+
+        // Store the computed value into the result matrix.
+        rewriter.create<affine::AffineStoreOp>(loc, sum, result, ValueRange{
+          rewriter.create<arith::ConstantIndexOp>(loc, i),
+          rewriter.create<arith::ConstantIndexOp>(loc, j)
+        });
+      }
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -366,7 +421,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
   // the set of patterns that will lower the Toy operations.
   RewritePatternSet patterns(&getContext());
   patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
-               PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+               PrintOpLowering, ReturnOpLowering, TransposeOpLowering, MatmulOpLowering>(
       &getContext());
 
   // With the target and rewrite patterns defined, we can now attempt the
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index 0f8e8df38525f..4a27f22e5cd6c 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -526,6 +526,16 @@ class MLIRGenImpl {
       return builder.create<TransposeOp>(location, operands[0]);
     }
 
+    // Call for matmul
+    if(callee == "matmul") {
+      if(call.getArgs().size()!=2){
+        emitError(location,"MLIR codegen encountered an error: toy.matmul "
+                            "only accepts two arguments");
+        return nullptr;
+      }
+      return builder.create<MatmulOp>(location,operands[0],operands[1]);
+    }
+
     // Otherwise this is a call to a user-defined function. Calls to
     // user-defined functions are mapped to a custom call that takes the callee
     // name as an attribute.

>From 50476952b1d2bb906d010f714a5fbe7fb407c3f6 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 20 Jun 2024 05:20:49 +0000
Subject: [PATCH 02/13]  Fix: Add code to ensure result is in correct shape

---
 mlir/examples/toy/Ch7/include/toy/Ops.td      | 18 ++++++
 mlir/examples/toy/Ch7/mlir/Dialect.cpp        | 44 +++++++++++++
 .../toy/Ch7/mlir/LowerToAffineLoops.cpp       | 64 ++++++++++++++++++-
 mlir/examples/toy/Ch7/mlir/MLIRGen.cpp        | 10 +++
 4 files changed, 134 insertions(+), 2 deletions(-)

diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index cfd6859eb27bf..f629efb057230 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -450,4 +450,22 @@ def TransposeOp : Toy_Op<"transpose",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : Toy_Op<"matmul",
+    [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]>
+{
+  let summary = "matrix multiplication operation";
+  let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+  let results = (outs F64Tensor:$results);
+  let builders = [
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+  ];
+  let assemblyFormat = "attr-dict $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($results)";
+  let hasVerifier = 1;
+}
+
+
 #endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index b268b1ef157f9..2668993da789c 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -498,6 +498,50 @@ mlir::LogicalResult TransposeOp::verify() {
   return mlir::success();
 }
 
+
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+void MatmulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+                        mlir::Value lhs, mlir::Value rhs) {
+  state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+  state.addOperands({lhs,rhs});
+}
+
+void MatmulOp::inferShapes() {
+  auto lhsType = llvm::dyn_cast<RankedTensorType>(getOperand(0).getType());
+  auto rhsType = llvm::dyn_cast<RankedTensorType>(getOperand(1).getType());
+
+  // Ensure that both operands are ranked tensor types
+  if (!lhsType || !rhsType) {
+    emitError("Both operands of MatmulOp must be ranked tensors.");
+    return;
+  }
+  // Get the shapes of the operands
+  ArrayRef<int64_t> lhsShape = lhsType.getShape();
+  ArrayRef<int64_t> rhsShape = rhsType.getShape();
+   // Check that the dimensions are compatible for matrix multiplication
+  if (lhsShape.size() != 2 || rhsShape.size() != 2 || lhsShape[1] != rhsShape[0]) {
+    emitError("MatmulOp requires lhs of shape (m, k) and rhs of shape (k, n).");
+    return;
+  }
+  
+  // Infer the shape of the result
+  SmallVector<int64_t, 2> resultShape = {lhsShape[0], rhsShape[1]};
+  getResult().setType(RankedTensorType::get(resultShape, lhsType.getElementType()));
+
+}
+
+mlir::LogicalResult MatmulOp::verify() {
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
+  if (!inputType || !resultType)
+    return mlir::success();
+
+  return mlir::success();
+}
+
+
 //===----------------------------------------------------------------------===//
 // Toy Types
 //===----------------------------------------------------------------------===//
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index bded61542188e..5c3e6a552855b 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -25,7 +25,7 @@
 #include "mlir/Support/TypeID.h"
 #include "toy/Dialect.h"
 #include "toy/Passes.h"
-
+#include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -316,6 +316,66 @@ struct TransposeOpLowering : public ConversionPattern {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Matmul operations
+//===----------------------------------------------------------------------===//
+
+struct MatmulOpLowering : public ConversionPattern {
+  MatmulOpLowering(MLIRContext *ctx)
+      : ConversionPattern(toy::MatmulOp::getOperationName(), 1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto loc = op->getLoc();
+    toy::MatmulOpAdaptor matmulAdaptor(operands);
+    Value lhs = matmulAdaptor.getLhs();
+    Value rhs = matmulAdaptor.getRhs();
+    
+    auto lhsType = dyn_cast<MemRefType>(lhs.getType());
+    auto rhsType = dyn_cast<MemRefType>(rhs.getType());
+    if (!lhsType || !rhsType) {
+      return failure();
+    }
+    int64_t M = lhsType.getShape()[0];
+    int64_t N = rhsType.getShape()[1];
+    int64_t K = lhsType.getShape()[1];
+    auto elementType = lhsType.getElementType();
+    auto resultType = MemRefType::get({M, N}, elementType);
+    Value result = rewriter.create<memref::AllocOp>(loc, resultType);
+     for (int64_t i = 0; i < M; ++i) {
+      for (int64_t j = 0; j < N; ++j) {
+        // Initialize the sum to zero.
+        Value sum = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0.0));
+
+        for (int64_t k = 0; k < K; ++k) {
+          // Load lhs[i, k] and rhs[k, j].
+          Value lhsVal = rewriter.create<affine::AffineLoadOp>(loc, lhs, ValueRange{
+            rewriter.create<arith::ConstantIndexOp>(loc, i),
+            rewriter.create<arith::ConstantIndexOp>(loc, k)
+          });
+          Value rhsVal = rewriter.create<affine::AffineLoadOp>(loc, rhs, ValueRange{
+            rewriter.create<arith::ConstantIndexOp>(loc, k),
+            rewriter.create<arith::ConstantIndexOp>(loc, j)
+          });
+
+          // Perform the multiplication and accumulate the result.
+          Value product = rewriter.create<arith::MulFOp>(loc, lhsVal, rhsVal);
+          sum = rewriter.create<arith::AddFOp>(loc, sum, product);
+        }
+
+        // Store the computed value into the result matrix.
+        rewriter.create<affine::AffineStoreOp>(loc, sum, result, ValueRange{
+          rewriter.create<arith::ConstantIndexOp>(loc, i),
+          rewriter.create<arith::ConstantIndexOp>(loc, j)
+        });
+      }
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -366,7 +426,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
   // the set of patterns that will lower the Toy operations.
   RewritePatternSet patterns(&getContext());
   patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
-               PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+               PrintOpLowering, ReturnOpLowering, TransposeOpLowering, MatmulOpLowering>(
       &getContext());
 
   // With the target and rewrite patterns defined, we can now attempt the
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index 0f8e8df38525f..4a27f22e5cd6c 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -526,6 +526,16 @@ class MLIRGenImpl {
       return builder.create<TransposeOp>(location, operands[0]);
     }
 
+    // Call for matmul
+    if(callee == "matmul") {
+      if(call.getArgs().size()!=2){
+        emitError(location,"MLIR codegen encountered an error: toy.matmul "
+                            "only accepts two arguments");
+        return nullptr;
+      }
+      return builder.create<MatmulOp>(location,operands[0],operands[1]);
+    }
+
     // Otherwise this is a call to a user-defined function. Calls to
     // user-defined functions are mapped to a custom call that takes the callee
     // name as an attribute.

>From d844cbe8753cac33fa19bc1b1f16b044154dc09e Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 20 Jun 2024 05:28:43 +0000
Subject: [PATCH 03/13]  Test: Add test cases for matmul operation

---
 .../Examples/Toy/Matmul/matmul1-affine.mlir   |  82 +++
 mlir/test/Examples/Toy/Matmul/matmul1.llvm    | 239 ++++++++
 mlir/test/Examples/Toy/Matmul/matmul1.mlir    |  11 +
 mlir/test/Examples/Toy/Matmul/matmul1.toy     |   6 +
 .../Examples/Toy/Matmul/matmul2-affine.mlir   | 210 +++++++
 mlir/test/Examples/Toy/Matmul/matmul2.llvm    | 527 ++++++++++++++++++
 mlir/test/Examples/Toy/Matmul/matmul2.mlir    |  11 +
 mlir/test/Examples/Toy/Matmul/matmul2.toy     |   6 +
 .../Examples/Toy/Matmul/matmul3-affine.mlir   |  55 ++
 mlir/test/Examples/Toy/Matmul/matmul3.llvm    | 181 ++++++
 mlir/test/Examples/Toy/Matmul/matmul3.mlir    |  11 +
 mlir/test/Examples/Toy/Matmul/matmul3.toy     |   7 +
 12 files changed, 1346 insertions(+)
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1.llvm
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1.mlir
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1.toy
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2.llvm
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2.mlir
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2.toy
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3.llvm
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3.mlir
 create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3.toy

diff --git a/mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir b/mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir
new file mode 100644
index 0000000000000..f1734ab60e3d4
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir
@@ -0,0 +1,82 @@
+module {
+  func.func @main() {
+    %cst = arith.constant 0.000000e+00 : f64
+    %cst_0 = arith.constant 6.000000e+00 : f64
+    %cst_1 = arith.constant 5.000000e+00 : f64
+    %cst_2 = arith.constant 4.000000e+00 : f64
+    %cst_3 = arith.constant 3.000000e+00 : f64
+    %cst_4 = arith.constant 2.000000e+00 : f64
+    %cst_5 = arith.constant 1.000000e+00 : f64
+    %alloc = memref.alloc() : memref<3x2xf64>
+    %alloc_6 = memref.alloc() : memref<2x3xf64>
+    affine.store %cst_5, %alloc_6[0, 0] : memref<2x3xf64>
+    affine.store %cst_4, %alloc_6[0, 1] : memref<2x3xf64>
+    affine.store %cst_3, %alloc_6[0, 2] : memref<2x3xf64>
+    affine.store %cst_2, %alloc_6[1, 0] : memref<2x3xf64>
+    affine.store %cst_1, %alloc_6[1, 1] : memref<2x3xf64>
+    affine.store %cst_0, %alloc_6[1, 2] : memref<2x3xf64>
+    affine.store %cst_5, %alloc[0, 0] : memref<3x2xf64>
+    affine.store %cst_4, %alloc[0, 1] : memref<3x2xf64>
+    affine.store %cst_3, %alloc[1, 0] : memref<3x2xf64>
+    affine.store %cst_2, %alloc[1, 1] : memref<3x2xf64>
+    affine.store %cst_1, %alloc[2, 0] : memref<3x2xf64>
+    affine.store %cst_0, %alloc[2, 1] : memref<3x2xf64>
+    %alloc_7 = memref.alloc() : memref<2x2xf64>
+    %0 = affine.load %alloc_6[0, 0] : memref<2x3xf64>
+    %1 = affine.load %alloc[0, 0] : memref<3x2xf64>
+    %2 = arith.mulf %0, %1 : f64
+    %3 = arith.addf %2, %cst : f64
+    %4 = affine.load %alloc_6[0, 1] : memref<2x3xf64>
+    %5 = affine.load %alloc[1, 0] : memref<3x2xf64>
+    %6 = arith.mulf %4, %5 : f64
+    %7 = arith.addf %3, %6 : f64
+    %8 = affine.load %alloc_6[0, 2] : memref<2x3xf64>
+    %9 = affine.load %alloc[2, 0] : memref<3x2xf64>
+    %10 = arith.mulf %8, %9 : f64
+    %11 = arith.addf %7, %10 : f64
+    affine.store %11, %alloc_7[0, 0] : memref<2x2xf64>
+    %12 = affine.load %alloc_6[0, 0] : memref<2x3xf64>
+    %13 = affine.load %alloc[0, 1] : memref<3x2xf64>
+    %14 = arith.mulf %12, %13 : f64
+    %15 = arith.addf %14, %cst : f64
+    %16 = affine.load %alloc_6[0, 1] : memref<2x3xf64>
+    %17 = affine.load %alloc[1, 1] : memref<3x2xf64>
+    %18 = arith.mulf %16, %17 : f64
+    %19 = arith.addf %15, %18 : f64
+    %20 = affine.load %alloc_6[0, 2] : memref<2x3xf64>
+    %21 = affine.load %alloc[2, 1] : memref<3x2xf64>
+    %22 = arith.mulf %20, %21 : f64
+    %23 = arith.addf %19, %22 : f64
+    affine.store %23, %alloc_7[0, 1] : memref<2x2xf64>
+    %24 = affine.load %alloc_6[1, 0] : memref<2x3xf64>
+    %25 = affine.load %alloc[0, 0] : memref<3x2xf64>
+    %26 = arith.mulf %24, %25 : f64
+    %27 = arith.addf %26, %cst : f64
+    %28 = affine.load %alloc_6[1, 1] : memref<2x3xf64>
+    %29 = affine.load %alloc[1, 0] : memref<3x2xf64>
+    %30 = arith.mulf %28, %29 : f64
+    %31 = arith.addf %27, %30 : f64
+    %32 = affine.load %alloc_6[1, 2] : memref<2x3xf64>
+    %33 = affine.load %alloc[2, 0] : memref<3x2xf64>
+    %34 = arith.mulf %32, %33 : f64
+    %35 = arith.addf %31, %34 : f64
+    affine.store %35, %alloc_7[1, 0] : memref<2x2xf64>
+    %36 = affine.load %alloc_6[1, 0] : memref<2x3xf64>
+    %37 = affine.load %alloc[0, 1] : memref<3x2xf64>
+    %38 = arith.mulf %36, %37 : f64
+    %39 = arith.addf %38, %cst : f64
+    %40 = affine.load %alloc_6[1, 1] : memref<2x3xf64>
+    %41 = affine.load %alloc[1, 1] : memref<3x2xf64>
+    %42 = arith.mulf %40, %41 : f64
+    %43 = arith.addf %39, %42 : f64
+    %44 = affine.load %alloc_6[1, 2] : memref<2x3xf64>
+    %45 = affine.load %alloc[2, 1] : memref<3x2xf64>
+    %46 = arith.mulf %44, %45 : f64
+    %47 = arith.addf %43, %46 : f64
+    affine.store %47, %alloc_7[1, 1] : memref<2x2xf64>
+    toy.print %alloc_7 : memref<2x2xf64>
+    memref.dealloc %alloc_6 : memref<2x3xf64>
+    memref.dealloc %alloc : memref<3x2xf64>
+    return
+  }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.llvm b/mlir/test/Examples/Toy/Matmul/matmul1.llvm
new file mode 100644
index 0000000000000..65f4e07d37f0f
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.llvm
@@ -0,0 +1,239 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at nl = internal constant [2 x i8] c"\0A\00"
+ at frmt_spec = internal constant [4 x i8] c"%f \00"
+
+declare !dbg !3 void @free(ptr)
+
+declare !dbg !6 i32 @printf(ptr, ...)
+
+declare !dbg !7 ptr @malloc(i64)
+
+define void @main() !dbg !8 {
+  %1 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !9
+  %2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %1, 0, !dbg !9
+  %3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %2, ptr %1, 1, !dbg !9
+  %4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, i64 0, 2, !dbg !9
+  %5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %4, i64 3, 3, 0, !dbg !9
+  %6 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %5, i64 2, 3, 1, !dbg !9
+  %7 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %6, i64 2, 4, 0, !dbg !9
+  %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %7, i64 1, 4, 1, !dbg !9
+  %9 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !10
+  %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %9, 0, !dbg !10
+  %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, ptr %9, 1, !dbg !10
+  %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 0, 2, !dbg !10
+  %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 2, 3, 0, !dbg !10
+  %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 3, 3, 1, !dbg !10
+  %15 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, i64 3, 4, 0, !dbg !10
+  %16 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %15, i64 1, 4, 1, !dbg !10
+  %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %18 = getelementptr double, ptr %17, i64 0, !dbg !10
+  store double 1.000000e+00, ptr %18, align 8, !dbg !10
+  %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %20 = getelementptr double, ptr %19, i64 1, !dbg !10
+  store double 2.000000e+00, ptr %20, align 8, !dbg !10
+  %21 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %22 = getelementptr double, ptr %21, i64 2, !dbg !10
+  store double 3.000000e+00, ptr %22, align 8, !dbg !10
+  %23 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %24 = getelementptr double, ptr %23, i64 3, !dbg !10
+  store double 4.000000e+00, ptr %24, align 8, !dbg !10
+  %25 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %26 = getelementptr double, ptr %25, i64 4, !dbg !10
+  store double 5.000000e+00, ptr %26, align 8, !dbg !10
+  %27 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %28 = getelementptr double, ptr %27, i64 5, !dbg !10
+  store double 6.000000e+00, ptr %28, align 8, !dbg !10
+  %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %30 = getelementptr double, ptr %29, i64 0, !dbg !9
+  store double 1.000000e+00, ptr %30, align 8, !dbg !9
+  %31 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %32 = getelementptr double, ptr %31, i64 1, !dbg !9
+  store double 2.000000e+00, ptr %32, align 8, !dbg !9
+  %33 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %34 = getelementptr double, ptr %33, i64 2, !dbg !9
+  store double 3.000000e+00, ptr %34, align 8, !dbg !9
+  %35 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %36 = getelementptr double, ptr %35, i64 3, !dbg !9
+  store double 4.000000e+00, ptr %36, align 8, !dbg !9
+  %37 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %38 = getelementptr double, ptr %37, i64 4, !dbg !9
+  store double 5.000000e+00, ptr %38, align 8, !dbg !9
+  %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %40 = getelementptr double, ptr %39, i64 5, !dbg !9
+  store double 6.000000e+00, ptr %40, align 8, !dbg !9
+  %41 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 4) to i64)), !dbg !11
+  %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %41, 0, !dbg !11
+  %43 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, ptr %41, 1, !dbg !11
+  %44 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %43, i64 0, 2, !dbg !11
+  %45 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %44, i64 2, 3, 0, !dbg !11
+  %46 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %45, i64 2, 3, 1, !dbg !11
+  %47 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %46, i64 2, 4, 0, !dbg !11
+  %48 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %47, i64 1, 4, 1, !dbg !11
+  %49 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %50 = getelementptr double, ptr %49, i64 0, !dbg !11
+  %51 = load double, ptr %50, align 8, !dbg !11
+  %52 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %53 = getelementptr double, ptr %52, i64 0, !dbg !11
+  %54 = load double, ptr %53, align 8, !dbg !11
+  %55 = fmul double %51, %54, !dbg !11
+  %56 = fadd double %55, 0.000000e+00, !dbg !11
+  %57 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %58 = getelementptr double, ptr %57, i64 1, !dbg !11
+  %59 = load double, ptr %58, align 8, !dbg !11
+  %60 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %61 = getelementptr double, ptr %60, i64 2, !dbg !11
+  %62 = load double, ptr %61, align 8, !dbg !11
+  %63 = fmul double %59, %62, !dbg !11
+  %64 = fadd double %56, %63, !dbg !11
+  %65 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %66 = getelementptr double, ptr %65, i64 2, !dbg !11
+  %67 = load double, ptr %66, align 8, !dbg !11
+  %68 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %69 = getelementptr double, ptr %68, i64 4, !dbg !11
+  %70 = load double, ptr %69, align 8, !dbg !11
+  %71 = fmul double %67, %70, !dbg !11
+  %72 = fadd double %64, %71, !dbg !11
+  %73 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %74 = getelementptr double, ptr %73, i64 0, !dbg !11
+  store double %72, ptr %74, align 8, !dbg !11
+  %75 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %76 = getelementptr double, ptr %75, i64 0, !dbg !11
+  %77 = load double, ptr %76, align 8, !dbg !11
+  %78 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %79 = getelementptr double, ptr %78, i64 1, !dbg !11
+  %80 = load double, ptr %79, align 8, !dbg !11
+  %81 = fmul double %77, %80, !dbg !11
+  %82 = fadd double %81, 0.000000e+00, !dbg !11
+  %83 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %84 = getelementptr double, ptr %83, i64 1, !dbg !11
+  %85 = load double, ptr %84, align 8, !dbg !11
+  %86 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %87 = getelementptr double, ptr %86, i64 3, !dbg !11
+  %88 = load double, ptr %87, align 8, !dbg !11
+  %89 = fmul double %85, %88, !dbg !11
+  %90 = fadd double %82, %89, !dbg !11
+  %91 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %92 = getelementptr double, ptr %91, i64 2, !dbg !11
+  %93 = load double, ptr %92, align 8, !dbg !11
+  %94 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %95 = getelementptr double, ptr %94, i64 5, !dbg !11
+  %96 = load double, ptr %95, align 8, !dbg !11
+  %97 = fmul double %93, %96, !dbg !11
+  %98 = fadd double %90, %97, !dbg !11
+  %99 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %100 = getelementptr double, ptr %99, i64 1, !dbg !11
+  store double %98, ptr %100, align 8, !dbg !11
+  %101 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %102 = getelementptr double, ptr %101, i64 3, !dbg !11
+  %103 = load double, ptr %102, align 8, !dbg !11
+  %104 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %105 = getelementptr double, ptr %104, i64 0, !dbg !11
+  %106 = load double, ptr %105, align 8, !dbg !11
+  %107 = fmul double %103, %106, !dbg !11
+  %108 = fadd double %107, 0.000000e+00, !dbg !11
+  %109 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %110 = getelementptr double, ptr %109, i64 4, !dbg !11
+  %111 = load double, ptr %110, align 8, !dbg !11
+  %112 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %113 = getelementptr double, ptr %112, i64 2, !dbg !11
+  %114 = load double, ptr %113, align 8, !dbg !11
+  %115 = fmul double %111, %114, !dbg !11
+  %116 = fadd double %108, %115, !dbg !11
+  %117 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %118 = getelementptr double, ptr %117, i64 5, !dbg !11
+  %119 = load double, ptr %118, align 8, !dbg !11
+  %120 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %121 = getelementptr double, ptr %120, i64 4, !dbg !11
+  %122 = load double, ptr %121, align 8, !dbg !11
+  %123 = fmul double %119, %122, !dbg !11
+  %124 = fadd double %116, %123, !dbg !11
+  %125 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %126 = getelementptr double, ptr %125, i64 2, !dbg !11
+  store double %124, ptr %126, align 8, !dbg !11
+  %127 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %128 = getelementptr double, ptr %127, i64 3, !dbg !11
+  %129 = load double, ptr %128, align 8, !dbg !11
+  %130 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %131 = getelementptr double, ptr %130, i64 1, !dbg !11
+  %132 = load double, ptr %131, align 8, !dbg !11
+  %133 = fmul double %129, %132, !dbg !11
+  %134 = fadd double %133, 0.000000e+00, !dbg !11
+  %135 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %136 = getelementptr double, ptr %135, i64 4, !dbg !11
+  %137 = load double, ptr %136, align 8, !dbg !11
+  %138 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %139 = getelementptr double, ptr %138, i64 3, !dbg !11
+  %140 = load double, ptr %139, align 8, !dbg !11
+  %141 = fmul double %137, %140, !dbg !11
+  %142 = fadd double %134, %141, !dbg !11
+  %143 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %144 = getelementptr double, ptr %143, i64 5, !dbg !11
+  %145 = load double, ptr %144, align 8, !dbg !11
+  %146 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %147 = getelementptr double, ptr %146, i64 5, !dbg !11
+  %148 = load double, ptr %147, align 8, !dbg !11
+  %149 = fmul double %145, %148, !dbg !11
+  %150 = fadd double %142, %149, !dbg !11
+  %151 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %152 = getelementptr double, ptr %151, i64 3, !dbg !11
+  store double %150, ptr %152, align 8, !dbg !11
+  br label %153, !dbg !12
+
+153:                                              ; preds = %168, %0
+  %154 = phi i64 [ 0, %0 ], [ %170, %168 ]
+  %155 = icmp slt i64 %154, 2, !dbg !12
+  br i1 %155, label %156, label %171, !dbg !12
+
+156:                                              ; preds = %153
+  br label %157, !dbg !12
+
+157:                                              ; preds = %160, %156
+  %158 = phi i64 [ 0, %156 ], [ %167, %160 ]
+  %159 = icmp slt i64 %158, 2, !dbg !12
+  br i1 %159, label %160, label %168, !dbg !12
+
+160:                                              ; preds = %157
+  %161 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !12
+  %162 = mul i64 %154, 2, !dbg !12
+  %163 = add i64 %162, %158, !dbg !12
+  %164 = getelementptr double, ptr %161, i64 %163, !dbg !12
+  %165 = load double, ptr %164, align 8, !dbg !12
+  %166 = call i32 (ptr, ...) @printf(ptr @frmt_spec, double %165), !dbg !12
+  %167 = add i64 %158, 1, !dbg !12
+  br label %157, !dbg !12
+
+168:                                              ; preds = %157
+  %169 = call i32 (ptr, ...) @printf(ptr @nl), !dbg !12
+  %170 = add i64 %154, 1, !dbg !12
+  br label %153, !dbg !12
+
+171:                                              ; preds = %153
+  %172 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 0, !dbg !10
+  call void @free(ptr %172), !dbg !10
+  %173 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 0, !dbg !9
+  call void @free(ptr %173), !dbg !9
+  ret void, !dbg !13
+}
+
+!llvm.module.flags = !{!0}
+!llvm.dbg.cu = !{!1}
+
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2, producer: "MLIR", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
+!2 = !DIFile(filename: "matmul1.mlir", directory: "")
+!3 = !DISubprogram(name: "free", linkageName: "free", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!4 = !DISubroutineType(cc: DW_CC_normal, types: !5)
+!5 = !{}
+!6 = !DISubprogram(name: "printf", linkageName: "printf", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!7 = !DISubprogram(name: "malloc", linkageName: "malloc", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!8 = distinct !DISubprogram(name: "main", linkageName: "main", scope: !2, file: !2, line: 2, type: !4, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
+!9 = !DILocation(line: 6, column: 10, scope: !8)
+!10 = !DILocation(line: 4, column: 10, scope: !8)
+!11 = !DILocation(line: 7, column: 10, scope: !8)
+!12 = !DILocation(line: 8, column: 5, scope: !8)
+!13 = !DILocation(line: 9, column: 5, scope: !8)
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.mlir b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
new file mode 100644
index 0000000000000..956da787f9b36
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
@@ -0,0 +1,11 @@
+module {
+  toy.func @main() {
+    %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+    %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<2x3xf64>
+    %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+    %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64>
+    %4 = toy.matmul %1, %3 : tensor<2x3xf64>, tensor<3x2xf64> -> tensor<*xf64>
+    toy.print %4 : tensor<*xf64>
+    toy.return
+  }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.toy b/mlir/test/Examples/Toy/Matmul/matmul1.toy
new file mode 100644
index 0000000000000..823608a6a2f02
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.toy
@@ -0,0 +1,6 @@
+def main() {
+  var a<2, 3> = [1, 2, 3, 4, 5, 6];
+  var b<3, 2> = [1, 2, 3, 4, 5, 6];
+  var c = matmul(a,b);
+  print(c);
+}
\ No newline at end of file
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir b/mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir
new file mode 100644
index 0000000000000..88a235f662b15
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir
@@ -0,0 +1,210 @@
+module {
+  func.func @main() {
+    %cst = arith.constant 0.000000e+00 : f64
+    %cst_0 = arith.constant 6.000000e+00 : f64
+    %cst_1 = arith.constant 5.000000e+00 : f64
+    %cst_2 = arith.constant 4.000000e+00 : f64
+    %cst_3 = arith.constant 3.000000e+00 : f64
+    %cst_4 = arith.constant 2.000000e+00 : f64
+    %cst_5 = arith.constant 1.000000e+00 : f64
+    %alloc = memref.alloc() : memref<1x6xf64>
+    %alloc_6 = memref.alloc() : memref<6x1xf64>
+    affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+    affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+    affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+    affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+    affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+    affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+    affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+    affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+    affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+    affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+    affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+    affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+    %alloc_7 = memref.alloc() : memref<6x6xf64>
+    %0 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %1 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %2 = arith.mulf %0, %1 : f64
+    %3 = arith.addf %2, %cst : f64
+    affine.store %3, %alloc_7[0, 0] : memref<6x6xf64>
+    %4 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %5 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %6 = arith.mulf %4, %5 : f64
+    %7 = arith.addf %6, %cst : f64
+    affine.store %7, %alloc_7[0, 1] : memref<6x6xf64>
+    %8 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %9 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %10 = arith.mulf %8, %9 : f64
+    %11 = arith.addf %10, %cst : f64
+    affine.store %11, %alloc_7[0, 2] : memref<6x6xf64>
+    %12 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %13 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %14 = arith.mulf %12, %13 : f64
+    %15 = arith.addf %14, %cst : f64
+    affine.store %15, %alloc_7[0, 3] : memref<6x6xf64>
+    %16 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %17 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %18 = arith.mulf %16, %17 : f64
+    %19 = arith.addf %18, %cst : f64
+    affine.store %19, %alloc_7[0, 4] : memref<6x6xf64>
+    %20 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %21 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %22 = arith.mulf %20, %21 : f64
+    %23 = arith.addf %22, %cst : f64
+    affine.store %23, %alloc_7[0, 5] : memref<6x6xf64>
+    %24 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %25 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %26 = arith.mulf %24, %25 : f64
+    %27 = arith.addf %26, %cst : f64
+    affine.store %27, %alloc_7[1, 0] : memref<6x6xf64>
+    %28 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %29 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %30 = arith.mulf %28, %29 : f64
+    %31 = arith.addf %30, %cst : f64
+    affine.store %31, %alloc_7[1, 1] : memref<6x6xf64>
+    %32 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %33 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %34 = arith.mulf %32, %33 : f64
+    %35 = arith.addf %34, %cst : f64
+    affine.store %35, %alloc_7[1, 2] : memref<6x6xf64>
+    %36 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %37 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %38 = arith.mulf %36, %37 : f64
+    %39 = arith.addf %38, %cst : f64
+    affine.store %39, %alloc_7[1, 3] : memref<6x6xf64>
+    %40 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %41 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %42 = arith.mulf %40, %41 : f64
+    %43 = arith.addf %42, %cst : f64
+    affine.store %43, %alloc_7[1, 4] : memref<6x6xf64>
+    %44 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %45 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %46 = arith.mulf %44, %45 : f64
+    %47 = arith.addf %46, %cst : f64
+    affine.store %47, %alloc_7[1, 5] : memref<6x6xf64>
+    %48 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %49 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %50 = arith.mulf %48, %49 : f64
+    %51 = arith.addf %50, %cst : f64
+    affine.store %51, %alloc_7[2, 0] : memref<6x6xf64>
+    %52 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %53 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %54 = arith.mulf %52, %53 : f64
+    %55 = arith.addf %54, %cst : f64
+    affine.store %55, %alloc_7[2, 1] : memref<6x6xf64>
+    %56 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %57 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %58 = arith.mulf %56, %57 : f64
+    %59 = arith.addf %58, %cst : f64
+    affine.store %59, %alloc_7[2, 2] : memref<6x6xf64>
+    %60 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %61 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %62 = arith.mulf %60, %61 : f64
+    %63 = arith.addf %62, %cst : f64
+    affine.store %63, %alloc_7[2, 3] : memref<6x6xf64>
+    %64 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %65 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %66 = arith.mulf %64, %65 : f64
+    %67 = arith.addf %66, %cst : f64
+    affine.store %67, %alloc_7[2, 4] : memref<6x6xf64>
+    %68 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %69 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %70 = arith.mulf %68, %69 : f64
+    %71 = arith.addf %70, %cst : f64
+    affine.store %71, %alloc_7[2, 5] : memref<6x6xf64>
+    %72 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %73 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %74 = arith.mulf %72, %73 : f64
+    %75 = arith.addf %74, %cst : f64
+    affine.store %75, %alloc_7[3, 0] : memref<6x6xf64>
+    %76 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %77 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %78 = arith.mulf %76, %77 : f64
+    %79 = arith.addf %78, %cst : f64
+    affine.store %79, %alloc_7[3, 1] : memref<6x6xf64>
+    %80 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %81 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %82 = arith.mulf %80, %81 : f64
+    %83 = arith.addf %82, %cst : f64
+    affine.store %83, %alloc_7[3, 2] : memref<6x6xf64>
+    %84 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %85 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %86 = arith.mulf %84, %85 : f64
+    %87 = arith.addf %86, %cst : f64
+    affine.store %87, %alloc_7[3, 3] : memref<6x6xf64>
+    %88 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %89 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %90 = arith.mulf %88, %89 : f64
+    %91 = arith.addf %90, %cst : f64
+    affine.store %91, %alloc_7[3, 4] : memref<6x6xf64>
+    %92 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %93 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %94 = arith.mulf %92, %93 : f64
+    %95 = arith.addf %94, %cst : f64
+    affine.store %95, %alloc_7[3, 5] : memref<6x6xf64>
+    %96 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %97 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %98 = arith.mulf %96, %97 : f64
+    %99 = arith.addf %98, %cst : f64
+    affine.store %99, %alloc_7[4, 0] : memref<6x6xf64>
+    %100 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %101 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %102 = arith.mulf %100, %101 : f64
+    %103 = arith.addf %102, %cst : f64
+    affine.store %103, %alloc_7[4, 1] : memref<6x6xf64>
+    %104 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %105 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %106 = arith.mulf %104, %105 : f64
+    %107 = arith.addf %106, %cst : f64
+    affine.store %107, %alloc_7[4, 2] : memref<6x6xf64>
+    %108 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %109 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %110 = arith.mulf %108, %109 : f64
+    %111 = arith.addf %110, %cst : f64
+    affine.store %111, %alloc_7[4, 3] : memref<6x6xf64>
+    %112 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %113 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %114 = arith.mulf %112, %113 : f64
+    %115 = arith.addf %114, %cst : f64
+    affine.store %115, %alloc_7[4, 4] : memref<6x6xf64>
+    %116 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %117 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %118 = arith.mulf %116, %117 : f64
+    %119 = arith.addf %118, %cst : f64
+    affine.store %119, %alloc_7[4, 5] : memref<6x6xf64>
+    %120 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %121 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %122 = arith.mulf %120, %121 : f64
+    %123 = arith.addf %122, %cst : f64
+    affine.store %123, %alloc_7[5, 0] : memref<6x6xf64>
+    %124 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %125 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %126 = arith.mulf %124, %125 : f64
+    %127 = arith.addf %126, %cst : f64
+    affine.store %127, %alloc_7[5, 1] : memref<6x6xf64>
+    %128 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %129 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %130 = arith.mulf %128, %129 : f64
+    %131 = arith.addf %130, %cst : f64
+    affine.store %131, %alloc_7[5, 2] : memref<6x6xf64>
+    %132 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %133 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %134 = arith.mulf %132, %133 : f64
+    %135 = arith.addf %134, %cst : f64
+    affine.store %135, %alloc_7[5, 3] : memref<6x6xf64>
+    %136 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %137 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %138 = arith.mulf %136, %137 : f64
+    %139 = arith.addf %138, %cst : f64
+    affine.store %139, %alloc_7[5, 4] : memref<6x6xf64>
+    %140 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %141 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %142 = arith.mulf %140, %141 : f64
+    %143 = arith.addf %142, %cst : f64
+    affine.store %143, %alloc_7[5, 5] : memref<6x6xf64>
+    toy.print %alloc_7 : memref<6x6xf64>
+    memref.dealloc %alloc_6 : memref<6x1xf64>
+    memref.dealloc %alloc : memref<1x6xf64>
+    return
+  }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.llvm b/mlir/test/Examples/Toy/Matmul/matmul2.llvm
new file mode 100644
index 0000000000000..0d777aa2819ed
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.llvm
@@ -0,0 +1,527 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at nl = internal constant [2 x i8] c"\0A\00"
+ at frmt_spec = internal constant [4 x i8] c"%f \00"
+
+declare !dbg !3 void @free(ptr)
+
+declare !dbg !6 i32 @printf(ptr, ...)
+
+declare !dbg !7 ptr @malloc(i64)
+
+define void @main() !dbg !8 {
+  %1 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !9
+  %2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %1, 0, !dbg !9
+  %3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %2, ptr %1, 1, !dbg !9
+  %4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, i64 0, 2, !dbg !9
+  %5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %4, i64 1, 3, 0, !dbg !9
+  %6 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %5, i64 6, 3, 1, !dbg !9
+  %7 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %6, i64 6, 4, 0, !dbg !9
+  %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %7, i64 1, 4, 1, !dbg !9
+  %9 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !10
+  %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %9, 0, !dbg !10
+  %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, ptr %9, 1, !dbg !10
+  %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 0, 2, !dbg !10
+  %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 6, 3, 0, !dbg !10
+  %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 1, 3, 1, !dbg !10
+  %15 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, i64 1, 4, 0, !dbg !10
+  %16 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %15, i64 1, 4, 1, !dbg !10
+  %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %18 = getelementptr double, ptr %17, i64 0, !dbg !10
+  store double 1.000000e+00, ptr %18, align 8, !dbg !10
+  %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %20 = getelementptr double, ptr %19, i64 1, !dbg !10
+  store double 2.000000e+00, ptr %20, align 8, !dbg !10
+  %21 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %22 = getelementptr double, ptr %21, i64 2, !dbg !10
+  store double 3.000000e+00, ptr %22, align 8, !dbg !10
+  %23 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %24 = getelementptr double, ptr %23, i64 3, !dbg !10
+  store double 4.000000e+00, ptr %24, align 8, !dbg !10
+  %25 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %26 = getelementptr double, ptr %25, i64 4, !dbg !10
+  store double 5.000000e+00, ptr %26, align 8, !dbg !10
+  %27 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %28 = getelementptr double, ptr %27, i64 5, !dbg !10
+  store double 6.000000e+00, ptr %28, align 8, !dbg !10
+  %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %30 = getelementptr double, ptr %29, i64 0, !dbg !9
+  store double 1.000000e+00, ptr %30, align 8, !dbg !9
+  %31 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %32 = getelementptr double, ptr %31, i64 1, !dbg !9
+  store double 2.000000e+00, ptr %32, align 8, !dbg !9
+  %33 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %34 = getelementptr double, ptr %33, i64 2, !dbg !9
+  store double 3.000000e+00, ptr %34, align 8, !dbg !9
+  %35 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %36 = getelementptr double, ptr %35, i64 3, !dbg !9
+  store double 4.000000e+00, ptr %36, align 8, !dbg !9
+  %37 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %38 = getelementptr double, ptr %37, i64 4, !dbg !9
+  store double 5.000000e+00, ptr %38, align 8, !dbg !9
+  %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %40 = getelementptr double, ptr %39, i64 5, !dbg !9
+  store double 6.000000e+00, ptr %40, align 8, !dbg !9
+  %41 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 36) to i64)), !dbg !11
+  %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %41, 0, !dbg !11
+  %43 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, ptr %41, 1, !dbg !11
+  %44 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %43, i64 0, 2, !dbg !11
+  %45 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %44, i64 6, 3, 0, !dbg !11
+  %46 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %45, i64 6, 3, 1, !dbg !11
+  %47 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %46, i64 6, 4, 0, !dbg !11
+  %48 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %47, i64 1, 4, 1, !dbg !11
+  %49 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %50 = getelementptr double, ptr %49, i64 0, !dbg !11
+  %51 = load double, ptr %50, align 8, !dbg !11
+  %52 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %53 = getelementptr double, ptr %52, i64 0, !dbg !11
+  %54 = load double, ptr %53, align 8, !dbg !11
+  %55 = fmul double %51, %54, !dbg !11
+  %56 = fadd double %55, 0.000000e+00, !dbg !11
+  %57 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %58 = getelementptr double, ptr %57, i64 0, !dbg !11
+  store double %56, ptr %58, align 8, !dbg !11
+  %59 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %60 = getelementptr double, ptr %59, i64 0, !dbg !11
+  %61 = load double, ptr %60, align 8, !dbg !11
+  %62 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %63 = getelementptr double, ptr %62, i64 1, !dbg !11
+  %64 = load double, ptr %63, align 8, !dbg !11
+  %65 = fmul double %61, %64, !dbg !11
+  %66 = fadd double %65, 0.000000e+00, !dbg !11
+  %67 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %68 = getelementptr double, ptr %67, i64 1, !dbg !11
+  store double %66, ptr %68, align 8, !dbg !11
+  %69 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %70 = getelementptr double, ptr %69, i64 0, !dbg !11
+  %71 = load double, ptr %70, align 8, !dbg !11
+  %72 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %73 = getelementptr double, ptr %72, i64 2, !dbg !11
+  %74 = load double, ptr %73, align 8, !dbg !11
+  %75 = fmul double %71, %74, !dbg !11
+  %76 = fadd double %75, 0.000000e+00, !dbg !11
+  %77 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %78 = getelementptr double, ptr %77, i64 2, !dbg !11
+  store double %76, ptr %78, align 8, !dbg !11
+  %79 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %80 = getelementptr double, ptr %79, i64 0, !dbg !11
+  %81 = load double, ptr %80, align 8, !dbg !11
+  %82 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %83 = getelementptr double, ptr %82, i64 3, !dbg !11
+  %84 = load double, ptr %83, align 8, !dbg !11
+  %85 = fmul double %81, %84, !dbg !11
+  %86 = fadd double %85, 0.000000e+00, !dbg !11
+  %87 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %88 = getelementptr double, ptr %87, i64 3, !dbg !11
+  store double %86, ptr %88, align 8, !dbg !11
+  %89 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %90 = getelementptr double, ptr %89, i64 0, !dbg !11
+  %91 = load double, ptr %90, align 8, !dbg !11
+  %92 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %93 = getelementptr double, ptr %92, i64 4, !dbg !11
+  %94 = load double, ptr %93, align 8, !dbg !11
+  %95 = fmul double %91, %94, !dbg !11
+  %96 = fadd double %95, 0.000000e+00, !dbg !11
+  %97 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %98 = getelementptr double, ptr %97, i64 4, !dbg !11
+  store double %96, ptr %98, align 8, !dbg !11
+  %99 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %100 = getelementptr double, ptr %99, i64 0, !dbg !11
+  %101 = load double, ptr %100, align 8, !dbg !11
+  %102 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %103 = getelementptr double, ptr %102, i64 5, !dbg !11
+  %104 = load double, ptr %103, align 8, !dbg !11
+  %105 = fmul double %101, %104, !dbg !11
+  %106 = fadd double %105, 0.000000e+00, !dbg !11
+  %107 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %108 = getelementptr double, ptr %107, i64 5, !dbg !11
+  store double %106, ptr %108, align 8, !dbg !11
+  %109 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %110 = getelementptr double, ptr %109, i64 1, !dbg !11
+  %111 = load double, ptr %110, align 8, !dbg !11
+  %112 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %113 = getelementptr double, ptr %112, i64 0, !dbg !11
+  %114 = load double, ptr %113, align 8, !dbg !11
+  %115 = fmul double %111, %114, !dbg !11
+  %116 = fadd double %115, 0.000000e+00, !dbg !11
+  %117 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %118 = getelementptr double, ptr %117, i64 6, !dbg !11
+  store double %116, ptr %118, align 8, !dbg !11
+  %119 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %120 = getelementptr double, ptr %119, i64 1, !dbg !11
+  %121 = load double, ptr %120, align 8, !dbg !11
+  %122 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %123 = getelementptr double, ptr %122, i64 1, !dbg !11
+  %124 = load double, ptr %123, align 8, !dbg !11
+  %125 = fmul double %121, %124, !dbg !11
+  %126 = fadd double %125, 0.000000e+00, !dbg !11
+  %127 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %128 = getelementptr double, ptr %127, i64 7, !dbg !11
+  store double %126, ptr %128, align 8, !dbg !11
+  %129 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %130 = getelementptr double, ptr %129, i64 1, !dbg !11
+  %131 = load double, ptr %130, align 8, !dbg !11
+  %132 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %133 = getelementptr double, ptr %132, i64 2, !dbg !11
+  %134 = load double, ptr %133, align 8, !dbg !11
+  %135 = fmul double %131, %134, !dbg !11
+  %136 = fadd double %135, 0.000000e+00, !dbg !11
+  %137 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %138 = getelementptr double, ptr %137, i64 8, !dbg !11
+  store double %136, ptr %138, align 8, !dbg !11
+  %139 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %140 = getelementptr double, ptr %139, i64 1, !dbg !11
+  %141 = load double, ptr %140, align 8, !dbg !11
+  %142 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %143 = getelementptr double, ptr %142, i64 3, !dbg !11
+  %144 = load double, ptr %143, align 8, !dbg !11
+  %145 = fmul double %141, %144, !dbg !11
+  %146 = fadd double %145, 0.000000e+00, !dbg !11
+  %147 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %148 = getelementptr double, ptr %147, i64 9, !dbg !11
+  store double %146, ptr %148, align 8, !dbg !11
+  %149 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %150 = getelementptr double, ptr %149, i64 1, !dbg !11
+  %151 = load double, ptr %150, align 8, !dbg !11
+  %152 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %153 = getelementptr double, ptr %152, i64 4, !dbg !11
+  %154 = load double, ptr %153, align 8, !dbg !11
+  %155 = fmul double %151, %154, !dbg !11
+  %156 = fadd double %155, 0.000000e+00, !dbg !11
+  %157 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %158 = getelementptr double, ptr %157, i64 10, !dbg !11
+  store double %156, ptr %158, align 8, !dbg !11
+  %159 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %160 = getelementptr double, ptr %159, i64 1, !dbg !11
+  %161 = load double, ptr %160, align 8, !dbg !11
+  %162 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %163 = getelementptr double, ptr %162, i64 5, !dbg !11
+  %164 = load double, ptr %163, align 8, !dbg !11
+  %165 = fmul double %161, %164, !dbg !11
+  %166 = fadd double %165, 0.000000e+00, !dbg !11
+  %167 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %168 = getelementptr double, ptr %167, i64 11, !dbg !11
+  store double %166, ptr %168, align 8, !dbg !11
+  %169 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %170 = getelementptr double, ptr %169, i64 2, !dbg !11
+  %171 = load double, ptr %170, align 8, !dbg !11
+  %172 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %173 = getelementptr double, ptr %172, i64 0, !dbg !11
+  %174 = load double, ptr %173, align 8, !dbg !11
+  %175 = fmul double %171, %174, !dbg !11
+  %176 = fadd double %175, 0.000000e+00, !dbg !11
+  %177 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %178 = getelementptr double, ptr %177, i64 12, !dbg !11
+  store double %176, ptr %178, align 8, !dbg !11
+  %179 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %180 = getelementptr double, ptr %179, i64 2, !dbg !11
+  %181 = load double, ptr %180, align 8, !dbg !11
+  %182 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %183 = getelementptr double, ptr %182, i64 1, !dbg !11
+  %184 = load double, ptr %183, align 8, !dbg !11
+  %185 = fmul double %181, %184, !dbg !11
+  %186 = fadd double %185, 0.000000e+00, !dbg !11
+  %187 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %188 = getelementptr double, ptr %187, i64 13, !dbg !11
+  store double %186, ptr %188, align 8, !dbg !11
+  %189 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %190 = getelementptr double, ptr %189, i64 2, !dbg !11
+  %191 = load double, ptr %190, align 8, !dbg !11
+  %192 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %193 = getelementptr double, ptr %192, i64 2, !dbg !11
+  %194 = load double, ptr %193, align 8, !dbg !11
+  %195 = fmul double %191, %194, !dbg !11
+  %196 = fadd double %195, 0.000000e+00, !dbg !11
+  %197 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %198 = getelementptr double, ptr %197, i64 14, !dbg !11
+  store double %196, ptr %198, align 8, !dbg !11
+  %199 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %200 = getelementptr double, ptr %199, i64 2, !dbg !11
+  %201 = load double, ptr %200, align 8, !dbg !11
+  %202 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %203 = getelementptr double, ptr %202, i64 3, !dbg !11
+  %204 = load double, ptr %203, align 8, !dbg !11
+  %205 = fmul double %201, %204, !dbg !11
+  %206 = fadd double %205, 0.000000e+00, !dbg !11
+  %207 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %208 = getelementptr double, ptr %207, i64 15, !dbg !11
+  store double %206, ptr %208, align 8, !dbg !11
+  %209 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %210 = getelementptr double, ptr %209, i64 2, !dbg !11
+  %211 = load double, ptr %210, align 8, !dbg !11
+  %212 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %213 = getelementptr double, ptr %212, i64 4, !dbg !11
+  %214 = load double, ptr %213, align 8, !dbg !11
+  %215 = fmul double %211, %214, !dbg !11
+  %216 = fadd double %215, 0.000000e+00, !dbg !11
+  %217 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %218 = getelementptr double, ptr %217, i64 16, !dbg !11
+  store double %216, ptr %218, align 8, !dbg !11
+  %219 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %220 = getelementptr double, ptr %219, i64 2, !dbg !11
+  %221 = load double, ptr %220, align 8, !dbg !11
+  %222 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %223 = getelementptr double, ptr %222, i64 5, !dbg !11
+  %224 = load double, ptr %223, align 8, !dbg !11
+  %225 = fmul double %221, %224, !dbg !11
+  %226 = fadd double %225, 0.000000e+00, !dbg !11
+  %227 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %228 = getelementptr double, ptr %227, i64 17, !dbg !11
+  store double %226, ptr %228, align 8, !dbg !11
+  %229 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %230 = getelementptr double, ptr %229, i64 3, !dbg !11
+  %231 = load double, ptr %230, align 8, !dbg !11
+  %232 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %233 = getelementptr double, ptr %232, i64 0, !dbg !11
+  %234 = load double, ptr %233, align 8, !dbg !11
+  %235 = fmul double %231, %234, !dbg !11
+  %236 = fadd double %235, 0.000000e+00, !dbg !11
+  %237 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %238 = getelementptr double, ptr %237, i64 18, !dbg !11
+  store double %236, ptr %238, align 8, !dbg !11
+  %239 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %240 = getelementptr double, ptr %239, i64 3, !dbg !11
+  %241 = load double, ptr %240, align 8, !dbg !11
+  %242 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %243 = getelementptr double, ptr %242, i64 1, !dbg !11
+  %244 = load double, ptr %243, align 8, !dbg !11
+  %245 = fmul double %241, %244, !dbg !11
+  %246 = fadd double %245, 0.000000e+00, !dbg !11
+  %247 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %248 = getelementptr double, ptr %247, i64 19, !dbg !11
+  store double %246, ptr %248, align 8, !dbg !11
+  %249 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %250 = getelementptr double, ptr %249, i64 3, !dbg !11
+  %251 = load double, ptr %250, align 8, !dbg !11
+  %252 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %253 = getelementptr double, ptr %252, i64 2, !dbg !11
+  %254 = load double, ptr %253, align 8, !dbg !11
+  %255 = fmul double %251, %254, !dbg !11
+  %256 = fadd double %255, 0.000000e+00, !dbg !11
+  %257 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %258 = getelementptr double, ptr %257, i64 20, !dbg !11
+  store double %256, ptr %258, align 8, !dbg !11
+  %259 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %260 = getelementptr double, ptr %259, i64 3, !dbg !11
+  %261 = load double, ptr %260, align 8, !dbg !11
+  %262 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %263 = getelementptr double, ptr %262, i64 3, !dbg !11
+  %264 = load double, ptr %263, align 8, !dbg !11
+  %265 = fmul double %261, %264, !dbg !11
+  %266 = fadd double %265, 0.000000e+00, !dbg !11
+  %267 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %268 = getelementptr double, ptr %267, i64 21, !dbg !11
+  store double %266, ptr %268, align 8, !dbg !11
+  %269 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %270 = getelementptr double, ptr %269, i64 3, !dbg !11
+  %271 = load double, ptr %270, align 8, !dbg !11
+  %272 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %273 = getelementptr double, ptr %272, i64 4, !dbg !11
+  %274 = load double, ptr %273, align 8, !dbg !11
+  %275 = fmul double %271, %274, !dbg !11
+  %276 = fadd double %275, 0.000000e+00, !dbg !11
+  %277 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %278 = getelementptr double, ptr %277, i64 22, !dbg !11
+  store double %276, ptr %278, align 8, !dbg !11
+  %279 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %280 = getelementptr double, ptr %279, i64 3, !dbg !11
+  %281 = load double, ptr %280, align 8, !dbg !11
+  %282 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %283 = getelementptr double, ptr %282, i64 5, !dbg !11
+  %284 = load double, ptr %283, align 8, !dbg !11
+  %285 = fmul double %281, %284, !dbg !11
+  %286 = fadd double %285, 0.000000e+00, !dbg !11
+  %287 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %288 = getelementptr double, ptr %287, i64 23, !dbg !11
+  store double %286, ptr %288, align 8, !dbg !11
+  %289 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %290 = getelementptr double, ptr %289, i64 4, !dbg !11
+  %291 = load double, ptr %290, align 8, !dbg !11
+  %292 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %293 = getelementptr double, ptr %292, i64 0, !dbg !11
+  %294 = load double, ptr %293, align 8, !dbg !11
+  %295 = fmul double %291, %294, !dbg !11
+  %296 = fadd double %295, 0.000000e+00, !dbg !11
+  %297 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %298 = getelementptr double, ptr %297, i64 24, !dbg !11
+  store double %296, ptr %298, align 8, !dbg !11
+  %299 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %300 = getelementptr double, ptr %299, i64 4, !dbg !11
+  %301 = load double, ptr %300, align 8, !dbg !11
+  %302 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %303 = getelementptr double, ptr %302, i64 1, !dbg !11
+  %304 = load double, ptr %303, align 8, !dbg !11
+  %305 = fmul double %301, %304, !dbg !11
+  %306 = fadd double %305, 0.000000e+00, !dbg !11
+  %307 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %308 = getelementptr double, ptr %307, i64 25, !dbg !11
+  store double %306, ptr %308, align 8, !dbg !11
+  %309 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %310 = getelementptr double, ptr %309, i64 4, !dbg !11
+  %311 = load double, ptr %310, align 8, !dbg !11
+  %312 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %313 = getelementptr double, ptr %312, i64 2, !dbg !11
+  %314 = load double, ptr %313, align 8, !dbg !11
+  %315 = fmul double %311, %314, !dbg !11
+  %316 = fadd double %315, 0.000000e+00, !dbg !11
+  %317 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %318 = getelementptr double, ptr %317, i64 26, !dbg !11
+  store double %316, ptr %318, align 8, !dbg !11
+  %319 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %320 = getelementptr double, ptr %319, i64 4, !dbg !11
+  %321 = load double, ptr %320, align 8, !dbg !11
+  %322 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %323 = getelementptr double, ptr %322, i64 3, !dbg !11
+  %324 = load double, ptr %323, align 8, !dbg !11
+  %325 = fmul double %321, %324, !dbg !11
+  %326 = fadd double %325, 0.000000e+00, !dbg !11
+  %327 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %328 = getelementptr double, ptr %327, i64 27, !dbg !11
+  store double %326, ptr %328, align 8, !dbg !11
+  %329 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %330 = getelementptr double, ptr %329, i64 4, !dbg !11
+  %331 = load double, ptr %330, align 8, !dbg !11
+  %332 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %333 = getelementptr double, ptr %332, i64 4, !dbg !11
+  %334 = load double, ptr %333, align 8, !dbg !11
+  %335 = fmul double %331, %334, !dbg !11
+  %336 = fadd double %335, 0.000000e+00, !dbg !11
+  %337 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %338 = getelementptr double, ptr %337, i64 28, !dbg !11
+  store double %336, ptr %338, align 8, !dbg !11
+  %339 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %340 = getelementptr double, ptr %339, i64 4, !dbg !11
+  %341 = load double, ptr %340, align 8, !dbg !11
+  %342 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %343 = getelementptr double, ptr %342, i64 5, !dbg !11
+  %344 = load double, ptr %343, align 8, !dbg !11
+  %345 = fmul double %341, %344, !dbg !11
+  %346 = fadd double %345, 0.000000e+00, !dbg !11
+  %347 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %348 = getelementptr double, ptr %347, i64 29, !dbg !11
+  store double %346, ptr %348, align 8, !dbg !11
+  %349 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %350 = getelementptr double, ptr %349, i64 5, !dbg !11
+  %351 = load double, ptr %350, align 8, !dbg !11
+  %352 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %353 = getelementptr double, ptr %352, i64 0, !dbg !11
+  %354 = load double, ptr %353, align 8, !dbg !11
+  %355 = fmul double %351, %354, !dbg !11
+  %356 = fadd double %355, 0.000000e+00, !dbg !11
+  %357 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %358 = getelementptr double, ptr %357, i64 30, !dbg !11
+  store double %356, ptr %358, align 8, !dbg !11
+  %359 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %360 = getelementptr double, ptr %359, i64 5, !dbg !11
+  %361 = load double, ptr %360, align 8, !dbg !11
+  %362 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %363 = getelementptr double, ptr %362, i64 1, !dbg !11
+  %364 = load double, ptr %363, align 8, !dbg !11
+  %365 = fmul double %361, %364, !dbg !11
+  %366 = fadd double %365, 0.000000e+00, !dbg !11
+  %367 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %368 = getelementptr double, ptr %367, i64 31, !dbg !11
+  store double %366, ptr %368, align 8, !dbg !11
+  %369 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %370 = getelementptr double, ptr %369, i64 5, !dbg !11
+  %371 = load double, ptr %370, align 8, !dbg !11
+  %372 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %373 = getelementptr double, ptr %372, i64 2, !dbg !11
+  %374 = load double, ptr %373, align 8, !dbg !11
+  %375 = fmul double %371, %374, !dbg !11
+  %376 = fadd double %375, 0.000000e+00, !dbg !11
+  %377 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %378 = getelementptr double, ptr %377, i64 32, !dbg !11
+  store double %376, ptr %378, align 8, !dbg !11
+  %379 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %380 = getelementptr double, ptr %379, i64 5, !dbg !11
+  %381 = load double, ptr %380, align 8, !dbg !11
+  %382 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %383 = getelementptr double, ptr %382, i64 3, !dbg !11
+  %384 = load double, ptr %383, align 8, !dbg !11
+  %385 = fmul double %381, %384, !dbg !11
+  %386 = fadd double %385, 0.000000e+00, !dbg !11
+  %387 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %388 = getelementptr double, ptr %387, i64 33, !dbg !11
+  store double %386, ptr %388, align 8, !dbg !11
+  %389 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %390 = getelementptr double, ptr %389, i64 5, !dbg !11
+  %391 = load double, ptr %390, align 8, !dbg !11
+  %392 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %393 = getelementptr double, ptr %392, i64 4, !dbg !11
+  %394 = load double, ptr %393, align 8, !dbg !11
+  %395 = fmul double %391, %394, !dbg !11
+  %396 = fadd double %395, 0.000000e+00, !dbg !11
+  %397 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %398 = getelementptr double, ptr %397, i64 34, !dbg !11
+  store double %396, ptr %398, align 8, !dbg !11
+  %399 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %400 = getelementptr double, ptr %399, i64 5, !dbg !11
+  %401 = load double, ptr %400, align 8, !dbg !11
+  %402 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %403 = getelementptr double, ptr %402, i64 5, !dbg !11
+  %404 = load double, ptr %403, align 8, !dbg !11
+  %405 = fmul double %401, %404, !dbg !11
+  %406 = fadd double %405, 0.000000e+00, !dbg !11
+  %407 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %408 = getelementptr double, ptr %407, i64 35, !dbg !11
+  store double %406, ptr %408, align 8, !dbg !11
+  br label %409, !dbg !12
+
+409:                                              ; preds = %424, %0
+  %410 = phi i64 [ 0, %0 ], [ %426, %424 ]
+  %411 = icmp slt i64 %410, 6, !dbg !12
+  br i1 %411, label %412, label %427, !dbg !12
+
+412:                                              ; preds = %409
+  br label %413, !dbg !12
+
+413:                                              ; preds = %416, %412
+  %414 = phi i64 [ 0, %412 ], [ %423, %416 ]
+  %415 = icmp slt i64 %414, 6, !dbg !12
+  br i1 %415, label %416, label %424, !dbg !12
+
+416:                                              ; preds = %413
+  %417 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !12
+  %418 = mul i64 %410, 6, !dbg !12
+  %419 = add i64 %418, %414, !dbg !12
+  %420 = getelementptr double, ptr %417, i64 %419, !dbg !12
+  %421 = load double, ptr %420, align 8, !dbg !12
+  %422 = call i32 (ptr, ...) @printf(ptr @frmt_spec, double %421), !dbg !12
+  %423 = add i64 %414, 1, !dbg !12
+  br label %413, !dbg !12
+
+424:                                              ; preds = %413
+  %425 = call i32 (ptr, ...) @printf(ptr @nl), !dbg !12
+  %426 = add i64 %410, 1, !dbg !12
+  br label %409, !dbg !12
+
+427:                                              ; preds = %409
+  %428 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 0, !dbg !10
+  call void @free(ptr %428), !dbg !10
+  %429 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 0, !dbg !9
+  call void @free(ptr %429), !dbg !9
+  ret void, !dbg !13
+}
+
+!llvm.module.flags = !{!0}
+!llvm.dbg.cu = !{!1}
+
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2, producer: "MLIR", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
+!2 = !DIFile(filename: "matmul2.mlir", directory: "")
+!3 = !DISubprogram(name: "free", linkageName: "free", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!4 = !DISubroutineType(cc: DW_CC_normal, types: !5)
+!5 = !{}
+!6 = !DISubprogram(name: "printf", linkageName: "printf", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!7 = !DISubprogram(name: "malloc", linkageName: "malloc", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!8 = distinct !DISubprogram(name: "main", linkageName: "main", scope: !2, file: !2, line: 2, type: !4, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
+!9 = !DILocation(line: 6, column: 10, scope: !8)
+!10 = !DILocation(line: 4, column: 10, scope: !8)
+!11 = !DILocation(line: 7, column: 10, scope: !8)
+!12 = !DILocation(line: 8, column: 5, scope: !8)
+!13 = !DILocation(line: 9, column: 5, scope: !8)
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.mlir b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
new file mode 100644
index 0000000000000..e4145c8bb5dfa
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
@@ -0,0 +1,11 @@
+module {
+  toy.func @main() {
+    %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+    %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+    %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+    %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+    %4 = toy.matmul %1, %3 : tensor<6x1xf64>, tensor<1x6xf64> -> tensor<*xf64>
+    toy.print %4 : tensor<*xf64>
+    toy.return
+  }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.toy b/mlir/test/Examples/Toy/Matmul/matmul2.toy
new file mode 100644
index 0000000000000..6c37cc1af1f6a
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.toy
@@ -0,0 +1,6 @@
+def main() {
+  var a<6, 1> = [1, 2, 3, 4, 5, 6];
+  var b<1, 6> = [1, 2, 3, 4, 5, 6];
+  var c = matmul(a,b);
+  print(c);
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir b/mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir
new file mode 100644
index 0000000000000..46ac739c241b4
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir
@@ -0,0 +1,55 @@
+module {
+  func.func @main() {
+    %cst = arith.constant 0.000000e+00 : f64
+    %cst_0 = arith.constant 6.000000e+00 : f64
+    %cst_1 = arith.constant 5.000000e+00 : f64
+    %cst_2 = arith.constant 4.000000e+00 : f64
+    %cst_3 = arith.constant 3.000000e+00 : f64
+    %cst_4 = arith.constant 2.000000e+00 : f64
+    %cst_5 = arith.constant 1.000000e+00 : f64
+    %alloc = memref.alloc() : memref<1x6xf64>
+    %alloc_6 = memref.alloc() : memref<6x1xf64>
+    affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+    affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+    affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+    affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+    affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+    affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+    affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+    affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+    affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+    affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+    affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+    affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+    %alloc_7 = memref.alloc() : memref<1x1xf64>
+    %0 = affine.load %alloc[0, 0] : memref<1x6xf64>
+    %1 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+    %2 = arith.mulf %0, %1 : f64
+    %3 = arith.addf %2, %cst : f64
+    %4 = affine.load %alloc[0, 1] : memref<1x6xf64>
+    %5 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+    %6 = arith.mulf %4, %5 : f64
+    %7 = arith.addf %3, %6 : f64
+    %8 = affine.load %alloc[0, 2] : memref<1x6xf64>
+    %9 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+    %10 = arith.mulf %8, %9 : f64
+    %11 = arith.addf %7, %10 : f64
+    %12 = affine.load %alloc[0, 3] : memref<1x6xf64>
+    %13 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+    %14 = arith.mulf %12, %13 : f64
+    %15 = arith.addf %11, %14 : f64
+    %16 = affine.load %alloc[0, 4] : memref<1x6xf64>
+    %17 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+    %18 = arith.mulf %16, %17 : f64
+    %19 = arith.addf %15, %18 : f64
+    %20 = affine.load %alloc[0, 5] : memref<1x6xf64>
+    %21 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+    %22 = arith.mulf %20, %21 : f64
+    %23 = arith.addf %19, %22 : f64
+    affine.store %23, %alloc_7[0, 0] : memref<1x1xf64>
+    toy.print %alloc_7 : memref<1x1xf64>
+    memref.dealloc %alloc_6 : memref<6x1xf64>
+    memref.dealloc %alloc : memref<1x6xf64>
+    return
+  }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.llvm b/mlir/test/Examples/Toy/Matmul/matmul3.llvm
new file mode 100644
index 0000000000000..d61f8da3b2101
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.llvm
@@ -0,0 +1,181 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at nl = internal constant [2 x i8] c"\0A\00"
+ at frmt_spec = internal constant [4 x i8] c"%f \00"
+
+declare !dbg !3 void @free(ptr)
+
+declare !dbg !6 i32 @printf(ptr, ...)
+
+declare !dbg !7 ptr @malloc(i64)
+
+define void @main() !dbg !8 {
+  %1 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !9
+  %2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %1, 0, !dbg !9
+  %3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %2, ptr %1, 1, !dbg !9
+  %4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, i64 0, 2, !dbg !9
+  %5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %4, i64 1, 3, 0, !dbg !9
+  %6 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %5, i64 6, 3, 1, !dbg !9
+  %7 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %6, i64 6, 4, 0, !dbg !9
+  %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %7, i64 1, 4, 1, !dbg !9
+  %9 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !10
+  %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %9, 0, !dbg !10
+  %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, ptr %9, 1, !dbg !10
+  %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 0, 2, !dbg !10
+  %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 6, 3, 0, !dbg !10
+  %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 1, 3, 1, !dbg !10
+  %15 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, i64 1, 4, 0, !dbg !10
+  %16 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %15, i64 1, 4, 1, !dbg !10
+  %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %18 = getelementptr double, ptr %17, i64 0, !dbg !10
+  store double 1.000000e+00, ptr %18, align 8, !dbg !10
+  %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %20 = getelementptr double, ptr %19, i64 1, !dbg !10
+  store double 2.000000e+00, ptr %20, align 8, !dbg !10
+  %21 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %22 = getelementptr double, ptr %21, i64 2, !dbg !10
+  store double 3.000000e+00, ptr %22, align 8, !dbg !10
+  %23 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %24 = getelementptr double, ptr %23, i64 3, !dbg !10
+  store double 4.000000e+00, ptr %24, align 8, !dbg !10
+  %25 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %26 = getelementptr double, ptr %25, i64 4, !dbg !10
+  store double 5.000000e+00, ptr %26, align 8, !dbg !10
+  %27 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+  %28 = getelementptr double, ptr %27, i64 5, !dbg !10
+  store double 6.000000e+00, ptr %28, align 8, !dbg !10
+  %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %30 = getelementptr double, ptr %29, i64 0, !dbg !9
+  store double 1.000000e+00, ptr %30, align 8, !dbg !9
+  %31 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %32 = getelementptr double, ptr %31, i64 1, !dbg !9
+  store double 2.000000e+00, ptr %32, align 8, !dbg !9
+  %33 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %34 = getelementptr double, ptr %33, i64 2, !dbg !9
+  store double 3.000000e+00, ptr %34, align 8, !dbg !9
+  %35 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %36 = getelementptr double, ptr %35, i64 3, !dbg !9
+  store double 4.000000e+00, ptr %36, align 8, !dbg !9
+  %37 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %38 = getelementptr double, ptr %37, i64 4, !dbg !9
+  store double 5.000000e+00, ptr %38, align 8, !dbg !9
+  %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+  %40 = getelementptr double, ptr %39, i64 5, !dbg !9
+  store double 6.000000e+00, ptr %40, align 8, !dbg !9
+  %41 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 1) to i64)), !dbg !11
+  %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %41, 0, !dbg !11
+  %43 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, ptr %41, 1, !dbg !11
+  %44 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %43, i64 0, 2, !dbg !11
+  %45 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %44, i64 1, 3, 0, !dbg !11
+  %46 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %45, i64 1, 3, 1, !dbg !11
+  %47 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %46, i64 1, 4, 0, !dbg !11
+  %48 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %47, i64 1, 4, 1, !dbg !11
+  %49 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %50 = getelementptr double, ptr %49, i64 0, !dbg !11
+  %51 = load double, ptr %50, align 8, !dbg !11
+  %52 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %53 = getelementptr double, ptr %52, i64 0, !dbg !11
+  %54 = load double, ptr %53, align 8, !dbg !11
+  %55 = fmul double %51, %54, !dbg !11
+  %56 = fadd double %55, 0.000000e+00, !dbg !11
+  %57 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %58 = getelementptr double, ptr %57, i64 1, !dbg !11
+  %59 = load double, ptr %58, align 8, !dbg !11
+  %60 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %61 = getelementptr double, ptr %60, i64 1, !dbg !11
+  %62 = load double, ptr %61, align 8, !dbg !11
+  %63 = fmul double %59, %62, !dbg !11
+  %64 = fadd double %56, %63, !dbg !11
+  %65 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %66 = getelementptr double, ptr %65, i64 2, !dbg !11
+  %67 = load double, ptr %66, align 8, !dbg !11
+  %68 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %69 = getelementptr double, ptr %68, i64 2, !dbg !11
+  %70 = load double, ptr %69, align 8, !dbg !11
+  %71 = fmul double %67, %70, !dbg !11
+  %72 = fadd double %64, %71, !dbg !11
+  %73 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %74 = getelementptr double, ptr %73, i64 3, !dbg !11
+  %75 = load double, ptr %74, align 8, !dbg !11
+  %76 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %77 = getelementptr double, ptr %76, i64 3, !dbg !11
+  %78 = load double, ptr %77, align 8, !dbg !11
+  %79 = fmul double %75, %78, !dbg !11
+  %80 = fadd double %72, %79, !dbg !11
+  %81 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %82 = getelementptr double, ptr %81, i64 4, !dbg !11
+  %83 = load double, ptr %82, align 8, !dbg !11
+  %84 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %85 = getelementptr double, ptr %84, i64 4, !dbg !11
+  %86 = load double, ptr %85, align 8, !dbg !11
+  %87 = fmul double %83, %86, !dbg !11
+  %88 = fadd double %80, %87, !dbg !11
+  %89 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+  %90 = getelementptr double, ptr %89, i64 5, !dbg !11
+  %91 = load double, ptr %90, align 8, !dbg !11
+  %92 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+  %93 = getelementptr double, ptr %92, i64 5, !dbg !11
+  %94 = load double, ptr %93, align 8, !dbg !11
+  %95 = fmul double %91, %94, !dbg !11
+  %96 = fadd double %88, %95, !dbg !11
+  %97 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+  %98 = getelementptr double, ptr %97, i64 0, !dbg !11
+  store double %96, ptr %98, align 8, !dbg !11
+  br label %99, !dbg !12
+
+99:                                               ; preds = %113, %0
+  %100 = phi i64 [ 0, %0 ], [ %115, %113 ]
+  %101 = icmp slt i64 %100, 1, !dbg !12
+  br i1 %101, label %102, label %116, !dbg !12
+
+102:                                              ; preds = %99
+  br label %103, !dbg !12
+
+103:                                              ; preds = %106, %102
+  %104 = phi i64 [ 0, %102 ], [ %112, %106 ]
+  %105 = icmp slt i64 %104, 1, !dbg !12
+  br i1 %105, label %106, label %113, !dbg !12
+
+106:                                              ; preds = %103
+  %107 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !12
+  %108 = add i64 %100, %104, !dbg !12
+  %109 = getelementptr double, ptr %107, i64 %108, !dbg !12
+  %110 = load double, ptr %109, align 8, !dbg !12
+  %111 = call i32 (ptr, ...) @printf(ptr @frmt_spec, double %110), !dbg !12
+  %112 = add i64 %104, 1, !dbg !12
+  br label %103, !dbg !12
+
+113:                                              ; preds = %103
+  %114 = call i32 (ptr, ...) @printf(ptr @nl), !dbg !12
+  %115 = add i64 %100, 1, !dbg !12
+  br label %99, !dbg !12
+
+116:                                              ; preds = %99
+  %117 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 0, !dbg !10
+  call void @free(ptr %117), !dbg !10
+  %118 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 0, !dbg !9
+  call void @free(ptr %118), !dbg !9
+  ret void, !dbg !13
+}
+
+!llvm.module.flags = !{!0}
+!llvm.dbg.cu = !{!1}
+
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2, producer: "MLIR", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
+!2 = !DIFile(filename: "matmul3.mlir", directory: "")
+!3 = !DISubprogram(name: "free", linkageName: "free", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!4 = !DISubroutineType(cc: DW_CC_normal, types: !5)
+!5 = !{}
+!6 = !DISubprogram(name: "printf", linkageName: "printf", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!7 = !DISubprogram(name: "malloc", linkageName: "malloc", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!8 = distinct !DISubprogram(name: "main", linkageName: "main", scope: !2, file: !2, line: 2, type: !4, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
+!9 = !DILocation(line: 6, column: 10, scope: !8)
+!10 = !DILocation(line: 4, column: 10, scope: !8)
+!11 = !DILocation(line: 7, column: 10, scope: !8)
+!12 = !DILocation(line: 8, column: 5, scope: !8)
+!13 = !DILocation(line: 9, column: 5, scope: !8)
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.mlir b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
new file mode 100644
index 0000000000000..5e8fd5eba2398
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
@@ -0,0 +1,11 @@
+module {
+  toy.func @main() {
+    %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+    %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+    %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+    %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+    %4 = toy.matmul %3, %1 : tensor<1x6xf64>, tensor<6x1xf64> -> tensor<*xf64>
+    toy.print %4 : tensor<*xf64>
+    toy.return
+  }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.toy b/mlir/test/Examples/Toy/Matmul/matmul3.toy
new file mode 100644
index 0000000000000..d268735176e7f
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.toy
@@ -0,0 +1,7 @@
+def main() {
+  var a<6, 1> = [1, 2, 3, 4, 5, 6];
+  var b<1, 6> = [1, 2, 3, 4, 5, 6];
+  var c = matmul(b,a);
+  print(c);
+}
+

>From 8c9165319f8c9bb4e90fada68d90077cb60ba30f Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Fri, 21 Jun 2024 04:56:45 +0000
Subject: [PATCH 04/13] Feat: Convert lowering of toy matmul to affine loops

---
 .../toy/Ch7/mlir/LowerToAffineLoops.cpp       | 64 +++++++++----------
 1 file changed, 32 insertions(+), 32 deletions(-)

diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 5c3e6a552855b..e5cd3168e77cb 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -331,46 +331,46 @@ struct MatmulOpLowering : public ConversionPattern {
     toy::MatmulOpAdaptor matmulAdaptor(operands);
     Value lhs = matmulAdaptor.getLhs();
     Value rhs = matmulAdaptor.getRhs();
-    
+
     auto lhsType = dyn_cast<MemRefType>(lhs.getType());
     auto rhsType = dyn_cast<MemRefType>(rhs.getType());
-    if (!lhsType || !rhsType) {
-      return failure();
-    }
+    if (!lhsType || !rhsType) return failure();
+
     int64_t M = lhsType.getShape()[0];
     int64_t N = rhsType.getShape()[1];
     int64_t K = lhsType.getShape()[1];
     auto elementType = lhsType.getElementType();
     auto resultType = MemRefType::get({M, N}, elementType);
     Value result = rewriter.create<memref::AllocOp>(loc, resultType);
-     for (int64_t i = 0; i < M; ++i) {
-      for (int64_t j = 0; j < N; ++j) {
-        // Initialize the sum to zero.
-        Value sum = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0.0));
-
-        for (int64_t k = 0; k < K; ++k) {
-          // Load lhs[i, k] and rhs[k, j].
-          Value lhsVal = rewriter.create<affine::AffineLoadOp>(loc, lhs, ValueRange{
-            rewriter.create<arith::ConstantIndexOp>(loc, i),
-            rewriter.create<arith::ConstantIndexOp>(loc, k)
-          });
-          Value rhsVal = rewriter.create<affine::AffineLoadOp>(loc, rhs, ValueRange{
-            rewriter.create<arith::ConstantIndexOp>(loc, k),
-            rewriter.create<arith::ConstantIndexOp>(loc, j)
-          });
-
-          // Perform the multiplication and accumulate the result.
-          Value product = rewriter.create<arith::MulFOp>(loc, lhsVal, rhsVal);
-          sum = rewriter.create<arith::AddFOp>(loc, sum, product);
-        }
-
-        // Store the computed value into the result matrix.
-        rewriter.create<affine::AffineStoreOp>(loc, sum, result, ValueRange{
-          rewriter.create<arith::ConstantIndexOp>(loc, i),
-          rewriter.create<arith::ConstantIndexOp>(loc, j)
-        });
-      }
-    }
+
+    rewriter.setInsertionPoint(op);
+    Value zero = rewriter.create<arith::ConstantOp>(op->getLoc(), rewriter.getF64FloatAttr(0.0));
+
+    auto outerLoop = rewriter.create<affine::AffineForOp>(op->getLoc(), 0, M, 1);
+    rewriter.setInsertionPointToStart(outerLoop.getBody());
+    Value i = outerLoop.getInductionVar();
+
+    auto middleLoop = rewriter.create<affine::AffineForOp>(op->getLoc(),0,N,1);
+    rewriter.setInsertionPointToStart(middleLoop.getBody());
+    Value j = middleLoop.getInductionVar();
+
+    rewriter.create<affine::AffineStoreOp>(op->getLoc(),zero,result,ValueRange{i,j});
+    Value sum = zero;
+    Value total_sum = zero;
+
+    auto innerLoop = rewriter.create<affine::AffineForOp>(op->getLoc(),0,K,1);
+    rewriter.setInsertionPointToStart(innerLoop.getBody());
+    Value k = innerLoop.getInductionVar();
+
+    Value lhsValue = rewriter.create<affine::AffineLoadOp>(op->getLoc(),lhs,ValueRange{i,k});
+    Value rhsValue = rewriter.create<affine::AffineLoadOp>(op->getLoc(),rhs,ValueRange{k,j});
+    Value product = rewriter.create<arith::MulFOp>(op->getLoc(),lhsValue,rhsValue);
+    sum = rewriter.create<arith::AddFOp>(op->getLoc(),sum,product);
+    Value currentValue = rewriter.create<affine::AffineLoadOp>(op->getLoc(), result, ValueRange{i, j});
+    total_sum = rewriter.create<arith::AddFOp>(op->getLoc(),sum,currentValue);
+    rewriter.setInsertionPoint(innerLoop.getBody()->getTerminator());
+    rewriter.create<affine::AffineStoreOp>(op->getLoc(),total_sum,result,ValueRange{i,j});
+
     rewriter.replaceOp(op, result);
     return success();
   }

>From 239f031ec9d499d4fa794e01b8ce85e9955cfe6b Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Tue, 25 Jun 2024 13:31:51 +0000
Subject: [PATCH 05/13] Test: Add FileCheck commands for test cases

---
 mlir/test/Examples/Toy/Matmul/matmul1.toy  | 15 +++++++-
 mlir/test/Examples/Toy/Matmul/matmul2.mlir | 45 ++++++++++++++++++++++
 mlir/test/Examples/Toy/Matmul/matmul2.toy  | 12 ++++++
 mlir/test/Examples/Toy/Matmul/matmul3.mlir | 45 ++++++++++++++++++++++
 mlir/test/Examples/Toy/Matmul/matmul3.toy  | 11 ++++++
 5 files changed, 127 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.toy b/mlir/test/Examples/Toy/Matmul/matmul1.toy
index 823608a6a2f02..136985d26a073 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul1.toy
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.toy
@@ -1,6 +1,19 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
+# Function that performs matrix multiplication
 def main() {
   var a<2, 3> = [1, 2, 3, 4, 5, 6];
   var b<3, 2> = [1, 2, 3, 4, 5, 6];
   var c = matmul(a,b);
   print(c);
-}
\ No newline at end of file
+}
+
+#CHECK-LABEL: toy.func @main() {
+# CHECK-NEXT:    %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT:    %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<2x3xf64>
+# CHECK-NEXT:    %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT:    %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64>
+# CHECK-NEXT:    %4 = toy.matmul %1, %3 : tensor<2x3xf64>, tensor<3x2xf64> -> tensor<*xf64>
+# CHECK-NEXT:    toy.print %4 : tensor<*xf64>
+# CHECK-NEXT:    toy.return
+# CHECK-NEXT:  }
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.mlir b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
index e4145c8bb5dfa..a1a605817b575 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul2.mlir
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
@@ -1,3 +1,5 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
+
 module {
   toy.func @main() {
     %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@@ -9,3 +11,46 @@ module {
     toy.return
   }
 }
+
+//CHECK-LABEL: func @main
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
+//CHECK-NEXT: %cst_0 = arith.constant 6.000000e+00 : f64
+//CHECK-NEXT: %cst_1 = arith.constant 5.000000e+00 : f64
+//CHECK-NEXT: %cst_2 = arith.constant 4.000000e+00 : f64
+//CHECK-NEXT: %cst_3 = arith.constant 3.000000e+00 : f64
+//CHECK-NEXT: %cst_4 = arith.constant 2.000000e+00 : f64
+//CHECK-NEXT: %cst_5 = arith.constant 1.000000e+00 : f64
+//CHECK-NEXT: %alloc = memref.alloc() : memref<1x6xf64>
+//CHECK-NEXT: %alloc_6 = memref.alloc() : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+//CHECK-NEXT: %alloc_7 = memref.alloc() : memref<6x6xf64>
+//CHECK-NEXT: affine.for %arg0 = 0 to 6 {
+//CHECK-NEXT:   affine.for %arg1 = 0 to 6 {
+//CHECK-NEXT:     affine.store %cst, %alloc_7[%arg0, %arg1] : memref<6x6xf64>
+//CHECK-NEXT:     affine.for %arg2 = 0 to 1 {
+//CHECK-NEXT:       %0 = affine.load %alloc_6[%arg0, %arg2] : memref<6x1xf64>
+//CHECK-NEXT:       %1 = affine.load %alloc[%arg2, %arg1] : memref<1x6xf64>
+//CHECK-NEXT:       %2 = arith.mulf %0, %1 : f64
+//CHECK-NEXT:       %3 = arith.addf %2, %cst : f64
+//CHECK-NEXT:       %4 = affine.load %alloc_7[%arg0, %arg1] : memref<6x6xf64>
+//CHECK-NEXT:       %5 = arith.addf %3, %4 : f64
+//CHECK-NEXT:       affine.store %5, %alloc_7[%arg0, %arg1] : memref<6x6xf64>
+//CHECK-NEXT:     }
+//CHECK-NEXT:   }
+//CHECK-NEXT: }
+//CHECK-NEXT: toy.print %alloc_7 : memref<6x6xf64>
+//CHECK-NEXT: memref.dealloc %alloc_6 : memref<6x1xf64>
+//CHECK-NEXT: memref.dealloc %alloc : memref<1x6xf64>
+//CHECK-NEXT: return
+//CHECK-NEXT: }
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.toy b/mlir/test/Examples/Toy/Matmul/matmul2.toy
index 6c37cc1af1f6a..d72c35be4c692 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul2.toy
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.toy
@@ -1,6 +1,18 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
 def main() {
   var a<6, 1> = [1, 2, 3, 4, 5, 6];
   var b<1, 6> = [1, 2, 3, 4, 5, 6];
   var c = matmul(a,b);
   print(c);
 }
+
+#CHECK-LABEL: toy.func @main() {
+# CHECK-NEXT:    %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT:    %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+# CHECK-NEXT:    %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT:    %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+# CHECK-NEXT:    %4 = toy.matmul %1, %3 : tensor<6x1xf64>, tensor<1x6xf64> -> tensor<*xf64>
+# CHECK-NEXT:    toy.print %4 : tensor<*xf64>
+# CHECK-NEXT:    toy.return
+# CHECK-NEXT:  }
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.mlir b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
index 5e8fd5eba2398..8573534375f8d 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul3.mlir
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
@@ -1,3 +1,4 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
 module {
   toy.func @main() {
     %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@@ -9,3 +10,47 @@ module {
     toy.return
   }
 }
+
+//CHECK-LABEL: func @main
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
+//CHECK-NEXT: %cst_0 = arith.constant 6.000000e+00 : f64
+//CHECK-NEXT: %cst_1 = arith.constant 5.000000e+00 : f64
+//CHECK-NEXT: %cst_2 = arith.constant 4.000000e+00 : f64
+//CHECK-NEXT: %cst_3 = arith.constant 3.000000e+00 : f64
+//CHECK-NEXT: %cst_4 = arith.constant 2.000000e+00 : f64
+//CHECK-NEXT: %cst_5 = arith.constant 1.000000e+00 : f64
+//CHECK-NEXT: %alloc = memref.alloc() : memref<1x6xf64>
+//CHECK-NEXT: %alloc_6 = memref.alloc() : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+//CHECK-NEXT: %alloc_7 = memref.alloc() : memref<1x1xf64>
+//CHECK-NEXT: affine.for %arg0 = 0 to 1 {
+//CHECK-NEXT:   affine.for %arg1 = 0 to 1 {
+//CHECK-NEXT:     affine.store %cst, %alloc_7[%arg0, %arg1] : memref<1x1xf64>
+//CHECK-NEXT:     affine.for %arg2 = 0 to 6 {
+//CHECK-NEXT:       %0 = affine.load %alloc[%arg0, %arg2] : memref<1x6xf64>
+//CHECK-NEXT:       %1 = affine.load %alloc_6[%arg2, %arg1] : memref<6x1xf64>
+//CHECK-NEXT:       %2 = arith.mulf %0, %1 : f64
+//CHECK-NEXT:       %3 = arith.addf %2, %cst : f64
+//CHECK-NEXT:       %4 = affine.load %alloc_7[%arg0, %arg1] : memref<1x1xf64>
+//CHECK-NEXT:       %5 = arith.addf %3, %4 : f64
+//CHECK-NEXT:       affine.store %5, %alloc_7[%arg0, %arg1] : memref<1x1xf64>
+//CHECK-NEXT:     }
+//CHECK-NEXT:   }
+//CHECK-NEXT: }
+//CHECK-NEXT: toy.print %alloc_7 : memref<1x1xf64>
+//CHECK-NEXT: memref.dealloc %alloc_6 : memref<6x1xf64>
+//CHECK-NEXT: memref.dealloc %alloc : memref<1x6xf64>
+//CHECK-NEXT: return
+//CHECK-NEXT: }
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.toy b/mlir/test/Examples/Toy/Matmul/matmul3.toy
index d268735176e7f..a56c6e31cacb8 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul3.toy
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.toy
@@ -1,3 +1,5 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
 def main() {
   var a<6, 1> = [1, 2, 3, 4, 5, 6];
   var b<1, 6> = [1, 2, 3, 4, 5, 6];
@@ -5,3 +7,12 @@ def main() {
   print(c);
 }
 
+#CHECK-LABEL: toy.func @main() {
+# CHECK-NEXT:    %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT:    %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+# CHECK-NEXT:    %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT:    %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+# CHECK-NEXT:    %4 = toy.matmul %3, %1 : tensor<1x6xf64>, tensor<6x1xf64> -> tensor<*xf64>
+# CHECK-NEXT:    toy.print %4 : tensor<*xf64>
+# CHECK-NEXT:    toy.return
+# CHECK-NEXT:  }

>From 68ea36a89ea3ab60a5371d978488d7c04b7fd7cf Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Tue, 25 Jun 2024 13:48:41 +0000
Subject: [PATCH 06/13] Feat add Fileck for matmul1.mlir

---
 mlir/test/Examples/Toy/Matmul/matmul1.mlir | 45 ++++++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.mlir b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
index 956da787f9b36..0a6fbf5be7087 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul1.mlir
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
@@ -1,3 +1,5 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
+
 module {
   toy.func @main() {
     %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@@ -9,3 +11,46 @@ module {
     toy.return
   }
 }
+
+//CHECK-LABEL: func @main
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
+//CHECK-NEXT: %cst_0 = arith.constant 6.000000e+00 : f64
+//CHECK-NEXT: %cst_1 = arith.constant 5.000000e+00 : f64
+//CHECK-NEXT: %cst_2 = arith.constant 4.000000e+00 : f64
+//CHECK-NEXT: %cst_3 = arith.constant 3.000000e+00 : f64
+//CHECK-NEXT: %cst_4 = arith.constant 2.000000e+00 : f64
+//CHECK-NEXT: %cst_5 = arith.constant 1.000000e+00 : f64
+//CHECK-NEXT: %alloc = memref.alloc() : memref<3x2xf64>
+//CHECK-NEXT: %alloc_6 = memref.alloc() : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc_6[0, 0] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc_6[0, 1] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc_6[0, 2] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc_6[1, 0] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc_6[1, 1] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc_6[1, 2] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc[0, 0] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc[0, 1] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc[1, 0] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc[1, 1] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc[2, 0] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc[2, 1] : memref<3x2xf64>
+//CHECK-NEXT: %alloc_7 = memref.alloc() : memref<2x2xf64>
+//CHECK-NEXT: affine.for %arg0 = 0 to 2 {
+//CHECK-NEXT:   affine.for %arg1 = 0 to 2 {
+//CHECK-NEXT:     affine.store %cst, %alloc_7[%arg0, %arg1] : memref<2x2xf64>
+//CHECK-NEXT:     affine.for %arg2 = 0 to 3 {
+//CHECK-NEXT:       %0 = affine.load %alloc_6[%arg0, %arg2] : memref<2x3xf64>
+//CHECK-NEXT:       %1 = affine.load %alloc[%arg2, %arg1] : memref<3x2xf64>
+//CHECK-NEXT:       %2 = arith.mulf %0, %1 : f64
+//CHECK-NEXT:       %3 = arith.addf %2, %cst : f64
+//CHECK-NEXT:       %4 = affine.load %alloc_7[%arg0, %arg1] : memref<2x2xf64>
+//CHECK-NEXT:       %5 = arith.addf %3, %4 : f64
+//CHECK-NEXT:       affine.store %5, %alloc_7[%arg0, %arg1] : memref<2x2xf64>
+//CHECK-NEXT:     }
+//CHECK-NEXT:   }
+//CHECK-NEXT: }
+//CHECK-NEXT: toy.print %alloc_7 : memref<2x2xf64>
+//CHECK-NEXT: memref.dealloc %alloc_6 : memref<2x3xf64>
+//CHECK-NEXT: memref.dealloc %alloc : memref<3x2xf64>
+//CHECK-NEXT: return
+//CHECK-NEXT: }

>From 283c8c927bbf2d1f5d1cb27efe81f406ec407d89 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 4 Jul 2024 07:25:41 +0000
Subject: [PATCH 07/13] Feat: Code for Loop unroll in SCF

---
 .../mlir/Dialect/SCF/Transforms/Passes.h      |   8 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  17 ++
 .../lib/Dialect/SCF/Transforms/LoopUnroll.cpp | 183 ++++++++++++++++++
 3 files changed, 208 insertions(+)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index fb8411418ff9a..d325fa515a0e3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,12 +62,20 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 /// Creates a pass that converts SCF forall loops to SCF for loops.
 std::unique_ptr<Pass> createForallToForLoopPass();
 
+// Creates a pass that counts the number of operations in SCF
+std::unique_ptr<Pass> createCounterPass();
+
+// Creates pass for loop unroll
+std::unique_ptr<Pass> createLoopUnroll(
+    int unrollFactor = 4, bool unrollFull = false );
+
 /// Creates a pass that converts SCF forall loops to SCF parallel loops.
 std::unique_ptr<Pass> createForallToParallelLoopPass();
 
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 9b29affb97c43..694e75bdd49b6 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -125,6 +125,22 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
   let constructor = "mlir::createForallToForLoopPass()";
 }
 
+def SCFCounter : Pass<"scf-counter"> {
+  let summary = "Count operations for SCF";
+  let constructor = "mlir::createCounterPass()";
+}
+
+def SCFLoopUnroll : Pass<"scf-loop-unroll">{
+  let summary = " Pass to unroll for loops in scf";
+  let constructor = "mlir::createLoopUnroll()";
+   let options = [
+    Option<"unrollFactor", "unroll-factor", "int", /*default=*/"4",
+           "Use this unroll factor for all loops being unrolled">,
+    Option<"unrollFull", "unroll-full", "bool", /*default=*/"false",
+           "Fully unroll loops">,
+  ];
+}
+
 def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
   let summary = "Convert SCF forall loops to SCF parallel loops";
   let constructor = "mlir::createForallToParallelLoopPass()";
@@ -164,4 +180,5 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
+
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
new file mode 100644
index 0000000000000..38a8349793d0e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
@@ -0,0 +1,183 @@
+// /*
+// Unrolls loop in the SCF Dialect
+// */
+
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Pass/PassManager.h" 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+namespace mlir{
+    #define GEN_PASS_DEF_SCFLOOPUNROLL
+    #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+}
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::scf;
+using scf::ForOp;
+
+namespace {
+struct LoopUnroll : public impl::SCFLoopUnrollBase<LoopUnroll> {
+    private:
+        int unrollFactor;
+        bool unrollFull;
+    public:
+    LoopUnroll() = default;
+    LoopUnroll(int unrollFactor, bool unrollFull)
+        : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+    void runOnOperation() override {
+        Operation *parentOp = getOperation();
+        OpBuilder builder(parentOp->getContext());
+        SmallVector<ForOp, 4> loops;
+        gatherInnermostLoops(parentOp,loops);
+        if(loops.empty())return;
+        for(auto forOp: loops)
+        {
+            auto result = runOnSCFForOp(builder, forOp);
+        }
+    }
+    Value createUpdatedInductionVar(unsigned i,Value iv,OpBuilder &builder,int64_t step){
+        Location loc = iv.getLoc();
+        Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+        Value constantStep = builder.create<arith::ConstantIndexOp>(loc, step);
+        Value increment = builder.create<arith::MulIOp>(loc, constantI, constantStep);
+        Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+        return updatedIv;
+    }
+
+    void generateUnroll(ForOp forOp, int64_t step)
+    {
+        Block *loopBody = forOp.getBody();
+        auto returnValues = forOp.getBody()->getTerminator()->getOperands();
+        Value forOpInductionVar = forOp.getInductionVar();
+        ValueRange iterArgs(forOp.getRegionIterArgs());
+        auto builder = OpBuilder::atBlockTerminator(loopBody);
+        Block::iterator originalBlockEnd = std::prev(loopBody->end(),2);
+        SmallVector<Value,4> lastReturnValues(returnValues);
+        for(int i = 1 ; i < unrollFactor; i++)
+        {
+            IRMapping mapper;
+            mapper.map(iterArgs,lastReturnValues);
+            if(!forOpInductionVar.use_empty())
+            {
+                Value updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar,builder,step);
+                mapper.map(forOpInductionVar,updatedInductionVar);
+            }
+            for (auto it = loopBody->begin(); it != std::next(originalBlockEnd); it++) {
+                Operation *clonedOp = builder.clone(*it, mapper);
+            }
+            for(int j = 0; j < lastReturnValues.size(); j++)
+            {
+                Operation *defOp = returnValues[j].getDefiningOp();
+                if(defOp && defOp->getBlock()==loopBody)
+                {
+                    lastReturnValues[j] = mapper.lookup(returnValues[j]);
+                }
+            }
+        }
+        loopBody->getTerminator()->setOperands(lastReturnValues);
+    }
+    void fullUnroll(ForOp forOp, int64_t tripCount, int64_t step)
+    {
+        IRRewriter rewriter(forOp.getContext());
+        if(tripCount == 0)return;
+        if(tripCount ==1)
+        {
+            (void)forOp.promoteIfSingleIteration(rewriter);
+            return;
+        }
+        unrollFactor = tripCount;
+        generateUnroll(forOp,step);
+        return;
+    }
+    LogicalResult runOnSCFForOp(OpBuilder &builder, ForOp forOp) {
+        IRRewriter rewriter(forOp.getContext());
+        Value lowerBound = forOp.getLowerBound();
+        Value upperBound = forOp.getUpperBound();
+        Value step = forOp.getStep();
+        Value newStepValue,upperBoundUnrolled;
+        bool createHandlerLoop = false;
+        auto lowerBoundConst = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
+        auto upperBoundConst = upperBound.getDefiningOp<arith::ConstantIndexOp>();
+        auto stepConst = step.getDefiningOp<arith::ConstantIndexOp>();
+        if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+            forOp.emitError("Expected constant bounds and step for unrolling.");
+            return failure();
+        }
+        int64_t lowerBoundValue = lowerBoundConst.value();
+        int64_t upperBoundValue = upperBoundConst.value();
+        int64_t stepValue = stepConst.value();
+        int64_t tripCount = (upperBoundValue-lowerBoundValue)/stepValue;
+        int64_t multipliedStepValue = stepValue * unrollFactor;
+        int64_t tripCountUnrolled = tripCount -(tripCount%unrollFactor);
+        int64_t unrolledUpperBound = lowerBoundValue +(tripCountUnrolled*stepValue);
+        createHandlerLoop = unrolledUpperBound < upperBoundValue;
+        if (Block *prevBlock = forOp->getBlock()->getPrevNode())builder.setInsertionPointToEnd(prevBlock);
+        else builder.setInsertionPoint(forOp);
+        if(createHandlerLoop)
+        {
+            upperBoundUnrolled = builder.create<arith::ConstantIndexOp>(forOp.getLoc(),unrolledUpperBound);
+            newStepValue = builder.create<arith::ConstantIndexOp>(
+                forOp.getLoc(), multipliedStepValue);
+            builder.setInsertionPoint(forOp);
+        }
+        else{
+            upperBoundUnrolled = upperBound;
+            newStepValue = step;
+        }
+        if(createHandlerLoop)
+        {
+            OpBuilder loopBuilder(forOp->getContext());
+            loopBuilder.setInsertionPoint(forOp->getBlock(),std::next(Block::iterator(forOp)));
+            auto handlerForOp = cast<scf::ForOp>(loopBuilder.clone(*forOp));
+            handlerForOp.setLowerBound(upperBoundUnrolled);
+        
+            auto results = forOp.getResults();
+            auto handlerResults = handlerForOp.getResults();
+            for (auto element : llvm::zip(results, handlerResults)) {
+                std::get<0>(element).replaceAllUsesWith(std::get<1>(element));
+            }
+            handlerForOp->setOperands(handlerForOp.getNumControlOperands(),
+                                        handlerForOp.getInitArgs().size(), results);
+            (void)handlerForOp.promoteIfSingleIteration(rewriter);
+        
+        }
+        forOp.setStep(upperBoundUnrolled);
+        forOp.setStep(newStepValue);
+        if(unrollFull == false)generateUnroll(forOp,stepValue);
+        else fullUnroll(forOp,tripCount,stepValue);
+        (void)forOp.promoteIfSingleIteration(rewriter);
+        llvm::errs() << "Completed unroll\n";
+        return success();
+    }
+    static bool isInnermostSCFForOp(ForOp op) {
+        return !op.getBody()
+                    ->walk([&](ForOp nestedForOp) {
+                        return WalkResult::interrupt();
+                    })
+                    .wasInterrupted();
+    }
+
+    static void gatherInnermostLoops(Operation *op,
+                                    SmallVectorImpl<ForOp> &loops) {
+        op->walk([&](ForOp forOp) {
+        if (isInnermostSCFForOp(forOp))
+        loops.push_back(forOp);
+    });
+    }
+
+    virtual ~LoopUnroll() = default;
+};
+} // end anonymous namespace
+
+  
+std::unique_ptr<Pass> mlir::createLoopUnroll(int unrollFactor, bool unrollFull) {
+    return std::make_unique<LoopUnroll>(unrollFactor,unrollFull);
+}
\ No newline at end of file

>From e3c1232602825974772e6d9bf7577c6f2ba2633f Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 4 Jul 2024 10:01:18 +0000
Subject: [PATCH 08/13] Test: Add test cases

---
 mlir/test/Dialect/SCF/loop-unroll-scf.mlir  | 38 +++++++++
 mlir/test/Dialect/SCF/loop-unroll-scf2.mlir | 90 +++++++++++++++++++++
 mlir/test/Dialect/SCF/loop-unroll-scf3.mlir | 49 +++++++++++
 3 files changed, 177 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/loop-unroll-scf.mlir
 create mode 100644 mlir/test/Dialect/SCF/loop-unroll-scf2.mlir
 create mode 100644 mlir/test/Dialect/SCF/loop-unroll-scf3.mlir

diff --git a/mlir/test/Dialect/SCF/loop-unroll-scf.mlir b/mlir/test/Dialect/SCF/loop-unroll-scf.mlir
new file mode 100644
index 0000000000000..d0fb3fd0aa41c
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-scf.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --scf-loop-unroll --split-input-file | FileCheck %s
+module {
+  func.func @main() -> f32 {
+    %sum = arith.constant 0.0 : f32
+    %val = arith.constant 2.0 : f32
+    %N = arith.constant 10 : index
+
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+
+    %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+      %new_sum = arith.mulf %iter_sum, %val : f32
+      scf.yield %new_sum : f32
+    }
+    return %result : f32
+  }
+}
+//CHECK-LABEL: func.func @main() -> f32 {
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
+//CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f32
+//CHECK-NEXT: %c10 = arith.constant 10 : index
+//CHECK-NEXT: %c0 = arith.constant 0 : index
+//CHECK-NEXT: %c1 = arith.constant 1 : index
+//CHECK-NEXT: %c8 = arith.constant 8 : index
+//CHECK-NEXT: %c4 = arith.constant 4 : index
+//CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c10 step %c4 iter_args(%arg1 = %cst) -> (f32) {
+//CHECK-NEXT:   %2 = arith.mulf %arg1, %cst_0 : f32
+//CHECK-NEXT:   %3 = arith.mulf %2, %cst_0 : f32
+//CHECK-NEXT:   %4 = arith.mulf %3, %cst_0 : f32
+//CHECK-NEXT:   %5 = arith.mulf %4, %cst_0 : f32
+//CHECK-NEXT:   scf.yield %5 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: %1 = scf.for %arg0 = %c8 to %c10 step %c1 iter_args(%arg1 = %0) -> (f32) {
+//CHECK-NEXT:   %2 = arith.mulf %arg1, %cst_0 : f32
+//CHECK-NEXT:   scf.yield %2 : f32
+//CHECK-NEXT:  }
+//CHECK-NEXT: return
+//CHECK-NEXT: }
diff --git a/mlir/test/Dialect/SCF/loop-unroll-scf2.mlir b/mlir/test/Dialect/SCF/loop-unroll-scf2.mlir
new file mode 100644
index 0000000000000..552b8f5fb67ff
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-scf2.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s --scf-loop-unroll --split-input-file | FileCheck %s
+
+module {
+  func.func @main() -> f32 {
+    %N = arith.constant 10 : index
+    %val = arith.constant 2.0 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+
+    %array = memref.alloc() : memref<10xf32>
+
+    // Initialize array with %val
+    scf.for %i = %c0 to %N step %c1 {
+      memref.store %val, %array[%i] : memref<10xf32>
+    }
+
+    %sum = arith.constant 0.0 : f32
+
+    %result = scf.for %j = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+      %current_val = memref.load %array[%j] : memref<10xf32>
+      %new_sum = arith.addf %iter_sum, %current_val : f32
+      scf.yield %new_sum : f32
+    }
+
+    return %result : f32
+  }
+}
+
+//CHECK-LABEL: func.func @main() -> f32 {
+//CHECK-NEXT: %c10 = arith.constant 10 : index
+//CHECK-NEXT: %cst = arith.constant 2.000000e+00 : f32
+//CHECK-NEXT: %c0 = arith.constant 0 : index
+//CHECK-NEXT: %c1 = arith.constant 1 : index
+//CHECK-NEXT: %alloc = memref.alloc() : memref<10xf32>
+//CHECK-NEXT: %c8 = arith.constant 8 : index
+//CHECK-NEXT: %c4 = arith.constant 4 : index
+//CHECK-NEXT: scf.for %arg0 = %c0 to %c10 step %c4 {
+//CHECK-NEXT:   memref.store %cst, %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT:   %c1_3 = arith.constant 1 : index
+//CHECK-NEXT:   %c1_4 = arith.constant 1 : index
+//CHECK-NEXT:   %2 = arith.muli %c1_3, %c1_4 : index
+//CHECK-NEXT:   %3 = arith.addi %arg0, %2 : index
+//CHECK-NEXT:   memref.store %cst, %alloc[%3] : memref<10xf32>
+//CHECK-NEXT:   %c2 = arith.constant 2 : index
+//CHECK-NEXT:   %c1_5 = arith.constant 1 : index
+//CHECK-NEXT:   %4 = arith.muli %c2, %c1_5 : index
+//CHECK-NEXT:   %5 = arith.addi %arg0, %4 : index
+//CHECK-NEXT:   memref.store %cst, %alloc[%5] : memref<10xf32>
+//CHECK-NEXT:   %c3 = arith.constant 3 : index
+//CHECK-NEXT:   %c1_6 = arith.constant 1 : index
+//CHECK-NEXT:   %6 = arith.muli %c3, %c1_6 : index
+//CHECK-NEXT:   %7 = arith.addi %arg0, %6 : index
+//CHECK-NEXT:   memref.store %cst, %alloc[%7] : memref<10xf32>
+//CHECK-NEXT: }
+//CHECK-NEXT: scf.for %arg0 = %c8 to %c10 step %c1 {
+//CHECK-NEXT:   memref.store %cst, %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT: }
+//CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32
+//CHECK-NEXT: %c8_1 = arith.constant 8 : index
+//CHECK-NEXT: %c4_2 = arith.constant 4 : index
+//CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c10 step %c4_2 iter_args(%arg1 = %cst_0) -> (f32) {
+//CHECK-NEXT:   %2 = memref.load %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT:   %3 = arith.addf %arg1, %2 : f32
+//CHECK-NEXT:   %c1_3 = arith.constant 1 : index
+//CHECK-NEXT:   %c1_4 = arith.constant 1 : index
+//CHECK-NEXT:   %4 = arith.muli %c1_3, %c1_4 : index
+//CHECK-NEXT:   %5 = arith.addi %arg0, %4 : index
+//CHECK-NEXT:   %6 = memref.load %alloc[%5] : memref<10xf32>
+//CHECK-NEXT:   %7 = arith.addf %3, %6 : f32
+//CHECK-NEXT:   %c2 = arith.constant 2 : index
+//CHECK-NEXT:   %c1_5 = arith.constant 1 : index
+//CHECK-NEXT:   %8 = arith.muli %c2, %c1_5 : index
+//CHECK-NEXT:   %9 = arith.addi %arg0, %8 : index
+//CHECK-NEXT:   %10 = memref.load %alloc[%9] : memref<10xf32>
+//CHECK-NEXT:   %11 = arith.addf %7, %10 : f32
+//CHECK-NEXT:   %c3 = arith.constant 3 : index
+//CHECK-NEXT:   %c1_6 = arith.constant 1 : index
+//CHECK-NEXT:   %12 = arith.muli %c3, %c1_6 : index
+//CHECK-NEXT:   %13 = arith.addi %arg0, %12 : index
+//CHECK-NEXT:   %14 = memref.load %alloc[%13] : memref<10xf32>
+//CHECK-NEXT:   %15 = arith.addf %11, %14 : f32
+//CHECK-NEXT:   scf.yield %15 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: %1 = scf.for %arg0 = %c8_1 to %c10 step %c1 iter_args(%arg1 = %0) -> (f32) {
+//CHECK-NEXT:   %2 = memref.load %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT:   %3 = arith.addf %arg1, %2 : f32
+//CHECK-NEXT:   scf.yield %3 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: return %1 : f32
+//CHECK-NEXT: }
diff --git a/mlir/test/Dialect/SCF/loop-unroll-scf3.mlir b/mlir/test/Dialect/SCF/loop-unroll-scf3.mlir
new file mode 100644
index 0000000000000..7c2f4bdb76cce
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-scf3.mlir
@@ -0,0 +1,49 @@
+module {
+  func.func @main() -> f32 {
+    %sum = arith.constant 0.0 : f32
+    %val = arith.constant 2.0 : f32
+    %N = arith.constant 4 : index
+    %num = arith.constant 10 : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+
+    %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+      %new_sum = arith.addf %iter_sum, %val : f32
+      %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %new_sum) -> (f32) {
+        %new_sum2 = arith.addf %iter_sum2, %val : f32
+        scf.yield %new_sum2 : f32
+      }
+      %new_sum3 = arith.addf %result2, %val : f32
+      scf.yield %new_sum : f32
+    }
+    return %result : f32
+  }
+}
+
+//CHECK-LABEL: func.func @main() -> f32 {
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
+//CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f32
+//CHECK-NEXT: %c4 = arith.constant 4 : index
+//CHECK-NEXT: %c10 = arith.constant 10 : index
+//CHECK-NEXT: %c0 = arith.constant 0 : index
+//CHECK-NEXT: %c1 = arith.constant 1 : index
+//CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %cst) -> (f32) {
+//CHECK-NEXT:   %1 = arith.addf %arg1, %cst_0 : f32
+//CHECK-NEXT:   %c8 = arith.constant 8 : index
+//CHECK-NEXT:   %c4_1 = arith.constant 4 : index
+//CHECK-NEXT:   %2 = scf.for %arg2 = %c0 to %c10 step %c4_1 iter_args(%arg3 = %1) -> (f32) {
+//CHECK-NEXT:     %5 = arith.addf %arg3, %cst_0 : f32
+//CHECK-NEXT:     %6 = arith.addf %5, %cst_0 : f32
+//CHECK-NEXT:     %7 = arith.addf %6, %cst_0 : f32
+//CHECK-NEXT:     %8 = arith.addf %7, %cst_0 : f32
+//CHECK-NEXT:     scf.yield %8 : f32
+//CHECK-NEXT:   }
+//CHECK-NEXT:   %3 = scf.for %arg2 = %c8 to %c10 step %c1 iter_args(%arg3 = %2) -> (f32) {
+//CHECK-NEXT:     %5 = arith.addf %arg3, %cst_0 : f32
+//CHECK-NEXT:     scf.yield %5 : f32
+//CHECK-NEXT:   }
+//CHECK-NEXT:   %4 = arith.addf %3, %cst_0 : f32
+//CHECK-NEXT:   scf.yield %1 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: return %0 : f32
+//CHECK-NEXT: }

>From e2e749087678ac53d3bbe7ef23d2d9319714d049 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Sat, 6 Jul 2024 13:49:54 +0000
Subject: [PATCH 09/13] Refactor:Make code formatted and readable

---
 .../lib/Dialect/SCF/Transforms/LoopUnroll.cpp | 307 +++++++++---------
 1 file changed, 157 insertions(+), 150 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
index 38a8349793d0e..f07c6d12f4c79 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
@@ -2,22 +2,21 @@
 // Unrolls loop in the SCF Dialect
 // */
 
-
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
-#include "mlir/Pass/PassManager.h" 
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-namespace mlir{
-    #define GEN_PASS_DEF_SCFLOOPUNROLL
-    #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
-}
+#include "mlir/Support/LLVM.h"
+namespace mlir {
+#define GEN_PASS_DEF_SCFLOOPUNROLL
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::scf;
@@ -25,159 +24,167 @@ using scf::ForOp;
 
 namespace {
 struct LoopUnroll : public impl::SCFLoopUnrollBase<LoopUnroll> {
-    private:
-        int unrollFactor;
-        bool unrollFull;
-    public:
-    LoopUnroll() = default;
-    LoopUnroll(int unrollFactor, bool unrollFull)
-        : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
-    void runOnOperation() override {
-        Operation *parentOp = getOperation();
-        OpBuilder builder(parentOp->getContext());
-        SmallVector<ForOp, 4> loops;
-        gatherInnermostLoops(parentOp,loops);
-        if(loops.empty())return;
-        for(auto forOp: loops)
-        {
-            auto result = runOnSCFForOp(builder, forOp);
-        }
-    }
-    Value createUpdatedInductionVar(unsigned i,Value iv,OpBuilder &builder,int64_t step){
-        Location loc = iv.getLoc();
-        Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
-        Value constantStep = builder.create<arith::ConstantIndexOp>(loc, step);
-        Value increment = builder.create<arith::MulIOp>(loc, constantI, constantStep);
-        Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
-        return updatedIv;
+private:
+  int unrollFactor;
+  bool unrollFull;
+
+public:
+  LoopUnroll() = default;
+  LoopUnroll(int unrollFactor, bool unrollFull)
+      : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    OpBuilder builder(parentOp->getContext());
+    SmallVector<ForOp, 4> loops;
+    gatherInnermostLoops(parentOp, loops);
+    if (loops.empty())
+      return;
+    for (auto forOp : loops) {
+      auto result = runOnSCFForOp(builder, forOp);
     }
+  }
 
-    void generateUnroll(ForOp forOp, int64_t step)
-    {
-        Block *loopBody = forOp.getBody();
-        auto returnValues = forOp.getBody()->getTerminator()->getOperands();
-        Value forOpInductionVar = forOp.getInductionVar();
-        ValueRange iterArgs(forOp.getRegionIterArgs());
-        auto builder = OpBuilder::atBlockTerminator(loopBody);
-        Block::iterator originalBlockEnd = std::prev(loopBody->end(),2);
-        SmallVector<Value,4> lastReturnValues(returnValues);
-        for(int i = 1 ; i < unrollFactor; i++)
-        {
-            IRMapping mapper;
-            mapper.map(iterArgs,lastReturnValues);
-            if(!forOpInductionVar.use_empty())
-            {
-                Value updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar,builder,step);
-                mapper.map(forOpInductionVar,updatedInductionVar);
-            }
-            for (auto it = loopBody->begin(); it != std::next(originalBlockEnd); it++) {
-                Operation *clonedOp = builder.clone(*it, mapper);
-            }
-            for(int j = 0; j < lastReturnValues.size(); j++)
-            {
-                Operation *defOp = returnValues[j].getDefiningOp();
-                if(defOp && defOp->getBlock()==loopBody)
-                {
-                    lastReturnValues[j] = mapper.lookup(returnValues[j]);
-                }
-            }
+  Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
+                                  int64_t step) {
+    /*Function writes ir to update the induction variable*/
+    Location loc = iv.getLoc();
+    Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+    Value constantStep = builder.create<arith::ConstantIndexOp>(loc, step);
+    Value increment =
+        builder.create<arith::MulIOp>(loc, constantI, constantStep);
+    Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+    return updatedIv;
+  }
+
+  void generateUnroll(ForOp forOp, int64_t step) {
+    /*Function unrolls a given loop by the given step*/
+    Block *loopBody = forOp.getBody();
+    auto returnValues = forOp.getBody()->getTerminator()->getOperands();
+    Value forOpInductionVar = forOp.getInductionVar();
+    ValueRange iterArgs(forOp.getRegionIterArgs());
+    auto builder = OpBuilder::atBlockTerminator(loopBody);
+    Block::iterator originalBlockEnd = std::prev(loopBody->end(), 2);
+    SmallVector<Value, 4> lastReturnValues(returnValues);
+    for (int i = 1; i < unrollFactor; i++) {
+      IRMapping mapper;
+      mapper.map(iterArgs, lastReturnValues);
+      if (!forOpInductionVar.use_empty()) {
+        Value updatedInductionVar =
+            createUpdatedInductionVar(i, forOpInductionVar, builder, step);
+        mapper.map(forOpInductionVar, updatedInductionVar);
+      }
+      for (auto it = loopBody->begin(); it != std::next(originalBlockEnd);
+           it++) {
+        Operation *clonedOp = builder.clone(*it, mapper);
+      }
+      for (int j = 0; j < lastReturnValues.size(); j++) {
+        Operation *defOp = returnValues[j].getDefiningOp();
+        if (defOp && defOp->getBlock() == loopBody) {
+          lastReturnValues[j] = mapper.lookup(returnValues[j]);
         }
-        loopBody->getTerminator()->setOperands(lastReturnValues);
+      }
     }
-    void fullUnroll(ForOp forOp, int64_t tripCount, int64_t step)
-    {
-        IRRewriter rewriter(forOp.getContext());
-        if(tripCount == 0)return;
-        if(tripCount ==1)
-        {
-            (void)forOp.promoteIfSingleIteration(rewriter);
-            return;
-        }
-        unrollFactor = tripCount;
-        generateUnroll(forOp,step);
-        return;
+    loopBody->getTerminator()->setOperands(lastReturnValues);
+  }
+  void fullUnroll(ForOp forOp, int64_t tripCount, int64_t step) {
+    /*Function to fully unroll a given scf for Op*/
+    IRRewriter rewriter(forOp.getContext());
+    if (tripCount == 0)
+      return;
+    if (tripCount == 1) {
+      (void)forOp.promoteIfSingleIteration(rewriter);
+      return;
     }
-    LogicalResult runOnSCFForOp(OpBuilder &builder, ForOp forOp) {
-        IRRewriter rewriter(forOp.getContext());
-        Value lowerBound = forOp.getLowerBound();
-        Value upperBound = forOp.getUpperBound();
-        Value step = forOp.getStep();
-        Value newStepValue,upperBoundUnrolled;
-        bool createHandlerLoop = false;
-        auto lowerBoundConst = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
-        auto upperBoundConst = upperBound.getDefiningOp<arith::ConstantIndexOp>();
-        auto stepConst = step.getDefiningOp<arith::ConstantIndexOp>();
-        if (!lowerBoundConst || !upperBoundConst || !stepConst) {
-            forOp.emitError("Expected constant bounds and step for unrolling.");
-            return failure();
-        }
-        int64_t lowerBoundValue = lowerBoundConst.value();
-        int64_t upperBoundValue = upperBoundConst.value();
-        int64_t stepValue = stepConst.value();
-        int64_t tripCount = (upperBoundValue-lowerBoundValue)/stepValue;
-        int64_t multipliedStepValue = stepValue * unrollFactor;
-        int64_t tripCountUnrolled = tripCount -(tripCount%unrollFactor);
-        int64_t unrolledUpperBound = lowerBoundValue +(tripCountUnrolled*stepValue);
-        createHandlerLoop = unrolledUpperBound < upperBoundValue;
-        if (Block *prevBlock = forOp->getBlock()->getPrevNode())builder.setInsertionPointToEnd(prevBlock);
-        else builder.setInsertionPoint(forOp);
-        if(createHandlerLoop)
-        {
-            upperBoundUnrolled = builder.create<arith::ConstantIndexOp>(forOp.getLoc(),unrolledUpperBound);
-            newStepValue = builder.create<arith::ConstantIndexOp>(
-                forOp.getLoc(), multipliedStepValue);
-            builder.setInsertionPoint(forOp);
-        }
-        else{
-            upperBoundUnrolled = upperBound;
-            newStepValue = step;
-        }
-        if(createHandlerLoop)
-        {
-            OpBuilder loopBuilder(forOp->getContext());
-            loopBuilder.setInsertionPoint(forOp->getBlock(),std::next(Block::iterator(forOp)));
-            auto handlerForOp = cast<scf::ForOp>(loopBuilder.clone(*forOp));
-            handlerForOp.setLowerBound(upperBoundUnrolled);
-        
-            auto results = forOp.getResults();
-            auto handlerResults = handlerForOp.getResults();
-            for (auto element : llvm::zip(results, handlerResults)) {
-                std::get<0>(element).replaceAllUsesWith(std::get<1>(element));
-            }
-            handlerForOp->setOperands(handlerForOp.getNumControlOperands(),
-                                        handlerForOp.getInitArgs().size(), results);
-            (void)handlerForOp.promoteIfSingleIteration(rewriter);
-        
-        }
-        forOp.setStep(upperBoundUnrolled);
-        forOp.setStep(newStepValue);
-        if(unrollFull == false)generateUnroll(forOp,stepValue);
-        else fullUnroll(forOp,tripCount,stepValue);
-        (void)forOp.promoteIfSingleIteration(rewriter);
-        llvm::errs() << "Completed unroll\n";
-        return success();
+    unrollFactor = tripCount;
+    generateUnroll(forOp, step);
+    return;
+  }
+  LogicalResult runOnSCFForOp(OpBuilder &builder, ForOp forOp) {
+    /*Function process SCF ForOps*/
+    IRRewriter rewriter(forOp.getContext());
+    Value lowerBound = forOp.getLowerBound();
+    Value upperBound = forOp.getUpperBound();
+    Value step = forOp.getStep();
+    Value newStepValue, upperBoundUnrolled;
+    bool createHandlerLoop = false;
+    auto lowerBoundConst = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
+    auto upperBoundConst = upperBound.getDefiningOp<arith::ConstantIndexOp>();
+    auto stepConst = step.getDefiningOp<arith::ConstantIndexOp>();
+    if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+      forOp.emitError("Expected constant bounds and step for unrolling.");
+      return failure();
+    }
+    int64_t lowerBoundValue = lowerBoundConst.value();
+    int64_t upperBoundValue = upperBoundConst.value();
+    int64_t stepValue = stepConst.value();
+    int64_t tripCount = (upperBoundValue - lowerBoundValue) / stepValue;
+    int64_t multipliedStepValue = stepValue * unrollFactor;
+    int64_t tripCountUnrolled = tripCount - (tripCount % unrollFactor);
+    int64_t unrolledUpperBound =
+        lowerBoundValue + (tripCountUnrolled * stepValue);
+    createHandlerLoop = unrolledUpperBound < upperBoundValue;
+    if (Block *prevBlock = forOp->getBlock()->getPrevNode())
+      builder.setInsertionPointToEnd(prevBlock);
+    else
+      builder.setInsertionPoint(forOp);
+    if (createHandlerLoop) {
+      upperBoundUnrolled = builder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), unrolledUpperBound);
+      newStepValue = builder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), multipliedStepValue);
+      builder.setInsertionPoint(forOp);
+    } else {
+      upperBoundUnrolled = upperBound;
+      newStepValue = step;
     }
-    static bool isInnermostSCFForOp(ForOp op) {
-        return !op.getBody()
-                    ->walk([&](ForOp nestedForOp) {
-                        return WalkResult::interrupt();
-                    })
-                    .wasInterrupted();
+    if (createHandlerLoop) {
+      OpBuilder loopBuilder(forOp->getContext());
+      loopBuilder.setInsertionPoint(forOp->getBlock(),
+                                    std::next(Block::iterator(forOp)));
+      auto handlerForOp = cast<scf::ForOp>(loopBuilder.clone(*forOp));
+      handlerForOp.setLowerBound(upperBoundUnrolled);
+
+      auto results = forOp.getResults();
+      auto handlerResults = handlerForOp.getResults();
+      for (auto element : llvm::zip(results, handlerResults)) {
+        std::get<0>(element).replaceAllUsesWith(std::get<1>(element));
+      }
+      handlerForOp->setOperands(handlerForOp.getNumControlOperands(),
+                                handlerForOp.getInitArgs().size(), results);
+      (void)handlerForOp.promoteIfSingleIteration(rewriter);
     }
+    forOp.setStep(upperBoundUnrolled);
+    forOp.setStep(newStepValue);
+    if (unrollFull == false)
+      generateUnroll(forOp, stepValue);
+    else
+      fullUnroll(forOp, tripCount, stepValue);
+    (void)forOp.promoteIfSingleIteration(rewriter);
+    llvm::errs() << "Completed unroll\n";
+    return success();
+  }
+  static bool isInnermostSCFForOp(ForOp op) {
+    /*Identifies if a given for loop is the inner most scf ForOp*/
+    return !op.getBody()
+                ->walk(
+                    [&](ForOp nestedForOp) { return WalkResult::interrupt(); })
+                .wasInterrupted();
+  }
 
-    static void gatherInnermostLoops(Operation *op,
-                                    SmallVectorImpl<ForOp> &loops) {
-        op->walk([&](ForOp forOp) {
-        if (isInnermostSCFForOp(forOp))
+  static void gatherInnermostLoops(Operation *op,
+                                   SmallVectorImpl<ForOp> &loops) {
+    /*Function gathers all the innermost scf for loops*/
+    op->walk([&](ForOp forOp) {
+      if (isInnermostSCFForOp(forOp))
         loops.push_back(forOp);
     });
-    }
+  }
 
-    virtual ~LoopUnroll() = default;
+  virtual ~LoopUnroll() = default;
 };
 } // end anonymous namespace
 
-  
-std::unique_ptr<Pass> mlir::createLoopUnroll(int unrollFactor, bool unrollFull) {
-    return std::make_unique<LoopUnroll>(unrollFactor,unrollFull);
+std::unique_ptr<Pass> mlir::createLoopUnroll(int unrollFactor,
+                                             bool unrollFull) {
+  return std::make_unique<LoopUnroll>(unrollFactor, unrollFull);
 }
\ No newline at end of file

>From 2b835b41d21cb208fb21551fd8d0f8f73f3afaee Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Tue, 9 Jul 2024 09:11:07 +0000
Subject: [PATCH 10/13] Feat: Upddate LoopUnrollJam

---
 .../mlir/Dialect/SCF/Transforms/Passes.h      |   4 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  10 +
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |   3 +
 .../Dialect/SCF/Transforms/LoopUnrollJam.cpp  | 220 ++++++++++++++++++
 4 files changed, 237 insertions(+)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index d325fa515a0e3..a83925a322834 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -69,6 +69,10 @@ std::unique_ptr<Pass> createCounterPass();
 std::unique_ptr<Pass> createLoopUnroll(
     int unrollFactor = 4, bool unrollFull = false );
 
+// Create pass for loop unroll jam
+std::unique_ptr<Pass> createLoopUnrollJam(
+    int unrollJamFactor=-1);
+
 /// Creates a pass that converts SCF forall loops to SCF parallel loops.
 std::unique_ptr<Pass> createForallToParallelLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 694e75bdd49b6..c5b389e379112 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -141,6 +141,16 @@ def SCFLoopUnroll : Pass<"scf-loop-unroll">{
   ];
 }
 
+def SCFLoopUnrollJam : Pass<"scf-loop-unroll-jam">
+{
+  let summary= "Unroll and Jam SCF ForOps";
+  let constructor = "mlir::createLoopUnrollJam()";
+  let options = [
+    Option<"unrollJamFactor", "unroll-jam-factor", "unsigned","4",
+    "Use this unroll jam factor for all loops (default 4)">,
+  ];
+}
+
 def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
   let summary = "Convert SCF forall loops to SCF parallel loops";
   let constructor = "mlir::createForallToParallelLoopPass()";
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..a3856080efcd0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -17,6 +17,9 @@ add_mlir_dialect_library(MLIRSCFTransforms
   TileUsingInterface.cpp
   WrapInZeroTripCheck.cpp
   UpliftWhileToFor.cpp
+  counter.cpp
+  LoopUnroll.cpp
+  LoopUnrollJam.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
new file mode 100644
index 0000000000000..82ea65f247e8f
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -0,0 +1,220 @@
+/*
+File containing pass for loop unroll and jam transformation
+ on scf dialect forOp
+*/
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+namespace mlir {
+    #define GEN_PASS_DEF_SCFLOOPUNROLLJAM
+    #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::scf;
+using scf::ForOp;
+
+
+namespace{
+    struct LoopReduction {
+        arith::AtomicRMWKind kind;
+        unsigned position;
+        Value reducedValue;
+    };
+    struct ForBlockGatherer{
+        SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+        void walk(Operation *op) {
+            for (Region &region : op->getRegions())
+            for (Block &block : region)
+                walk(block);
+        }
+        void walk(Block &block) {
+            assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+                "expected block to have a terminator");
+            for (Block::iterator it = block.begin(), e = std::prev(block.end());
+                it != e;) {
+            Block::iterator subBlockStart = it;
+            while (it != e && !isa<ForOp>(&*it))
+                ++it;
+            if (it != subBlockStart)
+                subBlocks.emplace_back(subBlockStart, std::prev(it));
+            // Process all for ops that appear next.
+            while (it != e && isa<ForOp>(&*it))
+                walk(&*it++);
+            }
+        }
+        };
+    struct LoopUnrollJam : public impl::SCFLoopUnrollJamBase<LoopUnrollJam>
+    {
+        explicit LoopUnrollJam(
+            std::optional<unsigned> unrollJamFactor= std::nullopt){
+                if(unrollJamFactor) this->unrollJamFactor = *unrollJamFactor;
+            }
+        void runOnOperation() override
+        {
+            auto *op = getOperation();
+            op->walk([&](ForOp forOp)
+            {
+                (void)loopUnrollJamByFactor(forOp);
+            });
+        }
+        std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
+            // This is a placeholder implementation. You need to adapt it to your use case.
+            auto lowerBoundConst = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+            auto upperBoundConst = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+            auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+
+            if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+                return std::nullopt;
+            }
+
+            int64_t lowerBoundValue = lowerBoundConst.value();
+            int64_t upperBoundValue = upperBoundConst.value();
+            int64_t stepValue = stepConst.value();
+            return (upperBoundValue - lowerBoundValue) / stepValue;
+            }
+        static bool invariantBounds(scf::ForOp forOp) {
+            auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+                if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+                    !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+                    !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+                return WalkResult::interrupt();
+            
+                return WalkResult::advance();
+            });
+            return !walkResult.wasInterrupted();
+        }
+
+        void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
+         ForOp &forOp, IRRewriter &rewriter,
+          const SmallVector<ForOp> &innerLoops, unsigned unrollJamFactor) {
+            for (scf::ForOp currentForOp : innerLoops) {
+                SmallVector<Value> iterOperandsList, yeildOperandsList;
+                ValueRange previousIterOperands = currentForOp.getInits();
+                ValueRange previousIterArgs = currentForOp.getRegionIterArgs();
+                ValueRange previousYeildOperands = cast<YieldOp>(currentForOp.getBody()->getTerminator()).getOperands();
+                for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+                    iterOperandsList.append(previousIterOperands.begin(), previousIterOperands.end());
+                    yeildOperandsList.append(previousYeildOperands.begin(), previousYeildOperands.end());
+                }
+                bool forOpReplaced = currentForOp == forOp;
+                scf::ForOp newForOp =
+                    cast<scf::ForOp>(*currentForOp.replaceWithAdditionalYields(
+                        rewriter, iterOperandsList, /*replaceInitOperandUsesInLoop=*/false,
+                        [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+                            return yeildOperandsList;
+                        }));
+                newInnerLoops.push_back(newForOp);
+
+                if (forOpReplaced) forOp = newForOp;
+                ValueRange newIterArgs = newForOp.getRegionIterArgs();
+                unsigned oldNumIterArgs = previousIterArgs.size();
+                ValueRange newResults = newForOp.getResults();
+                unsigned oldNumResults = newResults.size() / unrollJamFactor;
+                for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+                    for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+                        mapper[i - 1].map(newIterArgs[j],
+                                        newIterArgs[i * oldNumIterArgs + j]);
+                        mapper[i - 1].map(newResults[j],
+                                        newResults[i * oldNumResults + j]);
+                    }
+                }
+            }
+        }
+
+        Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
+                                  Value step) {
+            /*Function writes ir to update the induction variable*/
+            Location loc = iv.getLoc();
+            Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+            Value increment =builder.create<arith::MulIOp>(loc, constantI, step);
+            Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+            return updatedIv;
+        }
+
+        void updateStep(ForOp forOp,IRRewriter &rewriter)
+        {
+            if (Block *prevBlock = forOp->getBlock()->getPrevNode())
+                rewriter.setInsertionPointToEnd(prevBlock);
+            else
+                rewriter.setInsertionPoint(forOp);
+            // rewriter.setInsertionPoint(forOp);
+            auto newStep = rewriter.createOrFold<arith::MulIOp>(
+                forOp.getLoc(), forOp.getStep(),
+                rewriter.createOrFold<arith::ConstantOp>(
+                    forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+            forOp.setStep(newStep);
+        }
+        void finalUnrollJam(Value forOpInductionVar,
+         SmallVector<std::pair<Block::iterator, Block::iterator>> &subBlocks,
+         SmallVector<IRMapping> &mapper, const SmallVector<ForOp> &newInnerLoops, ForOp forOp) {
+            for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+                for (auto &subBlock : subBlocks) {
+                    OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+                    if (!forOpInductionVar.use_empty()) {
+                        auto updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar, builder, forOp.getStep());
+                        mapper[i - 1].map(forOpInductionVar, updatedInductionVar);
+                    }
+                    for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+                        builder.clone(*it, mapper[i - 1]);
+                }
+                for (auto newForOp : newInnerLoops) {
+                    unsigned oldNumIterOperands = newForOp.getNumRegionIterArgs() / unrollJamFactor;
+                    unsigned numControlOperands = newForOp.getNumControlOperands();
+                    auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+                    unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+                    for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+                        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                                            mapper[i - 1].lookupOrDefault(newForOp.getOperand(numControlOperands + j)));
+                        yieldOp.setOperand(i * oldNumYieldOperands + j,
+                                        mapper[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+                    }
+                }
+            }
+        }
+        LogicalResult loopUnrollJamByFactor(ForOp forOp)
+        {
+            if(unrollJamFactor ==1) return success();
+            if(!invariantBounds(forOp))return failure();
+            std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+            if (unrollJamFactor > *tripCount) unrollJamFactor = *tripCount;
+            else if (*tripCount % unrollJamFactor != 0) return failure();
+            if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success();
+            
+            auto bg = ForBlockGatherer();
+            bg.walk(forOp);
+            auto &subBlocks = bg.subBlocks;
+            
+            SmallVector<ForOp> innerLoops;
+            forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+            SmallVector<IRMapping> mapper(unrollJamFactor - 1);
+            IRRewriter rewriter(forOp.getContext());
+            SmallVector<ForOp> newInnerLoops;
+            duplicateArgs(mapper, newInnerLoops, forOp, rewriter, innerLoops, unrollJamFactor);
+            updateStep(forOp,rewriter); 
+            auto forOpInductionVar = forOp.getInductionVar();
+            finalUnrollJam(forOpInductionVar,subBlocks,mapper,newInnerLoops,forOp);
+            (void)forOp.promoteIfSingleIteration(rewriter);
+            return success();
+        } 
+
+    };
+} // namespace ends
+
+std::unique_ptr<Pass> mlir::createLoopUnrollJam(int unrollJamFactor){
+        return std::make_unique<LoopUnrollJam>(
+            unrollJamFactor == -1 ? std::nullopt 
+            : std::optional<unsigned>(unrollJamFactor));
+    }
\ No newline at end of file

>From 5bfa0af50037f1ecc32aaf07242a0c8870611069 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 10 Jul 2024 03:46:09 +0000
Subject: [PATCH 11/13] Fix: Add code to run loop unroll jam without
 reductionsand add test cases

---
 .../Dialect/SCF/Transforms/LoopUnrollJam.cpp  | 39 +++++++++++----
 .../test/Dialect/SCF/loop-unroll-jam-scf.mlir | 47 +++++++++++++++++++
 2 files changed, 78 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
index 82ea65f247e8f..6679816e9d77a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -16,6 +16,7 @@ File containing pass for loop unroll and jam transformation
 #include "mlir/Support/LLVM.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include <map>
 namespace mlir {
     #define GEN_PASS_DEF_SCFLOOPUNROLLJAM
     #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -27,11 +28,6 @@ using scf::ForOp;
 
 
 namespace{
-    struct LoopReduction {
-        arith::AtomicRMWKind kind;
-        unsigned position;
-        Value reducedValue;
-    };
     struct ForBlockGatherer{
         SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
         void walk(Operation *op) {
@@ -63,11 +59,25 @@ namespace{
             }
         void runOnOperation() override
         {
+            std::map<ForOp, int> outerLoopMap;
+            std::map<ForOp, int> innerLoopMap;
             auto *op = getOperation();
             op->walk([&](ForOp forOp)
             {
-                (void)loopUnrollJamByFactor(forOp);
+                outerLoopMap[forOp]=0;
+                innerLoopMap[forOp]=0;
             });
+            op->walk([&](ForOp forOp)
+            {
+                loopGatherer(forOp,outerLoopMap,innerLoopMap);
+            });
+            for(auto ele :innerLoopMap)
+            {
+                if(ele.second ==0)
+                {
+                    (void)loopUnrollJamByFactor(ele.first);
+                }
+            }
         }
         std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
             // This is a placeholder implementation. You need to adapt it to your use case.
@@ -95,6 +105,21 @@ namespace{
             });
             return !walkResult.wasInterrupted();
         }
+        void loopNestMapper(ForOp op,std::map<ForOp, int> &outerLoopMap,
+        std::map<ForOp, int> &innerLoopMap) {
+            /*Identifies if a given for loop is the inner most scf ForOp*/
+            op.getBody()->walk([&](ForOp nestedForOp){
+                outerLoopMap[op]+=1;
+                innerLoopMap[nestedForOp]+=1;
+                });
+        }
+        void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
+        std::map<ForOp, int> &innerLoopMap) {
+            /*Function gathers all the innermost scf for loops*/
+            op->walk([&](ForOp forOp) {
+                loopNestMapper(forOp,outerLoopMap,innerLoopMap);
+            });
+        }
 
         void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
          ForOp &forOp, IRRewriter &rewriter,
@@ -149,7 +174,6 @@ namespace{
                 rewriter.setInsertionPointToEnd(prevBlock);
             else
                 rewriter.setInsertionPoint(forOp);
-            // rewriter.setInsertionPoint(forOp);
             auto newStep = rewriter.createOrFold<arith::MulIOp>(
                 forOp.getLoc(), forOp.getStep(),
                 rewriter.createOrFold<arith::ConstantOp>(
@@ -186,7 +210,6 @@ namespace{
         LogicalResult loopUnrollJamByFactor(ForOp forOp)
         {
             if(unrollJamFactor ==1) return success();
-            if(!invariantBounds(forOp))return failure();
             std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
             if (unrollJamFactor > *tripCount) unrollJamFactor = *tripCount;
             else if (*tripCount % unrollJamFactor != 0) return failure();
diff --git a/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
new file mode 100644
index 0000000000000..2e4b331c59c03
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --scf-loop-unroll-jam="unroll-jam-factor=2"  --split-input-file | FileCheck %s
+
+module {
+  func.func @main() -> f32 {
+    %sum = arith.constant 0.0 : f32
+    %val = arith.constant 2.0 : f32
+    %N = arith.constant 4 : index
+    %num = arith.constant 10 : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+
+    %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+      %new_sum = arith.addf %iter_sum, %val : f32
+      %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %new_sum) -> (f32) {
+        %new_sum2 = arith.addf %iter_sum2, %val : f32
+        scf.yield %new_sum2 : f32
+      }
+      %new_sum3 = arith.addf %result2, %val : f32
+      scf.yield %new_sum : f32
+    }
+    return %result : f32
+  }
+}
+
+// CHECK-LABEL: func.func @main() -> f32 {
+// CHECK-NEXT:   %cst = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:   %cst_0 = arith.constant 2.000000e+00 : f32
+// CHECK-NEXT:   %c4 = arith.constant 4 : index
+// CHECK-NEXT:   %c10 = arith.constant 10 : index
+// CHECK-NEXT:   %c0 = arith.constant 0 : index
+// CHECK-NEXT:   %c1 = arith.constant 1 : index
+// CHECK-NEXT:   %c2 = arith.constant 2 : index
+// CHECK-NEXT:   %c2_1 = arith.constant 2 : index
+// CHECK-NEXT:   %0:2 = scf.for %arg0 = %c0 to %c4 step %c2_1 iter_args(%arg1 = %cst, %arg2 = %cst) -> (f32, f32) {
+// CHECK-NEXT:     %1 = arith.addf %arg1, %cst_0 : f32
+// CHECK-NEXT:     %2 = arith.addf %arg2, %cst_0 : f32
+// CHECK-NEXT:     %3:2 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (f32, f32) {
+// CHECK-NEXT:       %6 = arith.addf %arg4, %cst_0 : f32
+// CHECK-NEXT:       %7 = arith.addf %arg5, %cst_0 : f32
+// CHECK-NEXT:       scf.yield %6, %7 : f32, f32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %4 = arith.addf %3#0, %cst_0 : f32
+// CHECK-NEXT:     %5 = arith.addf %3#1, %cst_0 : f32
+// CHECK-NEXT:     scf.yield %1, %2 : f32, f32
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %0#0 : f32
+// CHECK-NEXT: }

>From 025240d160d0376232af864d73fda8bd84aa288f Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 10 Jul 2024 11:33:44 +0000
Subject: [PATCH 12/13] Feat: Add code to support reductions

---
 .../Dialect/SCF/Transforms/LoopUnrollJam.cpp  | 97 ++++++++++++++++++-
 1 file changed, 92 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
index 6679816e9d77a..7319af950ed3f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -16,7 +16,10 @@ File containing pass for loop unroll and jam transformation
 #include "mlir/Support/LLVM.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include <map>
+#include "../lib/Analysis/SliceAnalysis.cpp"
+#include "llvm/ADT/TypeSwitch.h"
 namespace mlir {
     #define GEN_PASS_DEF_SCFLOOPUNROLLJAM
     #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -24,6 +27,7 @@ namespace mlir {
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::scf;
+using namespace mlir::affine;
 using scf::ForOp;
 
 
@@ -80,7 +84,6 @@ namespace{
             }
         }
         std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
-            // This is a placeholder implementation. You need to adapt it to your use case.
             auto lowerBoundConst = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
             auto upperBoundConst = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
             auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
@@ -107,7 +110,6 @@ namespace{
         }
         void loopNestMapper(ForOp op,std::map<ForOp, int> &outerLoopMap,
         std::map<ForOp, int> &innerLoopMap) {
-            /*Identifies if a given for loop is the inner most scf ForOp*/
             op.getBody()->walk([&](ForOp nestedForOp){
                 outerLoopMap[op]+=1;
                 innerLoopMap[nestedForOp]+=1;
@@ -115,12 +117,10 @@ namespace{
         }
         void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
         std::map<ForOp, int> &innerLoopMap) {
-            /*Function gathers all the innermost scf for loops*/
             op->walk([&](ForOp forOp) {
                 loopNestMapper(forOp,outerLoopMap,innerLoopMap);
             });
         }
-
         void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
          ForOp &forOp, IRRewriter &rewriter,
           const SmallVector<ForOp> &innerLoops, unsigned unrollJamFactor) {
@@ -160,7 +160,6 @@ namespace{
 
         Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
                                   Value step) {
-            /*Function writes ir to update the induction variable*/
             Location loc = iv.getLoc();
             Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
             Value increment =builder.create<arith::MulIOp>(loc, constantI, step);
@@ -207,6 +206,85 @@ namespace{
                 }
             }
         }
+        void getReductions(ForOp forOp, SmallVectorImpl<LoopReduction> &reductions) {
+            ValueRange iterArgs = forOp.getRegionIterArgs();
+            unsigned numIterArgs = iterArgs.size();
+
+            if (numIterArgs == 0)
+            {
+                llvm::errs()<<"Iter args == 0";
+                return;   
+            }
+            llvm::errs()<<"Iter args"<<numIterArgs;
+            reductions.reserve(numIterArgs);
+            for (unsigned i = 0; i < numIterArgs; ++i) {
+                arith::AtomicRMWKind kind;
+                if (Value value = getReduction(forOp, i, kind))
+                {
+                    llvm::errs()<<"Inside here \n";
+                    reductions.emplace_back(LoopReduction{kind, i, value});
+                    llvm::errs()<<"\nValue "<<value;
+                }
+                
+            }
+        }
+        Value getReduction(ForOp forOp, unsigned pos,
+                                    arith::AtomicRMWKind &kind) {
+            SmallVector<Operation *> combinerOps;
+            Value reducedVal = matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
+            if (!reducedVal)return nullptr;
+    
+            if (combinerOps.size() > 1)return nullptr;
+            
+            Operation *combinerOp = combinerOps.back();
+            std::optional<arith::AtomicRMWKind> maybeKind =
+                mlir::TypeSwitch<Operation *, std::optional<arith::AtomicRMWKind>>(combinerOp)
+                    .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
+                    .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
+                    .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
+                    .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
+                    .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
+                    .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
+                    .Case(
+                        [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
+                    .Case(
+                        [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
+                    .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
+                    .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
+                    .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
+                    .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+                    .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
+                        return std::nullopt;
+                    });
+            if (!maybeKind)
+                return nullptr;
+            
+            kind = *maybeKind;
+            return reducedVal;
+        }
+        void updateReductions(ForOp forOp, IRRewriter &rewriter,
+        SmallVector<LoopReduction> &reductions)
+        {
+            rewriter.setInsertionPointAfter(forOp);
+            auto loc = forOp.getLoc();
+            unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
+            for (LoopReduction &reduction : reductions) {
+                unsigned pos = reduction.iterArgPosition;
+                Value lhs = forOp.getResult(pos);
+                Value rhs;
+                SmallPtrSet<Operation *, 4> newOps;
+                for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+                    rhs = forOp.getResult(i * oldNumResults + pos);
+                    lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
+                    if (!lhs)
+                    return;
+                    Operation *op = lhs.getDefiningOp();
+                    assert(op && "Reduction op should have been created");
+                    newOps.insert(op);
+                }
+            forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
+            }
+        }   
         LogicalResult loopUnrollJamByFactor(ForOp forOp)
         {
             if(unrollJamFactor ==1) return success();
@@ -222,6 +300,12 @@ namespace{
             SmallVector<ForOp> innerLoops;
             forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
 
+            SmallVector<LoopReduction> reductions;
+            ValueRange iterArgs = forOp.getRegionIterArgs();
+            unsigned numIterOperands = iterArgs.size();
+            if (numIterOperands > 0)
+                getReductions(forOp, reductions);
+
             SmallVector<IRMapping> mapper(unrollJamFactor - 1);
             IRRewriter rewriter(forOp.getContext());
             SmallVector<ForOp> newInnerLoops;
@@ -229,6 +313,9 @@ namespace{
             updateStep(forOp,rewriter); 
             auto forOpInductionVar = forOp.getInductionVar();
             finalUnrollJam(forOpInductionVar,subBlocks,mapper,newInnerLoops,forOp);
+            if (forOp.getNumResults() > 0) {
+                updateReductions(forOp,rewriter,reductions);
+            }
             (void)forOp.promoteIfSingleIteration(rewriter);
             return success();
         } 

>From 933f41944dd5416aa85b62626fd4b15a51495533 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 10 Jul 2024 12:35:20 +0000
Subject: [PATCH 13/13] Refactor and Test: Refactor the code and test case

---
 .../Dialect/SCF/Transforms/LoopUnrollJam.cpp  | 583 +++++++++---------
 .../test/Dialect/SCF/loop-unroll-jam-scf.mlir |  35 +-
 2 files changed, 311 insertions(+), 307 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
index 7319af950ed3f..ef1d1f291f43d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -5,24 +5,24 @@ File containing pass for loop unroll and jam transformation
 
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
+#include "../lib/Analysis/SliceAnalysis.cpp"
+#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Support/LLVM.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/LoopLikeInterface.h"
-#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
-#include <map>
-#include "../lib/Analysis/SliceAnalysis.cpp"
 #include "llvm/ADT/TypeSwitch.h"
+#include <map>
 namespace mlir {
-    #define GEN_PASS_DEF_SCFLOOPUNROLLJAM
-    #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+#define GEN_PASS_DEF_SCFLOOPUNROLLJAM
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
 } // namespace mlir
 using namespace llvm;
 using namespace mlir;
@@ -30,301 +30,304 @@ using namespace mlir::scf;
 using namespace mlir::affine;
 using scf::ForOp;
 
+namespace {
+struct ForBlockGatherer {
+  SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+  void walk(Operation *op) {
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        walk(block);
+  }
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (Block::iterator it = block.begin(), e = std::prev(block.end());
+         it != e;) {
+      Block::iterator subBlockStart = it;
+      while (it != e && !isa<ForOp>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      while (it != e && isa<ForOp>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+struct LoopUnrollJam : public impl::SCFLoopUnrollJamBase<LoopUnrollJam> {
+  explicit LoopUnrollJam(
+      std::optional<unsigned> unrollJamFactor = std::nullopt) {
+    if (unrollJamFactor)
+      this->unrollJamFactor = *unrollJamFactor;
+  }
+  void runOnOperation() override {
+    std::map<ForOp, int> outerLoopMap;
+    std::map<ForOp, int> innerLoopMap;
+    auto *op = getOperation();
+    op->walk([&](ForOp forOp) {
+      outerLoopMap[forOp] = 0;
+      innerLoopMap[forOp] = 0;
+    });
+    op->walk(
+        [&](ForOp forOp) { loopGatherer(forOp, outerLoopMap, innerLoopMap); });
+    for (auto ele : innerLoopMap) {
+      if (ele.second == 0) {
+        (void)loopUnrollJamByFactor(ele.first);
+      }
+    }
+  }
+
+  // Obtains trip count for a given for op
+  std::optional<uint64_t> getTripCount(scf::ForOp forOp) {
+    auto lowerBoundConst =
+        forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+    auto upperBoundConst =
+        forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+    auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+
+    if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+      return std::nullopt;
+    }
+
+    int64_t lowerBoundValue = lowerBoundConst.value();
+    int64_t upperBoundValue = upperBoundConst.value();
+    int64_t stepValue = stepConst.value();
+    return (upperBoundValue - lowerBoundValue) / stepValue;
+  }
+  // Maps the nesting levels of the loops
+  void loopNestMapper(ForOp op, std::map<ForOp, int> &outerLoopMap,
+                      std::map<ForOp, int> &innerLoopMap) {
+    op.getBody()->walk([&](ForOp nestedForOp) {
+      outerLoopMap[op] += 1;
+      innerLoopMap[nestedForOp] += 1;
+    });
+  }
 
-namespace{
-    struct ForBlockGatherer{
-        SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-        void walk(Operation *op) {
-            for (Region &region : op->getRegions())
-            for (Block &block : region)
-                walk(block);
+  // Gathers loops to map their nesting levels
+  void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
+                    std::map<ForOp, int> &innerLoopMap) {
+    op->walk([&](ForOp forOp) {
+      loopNestMapper(forOp, outerLoopMap, innerLoopMap);
+    });
+  }
+
+  // duplicates loop argument for unrolling and jamming
+  void duplicateArgs(SmallVector<IRMapping> &mapper,
+                     SmallVector<ForOp> &newInnerLoops, ForOp &forOp,
+                     IRRewriter &rewriter, const SmallVector<ForOp> &innerLoops,
+                     unsigned unrollJamFactor) {
+    for (scf::ForOp currentForOp : innerLoops) {
+      SmallVector<Value> iterOperandsList, yeildOperandsList;
+      ValueRange previousIterOperands = currentForOp.getInits();
+      ValueRange previousIterArgs = currentForOp.getRegionIterArgs();
+      ValueRange previousYeildOperands =
+          cast<YieldOp>(currentForOp.getBody()->getTerminator()).getOperands();
+      for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+        iterOperandsList.append(previousIterOperands.begin(),
+                                previousIterOperands.end());
+        yeildOperandsList.append(previousYeildOperands.begin(),
+                                 previousYeildOperands.end());
+      }
+      bool forOpReplaced = currentForOp == forOp;
+      scf::ForOp newForOp =
+          cast<scf::ForOp>(*currentForOp.replaceWithAdditionalYields(
+              rewriter, iterOperandsList,
+              /*replaceInitOperandUsesInLoop=*/false,
+              [&](OpBuilder &b, Location loc,
+                  ArrayRef<BlockArgument> newBbArgs) {
+                return yeildOperandsList;
+              }));
+      newInnerLoops.push_back(newForOp);
+
+      if (forOpReplaced)
+        forOp = newForOp;
+      ValueRange newIterArgs = newForOp.getRegionIterArgs();
+      unsigned oldNumIterArgs = previousIterArgs.size();
+      ValueRange newResults = newForOp.getResults();
+      unsigned oldNumResults = newResults.size() / unrollJamFactor;
+      for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+        for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+          mapper[i - 1].map(newIterArgs[j],
+                            newIterArgs[i * oldNumIterArgs + j]);
+          mapper[i - 1].map(newResults[j], newResults[i * oldNumResults + j]);
         }
-        void walk(Block &block) {
-            assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
-                "expected block to have a terminator");
-            for (Block::iterator it = block.begin(), e = std::prev(block.end());
-                it != e;) {
-            Block::iterator subBlockStart = it;
-            while (it != e && !isa<ForOp>(&*it))
-                ++it;
-            if (it != subBlockStart)
-                subBlocks.emplace_back(subBlockStart, std::prev(it));
-            // Process all for ops that appear next.
-            while (it != e && isa<ForOp>(&*it))
-                walk(&*it++);
-            }
+      }
+    }
+  }
+  // creates an updated induction variable
+  Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
+                                  Value step) {
+    Location loc = iv.getLoc();
+    Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+    Value increment = builder.create<arith::MulIOp>(loc, constantI, step);
+    Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+    return updatedIv;
+  }
+  // updates the step of a given loop
+  void updateStep(ForOp forOp, IRRewriter &rewriter) {
+    if (Block *prevBlock = forOp->getBlock()->getPrevNode())
+      rewriter.setInsertionPointToEnd(prevBlock);
+    else
+      rewriter.setInsertionPoint(forOp);
+    auto newStep = rewriter.createOrFold<arith::MulIOp>(
+        forOp.getLoc(), forOp.getStep(),
+        rewriter.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+    forOp.setStep(newStep);
+  }
+
+  // performs the final unroll and jam operations, cloning and updating the
+  // subblocks
+  void finalUnrollJam(
+      Value forOpInductionVar,
+      SmallVector<std::pair<Block::iterator, Block::iterator>> &subBlocks,
+      SmallVector<IRMapping> &mapper, const SmallVector<ForOp> &newInnerLoops,
+      ForOp forOp) {
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (auto &subBlock : subBlocks) {
+        OpBuilder builder(subBlock.first->getBlock(),
+                          std::next(subBlock.second));
+        if (!forOpInductionVar.use_empty()) {
+          auto updatedInductionVar = createUpdatedInductionVar(
+              i, forOpInductionVar, builder, forOp.getStep());
+          mapper[i - 1].map(forOpInductionVar, updatedInductionVar);
         }
-        };
-    struct LoopUnrollJam : public impl::SCFLoopUnrollJamBase<LoopUnrollJam>
-    {
-        explicit LoopUnrollJam(
-            std::optional<unsigned> unrollJamFactor= std::nullopt){
-                if(unrollJamFactor) this->unrollJamFactor = *unrollJamFactor;
-            }
-        void runOnOperation() override
-        {
-            std::map<ForOp, int> outerLoopMap;
-            std::map<ForOp, int> innerLoopMap;
-            auto *op = getOperation();
-            op->walk([&](ForOp forOp)
-            {
-                outerLoopMap[forOp]=0;
-                innerLoopMap[forOp]=0;
-            });
-            op->walk([&](ForOp forOp)
-            {
-                loopGatherer(forOp,outerLoopMap,innerLoopMap);
-            });
-            for(auto ele :innerLoopMap)
-            {
-                if(ele.second ==0)
-                {
-                    (void)loopUnrollJamByFactor(ele.first);
-                }
-            }
+        for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+          builder.clone(*it, mapper[i - 1]);
+      }
+      for (auto newForOp : newInnerLoops) {
+        unsigned oldNumIterOperands =
+            newForOp.getNumRegionIterArgs() / unrollJamFactor;
+        unsigned numControlOperands = newForOp.getNumControlOperands();
+        auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+        unsigned oldNumYieldOperands =
+            yieldOp.getNumOperands() / unrollJamFactor;
+        for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+          newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                              mapper[i - 1].lookupOrDefault(
+                                  newForOp.getOperand(numControlOperands + j)));
+          yieldOp.setOperand(
+              i * oldNumYieldOperands + j,
+              mapper[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
         }
-        std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
-            auto lowerBoundConst = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
-            auto upperBoundConst = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
-            auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+      }
+    }
+  }
 
-            if (!lowerBoundConst || !upperBoundConst || !stepConst) {
-                return std::nullopt;
-            }
+  // Gathers reduction operations in the loop
+  void getReductions(ForOp forOp, SmallVectorImpl<LoopReduction> &reductions) {
+    ValueRange iterArgs = forOp.getRegionIterArgs();
+    unsigned numIterArgs = iterArgs.size();
 
-            int64_t lowerBoundValue = lowerBoundConst.value();
-            int64_t upperBoundValue = upperBoundConst.value();
-            int64_t stepValue = stepConst.value();
-            return (upperBoundValue - lowerBoundValue) / stepValue;
-            }
-        static bool invariantBounds(scf::ForOp forOp) {
-            auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
-                if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
-                    !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
-                    !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
-                return WalkResult::interrupt();
-            
-                return WalkResult::advance();
-            });
-            return !walkResult.wasInterrupted();
-        }
-        void loopNestMapper(ForOp op,std::map<ForOp, int> &outerLoopMap,
-        std::map<ForOp, int> &innerLoopMap) {
-            op.getBody()->walk([&](ForOp nestedForOp){
-                outerLoopMap[op]+=1;
-                innerLoopMap[nestedForOp]+=1;
-                });
-        }
-        void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
-        std::map<ForOp, int> &innerLoopMap) {
-            op->walk([&](ForOp forOp) {
-                loopNestMapper(forOp,outerLoopMap,innerLoopMap);
+    if (numIterArgs == 0) {
+      return;
+    }
+    reductions.reserve(numIterArgs);
+    for (unsigned i = 0; i < numIterArgs; ++i) {
+      arith::AtomicRMWKind kind;
+      if (Value value = getReduction(forOp, i, kind)) {
+        reductions.emplace_back(LoopReduction{kind, i, value});
+      }
+    }
+  }
+
+  // Matches the reduction operation within the loop
+  Value getReduction(ForOp forOp, unsigned pos, arith::AtomicRMWKind &kind) {
+    SmallVector<Operation *> combinerOps;
+    Value reducedVal =
+        matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
+    if (!reducedVal)
+      return nullptr;
+
+    if (combinerOps.size() > 1)
+      return nullptr;
+
+    Operation *combinerOp = combinerOps.back();
+    std::optional<arith::AtomicRMWKind> maybeKind =
+        mlir::TypeSwitch<Operation *, std::optional<arith::AtomicRMWKind>>(
+            combinerOp)
+            .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
+            .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
+            .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
+            .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
+            .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
+            .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
+            .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
+              return std::nullopt;
             });
-        }
-        void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
-         ForOp &forOp, IRRewriter &rewriter,
-          const SmallVector<ForOp> &innerLoops, unsigned unrollJamFactor) {
-            for (scf::ForOp currentForOp : innerLoops) {
-                SmallVector<Value> iterOperandsList, yeildOperandsList;
-                ValueRange previousIterOperands = currentForOp.getInits();
-                ValueRange previousIterArgs = currentForOp.getRegionIterArgs();
-                ValueRange previousYeildOperands = cast<YieldOp>(currentForOp.getBody()->getTerminator()).getOperands();
-                for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
-                    iterOperandsList.append(previousIterOperands.begin(), previousIterOperands.end());
-                    yeildOperandsList.append(previousYeildOperands.begin(), previousYeildOperands.end());
-                }
-                bool forOpReplaced = currentForOp == forOp;
-                scf::ForOp newForOp =
-                    cast<scf::ForOp>(*currentForOp.replaceWithAdditionalYields(
-                        rewriter, iterOperandsList, /*replaceInitOperandUsesInLoop=*/false,
-                        [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
-                            return yeildOperandsList;
-                        }));
-                newInnerLoops.push_back(newForOp);
+    if (!maybeKind)
+      return nullptr;
 
-                if (forOpReplaced) forOp = newForOp;
-                ValueRange newIterArgs = newForOp.getRegionIterArgs();
-                unsigned oldNumIterArgs = previousIterArgs.size();
-                ValueRange newResults = newForOp.getResults();
-                unsigned oldNumResults = newResults.size() / unrollJamFactor;
-                for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
-                    for (unsigned j = 0; j < oldNumIterArgs; ++j) {
-                        mapper[i - 1].map(newIterArgs[j],
-                                        newIterArgs[i * oldNumIterArgs + j]);
-                        mapper[i - 1].map(newResults[j],
-                                        newResults[i * oldNumResults + j]);
-                    }
-                }
-            }
-        }
+    kind = *maybeKind;
+    return reducedVal;
+  }
 
-        Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
-                                  Value step) {
-            Location loc = iv.getLoc();
-            Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
-            Value increment =builder.create<arith::MulIOp>(loc, constantI, step);
-            Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
-            return updatedIv;
-        }
+  // Updates reduction operations after unrolling
+  void updateReductions(ForOp forOp, IRRewriter &rewriter,
+                        SmallVector<LoopReduction> &reductions) {
+    rewriter.setInsertionPointAfter(forOp);
+    auto loc = forOp.getLoc();
+    unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
+    for (LoopReduction &reduction : reductions) {
+      unsigned pos = reduction.iterArgPosition;
+      Value lhs = forOp.getResult(pos);
+      Value rhs;
+      SmallPtrSet<Operation *, 4> newOps;
+      for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+        rhs = forOp.getResult(i * oldNumResults + pos);
+        lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
+        if (!lhs)
+          return;
+        Operation *op = lhs.getDefiningOp();
+        newOps.insert(op);
+      }
+      forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
+    }
+  }
 
-        void updateStep(ForOp forOp,IRRewriter &rewriter)
-        {
-            if (Block *prevBlock = forOp->getBlock()->getPrevNode())
-                rewriter.setInsertionPointToEnd(prevBlock);
-            else
-                rewriter.setInsertionPoint(forOp);
-            auto newStep = rewriter.createOrFold<arith::MulIOp>(
-                forOp.getLoc(), forOp.getStep(),
-                rewriter.createOrFold<arith::ConstantOp>(
-                    forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
-            forOp.setStep(newStep);
-        }
-        void finalUnrollJam(Value forOpInductionVar,
-         SmallVector<std::pair<Block::iterator, Block::iterator>> &subBlocks,
-         SmallVector<IRMapping> &mapper, const SmallVector<ForOp> &newInnerLoops, ForOp forOp) {
-            for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
-                for (auto &subBlock : subBlocks) {
-                    OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
-                    if (!forOpInductionVar.use_empty()) {
-                        auto updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar, builder, forOp.getStep());
-                        mapper[i - 1].map(forOpInductionVar, updatedInductionVar);
-                    }
-                    for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
-                        builder.clone(*it, mapper[i - 1]);
-                }
-                for (auto newForOp : newInnerLoops) {
-                    unsigned oldNumIterOperands = newForOp.getNumRegionIterArgs() / unrollJamFactor;
-                    unsigned numControlOperands = newForOp.getNumControlOperands();
-                    auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
-                    unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
-                    for (unsigned j = 0; j < oldNumIterOperands; ++j) {
-                        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
-                                            mapper[i - 1].lookupOrDefault(newForOp.getOperand(numControlOperands + j)));
-                        yieldOp.setOperand(i * oldNumYieldOperands + j,
-                                        mapper[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
-                    }
-                }
-            }
-        }
-        void getReductions(ForOp forOp, SmallVectorImpl<LoopReduction> &reductions) {
-            ValueRange iterArgs = forOp.getRegionIterArgs();
-            unsigned numIterArgs = iterArgs.size();
+  // unrolls and jams a loop by a given factor
+  LogicalResult loopUnrollJamByFactor(ForOp forOp) {
+    if (unrollJamFactor == 1)
+      return success();
+    std::optional<uint64_t> tripCount = getTripCount(forOp);
+    if (unrollJamFactor > *tripCount)
+      unrollJamFactor = *tripCount;
+    else if (*tripCount % unrollJamFactor != 0)
+      return failure();
+    if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+      return success();
 
-            if (numIterArgs == 0)
-            {
-                llvm::errs()<<"Iter args == 0";
-                return;   
-            }
-            llvm::errs()<<"Iter args"<<numIterArgs;
-            reductions.reserve(numIterArgs);
-            for (unsigned i = 0; i < numIterArgs; ++i) {
-                arith::AtomicRMWKind kind;
-                if (Value value = getReduction(forOp, i, kind))
-                {
-                    llvm::errs()<<"Inside here \n";
-                    reductions.emplace_back(LoopReduction{kind, i, value});
-                    llvm::errs()<<"\nValue "<<value;
-                }
-                
-            }
-        }
-        Value getReduction(ForOp forOp, unsigned pos,
-                                    arith::AtomicRMWKind &kind) {
-            SmallVector<Operation *> combinerOps;
-            Value reducedVal = matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
-            if (!reducedVal)return nullptr;
-    
-            if (combinerOps.size() > 1)return nullptr;
-            
-            Operation *combinerOp = combinerOps.back();
-            std::optional<arith::AtomicRMWKind> maybeKind =
-                mlir::TypeSwitch<Operation *, std::optional<arith::AtomicRMWKind>>(combinerOp)
-                    .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
-                    .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
-                    .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
-                    .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
-                    .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
-                    .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
-                    .Case(
-                        [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
-                    .Case(
-                        [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
-                    .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
-                    .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
-                    .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
-                    .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
-                    .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
-                        return std::nullopt;
-                    });
-            if (!maybeKind)
-                return nullptr;
-            
-            kind = *maybeKind;
-            return reducedVal;
-        }
-        void updateReductions(ForOp forOp, IRRewriter &rewriter,
-        SmallVector<LoopReduction> &reductions)
-        {
-            rewriter.setInsertionPointAfter(forOp);
-            auto loc = forOp.getLoc();
-            unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
-            for (LoopReduction &reduction : reductions) {
-                unsigned pos = reduction.iterArgPosition;
-                Value lhs = forOp.getResult(pos);
-                Value rhs;
-                SmallPtrSet<Operation *, 4> newOps;
-                for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
-                    rhs = forOp.getResult(i * oldNumResults + pos);
-                    lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
-                    if (!lhs)
-                    return;
-                    Operation *op = lhs.getDefiningOp();
-                    assert(op && "Reduction op should have been created");
-                    newOps.insert(op);
-                }
-            forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
-            }
-        }   
-        LogicalResult loopUnrollJamByFactor(ForOp forOp)
-        {
-            if(unrollJamFactor ==1) return success();
-            std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
-            if (unrollJamFactor > *tripCount) unrollJamFactor = *tripCount;
-            else if (*tripCount % unrollJamFactor != 0) return failure();
-            if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success();
-            
-            auto bg = ForBlockGatherer();
-            bg.walk(forOp);
-            auto &subBlocks = bg.subBlocks;
-            
-            SmallVector<ForOp> innerLoops;
-            forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+    auto bg = ForBlockGatherer();
+    bg.walk(forOp);
+    auto &subBlocks = bg.subBlocks;
 
-            SmallVector<LoopReduction> reductions;
-            ValueRange iterArgs = forOp.getRegionIterArgs();
-            unsigned numIterOperands = iterArgs.size();
-            if (numIterOperands > 0)
-                getReductions(forOp, reductions);
+    SmallVector<ForOp> innerLoops;
+    forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
 
-            SmallVector<IRMapping> mapper(unrollJamFactor - 1);
-            IRRewriter rewriter(forOp.getContext());
-            SmallVector<ForOp> newInnerLoops;
-            duplicateArgs(mapper, newInnerLoops, forOp, rewriter, innerLoops, unrollJamFactor);
-            updateStep(forOp,rewriter); 
-            auto forOpInductionVar = forOp.getInductionVar();
-            finalUnrollJam(forOpInductionVar,subBlocks,mapper,newInnerLoops,forOp);
-            if (forOp.getNumResults() > 0) {
-                updateReductions(forOp,rewriter,reductions);
-            }
-            (void)forOp.promoteIfSingleIteration(rewriter);
-            return success();
-        } 
+    SmallVector<LoopReduction> reductions;
+    ValueRange iterArgs = forOp.getRegionIterArgs();
+    unsigned numIterOperands = iterArgs.size();
+    if (numIterOperands > 0)
+      getReductions(forOp, reductions);
 
-    };
-} // namespace ends
+    SmallVector<IRMapping> mapper(unrollJamFactor - 1);
+    IRRewriter rewriter(forOp.getContext());
+    SmallVector<ForOp> newInnerLoops;
+    duplicateArgs(mapper, newInnerLoops, forOp, rewriter, innerLoops,
+                  unrollJamFactor);
+    updateStep(forOp, rewriter);
+    auto forOpInductionVar = forOp.getInductionVar();
+    finalUnrollJam(forOpInductionVar, subBlocks, mapper, newInnerLoops, forOp);
+    if (forOp.getNumResults() > 0) {
+      updateReductions(forOp, rewriter, reductions);
+    }
+    (void)forOp.promoteIfSingleIteration(rewriter);
+    return success();
+  }
+};
+} // namespace
 
-std::unique_ptr<Pass> mlir::createLoopUnrollJam(int unrollJamFactor){
-        return std::make_unique<LoopUnrollJam>(
-            unrollJamFactor == -1 ? std::nullopt 
-            : std::optional<unsigned>(unrollJamFactor));
-    }
\ No newline at end of file
+std::unique_ptr<Pass> mlir::createLoopUnrollJam(int unrollJamFactor) {
+  return std::make_unique<LoopUnrollJam>(
+      unrollJamFactor == -1 ? std::nullopt
+                            : std::optional<unsigned>(unrollJamFactor));
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
index 2e4b331c59c03..4366026600811 100644
--- a/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
@@ -4,14 +4,14 @@ module {
   func.func @main() -> f32 {
     %sum = arith.constant 0.0 : f32
     %val = arith.constant 2.0 : f32
-    %N = arith.constant 4 : index
-    %num = arith.constant 10 : index
+    %N = arith.constant 16 : index
+    %num = arith.constant 16 : index
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
 
     %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
       %new_sum = arith.addf %iter_sum, %val : f32
-      %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %new_sum) -> (f32) {
+      %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %val) -> (f32) {
         %new_sum2 = arith.addf %iter_sum2, %val : f32
         scf.yield %new_sum2 : f32
       }
@@ -25,23 +25,24 @@ module {
 // CHECK-LABEL: func.func @main() -> f32 {
 // CHECK-NEXT:   %cst = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:   %cst_0 = arith.constant 2.000000e+00 : f32
-// CHECK-NEXT:   %c4 = arith.constant 4 : index
-// CHECK-NEXT:   %c10 = arith.constant 10 : index
+// CHECK-NEXT:   %c16 = arith.constant 16 : index
+// CHECK-NEXT:   %c16_1 = arith.constant 16 : index
 // CHECK-NEXT:   %c0 = arith.constant 0 : index
 // CHECK-NEXT:   %c1 = arith.constant 1 : index
 // CHECK-NEXT:   %c2 = arith.constant 2 : index
-// CHECK-NEXT:   %c2_1 = arith.constant 2 : index
-// CHECK-NEXT:   %0:2 = scf.for %arg0 = %c0 to %c4 step %c2_1 iter_args(%arg1 = %cst, %arg2 = %cst) -> (f32, f32) {
-// CHECK-NEXT:     %1 = arith.addf %arg1, %cst_0 : f32
-// CHECK-NEXT:     %2 = arith.addf %arg2, %cst_0 : f32
-// CHECK-NEXT:     %3:2 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (f32, f32) {
-// CHECK-NEXT:       %6 = arith.addf %arg4, %cst_0 : f32
-// CHECK-NEXT:       %7 = arith.addf %arg5, %cst_0 : f32
-// CHECK-NEXT:       scf.yield %6, %7 : f32, f32
+// CHECK-NEXT:   %c2_2 = arith.constant 2 : index
+// CHECK-NEXT:   %0:2 = scf.for %arg0 = %c0 to %c16 step %c2_2 iter_args(%arg1 = %cst, %arg2 = %cst) -> (f32, f32) {
+// CHECK-NEXT:     %2 = arith.addf %arg1, %cst_0 : f32
+// CHECK-NEXT:     %3 = arith.addf %arg2, %cst_0 : f32
+// CHECK-NEXT:     %4:2 = scf.for %arg3 = %c0 to %c16_1 step %c1 iter_args(%arg4 = %cst_0, %arg5 = %cst_0) -> (f32, f32) {
+// CHECK-NEXT:       %7 = arith.addf %arg4, %cst_0 : f32
+// CHECK-NEXT:       %8 = arith.addf %arg5, %cst_0 : f32
+// CHECK-NEXT:       scf.yield %7, %8 : f32, f32
 // CHECK-NEXT:     }
-// CHECK-NEXT:     %4 = arith.addf %3#0, %cst_0 : f32
-// CHECK-NEXT:     %5 = arith.addf %3#1, %cst_0 : f32
-// CHECK-NEXT:     scf.yield %1, %2 : f32, f32
+// CHECK-NEXT:     %5 = arith.addf %4#0, %cst_0 : f32
+// CHECK-NEXT:     %6 = arith.addf %4#1, %cst_0 : f32
+// CHECK-NEXT:     scf.yield %2, %3 : f32, f32
 // CHECK-NEXT:   }
-// CHECK-NEXT:   return %0#0 : f32
+// CHECK-NEXT:   %1 = arith.addf %0#0, %0#1 : f32
+// CHECK-NEXT:   return %1 : f32
 // CHECK-NEXT: }



More information about the Mlir-commits mailing list