[Mlir-commits] [mlir] Loop unroll jam for SCF (PR #98316)
Anirudh Sathish
llvmlistbot at llvm.org
Wed Jul 10 06:14:06 PDT 2024
https://github.com/Anirudh-Sathish created https://github.com/llvm/llvm-project/pull/98316
### Code for loop unroll jam of SCF
>From 5080a4efbedfe63d2fce780c1beba895e746bc9a Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 19 Jun 2024 16:42:09 +0000
Subject: [PATCH 01/13] Feat: Add code for matrix multiplication in toy
language
---
mlir/examples/toy/Ch7/include/toy/Ops.td | 18 ++++++
mlir/examples/toy/Ch7/mlir/Dialect.cpp | 44 ++++++++++++++
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 57 ++++++++++++++++++-
mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 10 ++++
4 files changed, 128 insertions(+), 1 deletion(-)
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index cfd6859eb27bf..f629efb057230 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -450,4 +450,22 @@ def TransposeOp : Toy_Op<"transpose",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : Toy_Op<"matmul",
+ [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]>
+{
+ let summary = "matrix multiplication operation";
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor:$results);
+ let builders = [
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+ ];
+ let assemblyFormat = "attr-dict $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($results)";
+ let hasVerifier = 1;
+}
+
+
#endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index b268b1ef157f9..2668993da789c 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -498,6 +498,50 @@ mlir::LogicalResult TransposeOp::verify() {
return mlir::success();
}
+
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+void MatmulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ mlir::Value lhs, mlir::Value rhs) {
+ state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+ state.addOperands({lhs,rhs});
+}
+
+void MatmulOp::inferShapes() {
+ auto lhsType = llvm::dyn_cast<RankedTensorType>(getOperand(0).getType());
+ auto rhsType = llvm::dyn_cast<RankedTensorType>(getOperand(1).getType());
+
+ // Ensure that both operands are ranked tensor types
+ if (!lhsType || !rhsType) {
+ emitError("Both operands of MatmulOp must be ranked tensors.");
+ return;
+ }
+ // Get the shapes of the operands
+ ArrayRef<int64_t> lhsShape = lhsType.getShape();
+ ArrayRef<int64_t> rhsShape = rhsType.getShape();
+ // Check that the dimensions are compatible for matrix multiplication
+ if (lhsShape.size() != 2 || rhsShape.size() != 2 || lhsShape[1] != rhsShape[0]) {
+ emitError("MatmulOp requires lhs of shape (m, k) and rhs of shape (k, n).");
+ return;
+ }
+
+ // Infer the shape of the result
+ SmallVector<int64_t, 2> resultShape = {lhsShape[0], rhsShape[1]};
+ getResult().setType(RankedTensorType::get(resultShape, lhsType.getElementType()));
+
+}
+
+mlir::LogicalResult MatmulOp::verify() {
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!inputType || !resultType)
+ return mlir::success();
+
+ return mlir::success();
+}
+
+
//===----------------------------------------------------------------------===//
// Toy Types
//===----------------------------------------------------------------------===//
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index bded61542188e..f2b55cde6926e 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -316,6 +316,61 @@ struct TransposeOpLowering : public ConversionPattern {
}
};
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Matmul operations
+//===----------------------------------------------------------------------===//
+
+struct MatmulOpLowering : public ConversionPattern {
+ MatmulOpLowering(MLIRContext *ctx)
+ : ConversionPattern(toy::MatmulOp::getOperationName(), 1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ toy::MatmulOpAdaptor matmulAdaptor(operands);
+ Value lhs = matmulAdaptor.getLhs();
+ Value rhs = matmulAdaptor.getRhs();
+ auto resultType = dyn_cast<MemRefType>(lhs.getType());
+ Value result = rewriter.create<memref::AllocOp>(loc, resultType);
+ auto lhsType = dyn_cast<MemRefType>(lhs.getType());
+ auto rhsType = dyn_cast<MemRefType>(rhs.getType());
+ int64_t M = lhsType.getShape()[0];
+ int64_t N = rhsType.getShape()[1];
+ int64_t K = lhsType.getShape()[1];
+ for (int64_t i = 0; i < M; ++i) {
+ for (int64_t j = 0; j < N; ++j) {
+ // Initialize the sum to zero.
+ Value sum = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0.0));
+
+ for (int64_t k = 0; k < K; ++k) {
+ // Load lhs[i, k] and rhs[k, j].
+ Value lhsVal = rewriter.create<affine::AffineLoadOp>(loc, lhs, ValueRange{
+ rewriter.create<arith::ConstantIndexOp>(loc, i),
+ rewriter.create<arith::ConstantIndexOp>(loc, k)
+ });
+ Value rhsVal = rewriter.create<affine::AffineLoadOp>(loc, rhs, ValueRange{
+ rewriter.create<arith::ConstantIndexOp>(loc, k),
+ rewriter.create<arith::ConstantIndexOp>(loc, j)
+ });
+
+ // Perform the multiplication and accumulate the result.
+ Value product = rewriter.create<arith::MulFOp>(loc, lhsVal, rhsVal);
+ sum = rewriter.create<arith::AddFOp>(loc, sum, product);
+ }
+
+ // Store the computed value into the result matrix.
+ rewriter.create<affine::AffineStoreOp>(loc, sum, result, ValueRange{
+ rewriter.create<arith::ConstantIndexOp>(loc, i),
+ rewriter.create<arith::ConstantIndexOp>(loc, j)
+ });
+ }
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -366,7 +421,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
- PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+ PrintOpLowering, ReturnOpLowering, TransposeOpLowering, MatmulOpLowering>(
&getContext());
// With the target and rewrite patterns defined, we can now attempt the
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index 0f8e8df38525f..4a27f22e5cd6c 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -526,6 +526,16 @@ class MLIRGenImpl {
return builder.create<TransposeOp>(location, operands[0]);
}
+ // Call for matmul
+ if(callee == "matmul") {
+ if(call.getArgs().size()!=2){
+ emitError(location,"MLIR codegen encountered an error: toy.matmul "
+ "only accepts two arguments");
+ return nullptr;
+ }
+ return builder.create<MatmulOp>(location,operands[0],operands[1]);
+ }
+
// Otherwise this is a call to a user-defined function. Calls to
// user-defined functions are mapped to a custom call that takes the callee
// name as an attribute.
>From 50476952b1d2bb906d010f714a5fbe7fb407c3f6 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 20 Jun 2024 05:20:49 +0000
Subject: [PATCH 02/13] Fix: Add code to ensure result is in correct shape
---
mlir/examples/toy/Ch7/include/toy/Ops.td | 18 ++++++
mlir/examples/toy/Ch7/mlir/Dialect.cpp | 44 +++++++++++++
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 64 ++++++++++++++++++-
mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 10 +++
4 files changed, 134 insertions(+), 2 deletions(-)
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index cfd6859eb27bf..f629efb057230 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -450,4 +450,22 @@ def TransposeOp : Toy_Op<"transpose",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : Toy_Op<"matmul",
+ [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]>
+{
+ let summary = "matrix multiplication operation";
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor:$results);
+ let builders = [
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
+ ];
+ let assemblyFormat = "attr-dict $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($results)";
+ let hasVerifier = 1;
+}
+
+
#endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index b268b1ef157f9..2668993da789c 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -498,6 +498,50 @@ mlir::LogicalResult TransposeOp::verify() {
return mlir::success();
}
+
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+void MatmulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ mlir::Value lhs, mlir::Value rhs) {
+ state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+ state.addOperands({lhs,rhs});
+}
+
+void MatmulOp::inferShapes() {
+ auto lhsType = llvm::dyn_cast<RankedTensorType>(getOperand(0).getType());
+ auto rhsType = llvm::dyn_cast<RankedTensorType>(getOperand(1).getType());
+
+ // Ensure that both operands are ranked tensor types
+ if (!lhsType || !rhsType) {
+ emitError("Both operands of MatmulOp must be ranked tensors.");
+ return;
+ }
+ // Get the shapes of the operands
+ ArrayRef<int64_t> lhsShape = lhsType.getShape();
+ ArrayRef<int64_t> rhsShape = rhsType.getShape();
+ // Check that the dimensions are compatible for matrix multiplication
+ if (lhsShape.size() != 2 || rhsShape.size() != 2 || lhsShape[1] != rhsShape[0]) {
+ emitError("MatmulOp requires lhs of shape (m, k) and rhs of shape (k, n).");
+ return;
+ }
+
+ // Infer the shape of the result
+ SmallVector<int64_t, 2> resultShape = {lhsShape[0], rhsShape[1]};
+ getResult().setType(RankedTensorType::get(resultShape, lhsType.getElementType()));
+
+}
+
+mlir::LogicalResult MatmulOp::verify() {
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!inputType || !resultType)
+ return mlir::success();
+
+ return mlir::success();
+}
+
+
//===----------------------------------------------------------------------===//
// Toy Types
//===----------------------------------------------------------------------===//
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index bded61542188e..5c3e6a552855b 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -25,7 +25,7 @@
#include "mlir/Support/TypeID.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
-
+#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -316,6 +316,66 @@ struct TransposeOpLowering : public ConversionPattern {
}
};
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Matmul operations
+//===----------------------------------------------------------------------===//
+
+struct MatmulOpLowering : public ConversionPattern {
+ MatmulOpLowering(MLIRContext *ctx)
+ : ConversionPattern(toy::MatmulOp::getOperationName(), 1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ toy::MatmulOpAdaptor matmulAdaptor(operands);
+ Value lhs = matmulAdaptor.getLhs();
+ Value rhs = matmulAdaptor.getRhs();
+
+ auto lhsType = dyn_cast<MemRefType>(lhs.getType());
+ auto rhsType = dyn_cast<MemRefType>(rhs.getType());
+ if (!lhsType || !rhsType) {
+ return failure();
+ }
+ int64_t M = lhsType.getShape()[0];
+ int64_t N = rhsType.getShape()[1];
+ int64_t K = lhsType.getShape()[1];
+ auto elementType = lhsType.getElementType();
+ auto resultType = MemRefType::get({M, N}, elementType);
+ Value result = rewriter.create<memref::AllocOp>(loc, resultType);
+ for (int64_t i = 0; i < M; ++i) {
+ for (int64_t j = 0; j < N; ++j) {
+ // Initialize the sum to zero.
+ Value sum = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0.0));
+
+ for (int64_t k = 0; k < K; ++k) {
+ // Load lhs[i, k] and rhs[k, j].
+ Value lhsVal = rewriter.create<affine::AffineLoadOp>(loc, lhs, ValueRange{
+ rewriter.create<arith::ConstantIndexOp>(loc, i),
+ rewriter.create<arith::ConstantIndexOp>(loc, k)
+ });
+ Value rhsVal = rewriter.create<affine::AffineLoadOp>(loc, rhs, ValueRange{
+ rewriter.create<arith::ConstantIndexOp>(loc, k),
+ rewriter.create<arith::ConstantIndexOp>(loc, j)
+ });
+
+ // Perform the multiplication and accumulate the result.
+ Value product = rewriter.create<arith::MulFOp>(loc, lhsVal, rhsVal);
+ sum = rewriter.create<arith::AddFOp>(loc, sum, product);
+ }
+
+ // Store the computed value into the result matrix.
+ rewriter.create<affine::AffineStoreOp>(loc, sum, result, ValueRange{
+ rewriter.create<arith::ConstantIndexOp>(loc, i),
+ rewriter.create<arith::ConstantIndexOp>(loc, j)
+ });
+ }
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -366,7 +426,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
- PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+ PrintOpLowering, ReturnOpLowering, TransposeOpLowering, MatmulOpLowering>(
&getContext());
// With the target and rewrite patterns defined, we can now attempt the
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index 0f8e8df38525f..4a27f22e5cd6c 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -526,6 +526,16 @@ class MLIRGenImpl {
return builder.create<TransposeOp>(location, operands[0]);
}
+ // Call for matmul
+ if(callee == "matmul") {
+ if(call.getArgs().size()!=2){
+ emitError(location,"MLIR codegen encountered an error: toy.matmul "
+ "only accepts two arguments");
+ return nullptr;
+ }
+ return builder.create<MatmulOp>(location,operands[0],operands[1]);
+ }
+
// Otherwise this is a call to a user-defined function. Calls to
// user-defined functions are mapped to a custom call that takes the callee
// name as an attribute.
>From d844cbe8753cac33fa19bc1b1f16b044154dc09e Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 20 Jun 2024 05:28:43 +0000
Subject: [PATCH 03/13] Test: Add test cases for matmul operation
---
.../Examples/Toy/Matmul/matmul1-affine.mlir | 82 +++
mlir/test/Examples/Toy/Matmul/matmul1.llvm | 239 ++++++++
mlir/test/Examples/Toy/Matmul/matmul1.mlir | 11 +
mlir/test/Examples/Toy/Matmul/matmul1.toy | 6 +
.../Examples/Toy/Matmul/matmul2-affine.mlir | 210 +++++++
mlir/test/Examples/Toy/Matmul/matmul2.llvm | 527 ++++++++++++++++++
mlir/test/Examples/Toy/Matmul/matmul2.mlir | 11 +
mlir/test/Examples/Toy/Matmul/matmul2.toy | 6 +
.../Examples/Toy/Matmul/matmul3-affine.mlir | 55 ++
mlir/test/Examples/Toy/Matmul/matmul3.llvm | 181 ++++++
mlir/test/Examples/Toy/Matmul/matmul3.mlir | 11 +
mlir/test/Examples/Toy/Matmul/matmul3.toy | 7 +
12 files changed, 1346 insertions(+)
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1.llvm
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1.mlir
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul1.toy
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2.llvm
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2.mlir
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul2.toy
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3.llvm
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3.mlir
create mode 100644 mlir/test/Examples/Toy/Matmul/matmul3.toy
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir b/mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir
new file mode 100644
index 0000000000000..f1734ab60e3d4
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1-affine.mlir
@@ -0,0 +1,82 @@
+module {
+ func.func @main() {
+ %cst = arith.constant 0.000000e+00 : f64
+ %cst_0 = arith.constant 6.000000e+00 : f64
+ %cst_1 = arith.constant 5.000000e+00 : f64
+ %cst_2 = arith.constant 4.000000e+00 : f64
+ %cst_3 = arith.constant 3.000000e+00 : f64
+ %cst_4 = arith.constant 2.000000e+00 : f64
+ %cst_5 = arith.constant 1.000000e+00 : f64
+ %alloc = memref.alloc() : memref<3x2xf64>
+ %alloc_6 = memref.alloc() : memref<2x3xf64>
+ affine.store %cst_5, %alloc_6[0, 0] : memref<2x3xf64>
+ affine.store %cst_4, %alloc_6[0, 1] : memref<2x3xf64>
+ affine.store %cst_3, %alloc_6[0, 2] : memref<2x3xf64>
+ affine.store %cst_2, %alloc_6[1, 0] : memref<2x3xf64>
+ affine.store %cst_1, %alloc_6[1, 1] : memref<2x3xf64>
+ affine.store %cst_0, %alloc_6[1, 2] : memref<2x3xf64>
+ affine.store %cst_5, %alloc[0, 0] : memref<3x2xf64>
+ affine.store %cst_4, %alloc[0, 1] : memref<3x2xf64>
+ affine.store %cst_3, %alloc[1, 0] : memref<3x2xf64>
+ affine.store %cst_2, %alloc[1, 1] : memref<3x2xf64>
+ affine.store %cst_1, %alloc[2, 0] : memref<3x2xf64>
+ affine.store %cst_0, %alloc[2, 1] : memref<3x2xf64>
+ %alloc_7 = memref.alloc() : memref<2x2xf64>
+ %0 = affine.load %alloc_6[0, 0] : memref<2x3xf64>
+ %1 = affine.load %alloc[0, 0] : memref<3x2xf64>
+ %2 = arith.mulf %0, %1 : f64
+ %3 = arith.addf %2, %cst : f64
+ %4 = affine.load %alloc_6[0, 1] : memref<2x3xf64>
+ %5 = affine.load %alloc[1, 0] : memref<3x2xf64>
+ %6 = arith.mulf %4, %5 : f64
+ %7 = arith.addf %3, %6 : f64
+ %8 = affine.load %alloc_6[0, 2] : memref<2x3xf64>
+ %9 = affine.load %alloc[2, 0] : memref<3x2xf64>
+ %10 = arith.mulf %8, %9 : f64
+ %11 = arith.addf %7, %10 : f64
+ affine.store %11, %alloc_7[0, 0] : memref<2x2xf64>
+ %12 = affine.load %alloc_6[0, 0] : memref<2x3xf64>
+ %13 = affine.load %alloc[0, 1] : memref<3x2xf64>
+ %14 = arith.mulf %12, %13 : f64
+ %15 = arith.addf %14, %cst : f64
+ %16 = affine.load %alloc_6[0, 1] : memref<2x3xf64>
+ %17 = affine.load %alloc[1, 1] : memref<3x2xf64>
+ %18 = arith.mulf %16, %17 : f64
+ %19 = arith.addf %15, %18 : f64
+ %20 = affine.load %alloc_6[0, 2] : memref<2x3xf64>
+ %21 = affine.load %alloc[2, 1] : memref<3x2xf64>
+ %22 = arith.mulf %20, %21 : f64
+ %23 = arith.addf %19, %22 : f64
+ affine.store %23, %alloc_7[0, 1] : memref<2x2xf64>
+ %24 = affine.load %alloc_6[1, 0] : memref<2x3xf64>
+ %25 = affine.load %alloc[0, 0] : memref<3x2xf64>
+ %26 = arith.mulf %24, %25 : f64
+ %27 = arith.addf %26, %cst : f64
+ %28 = affine.load %alloc_6[1, 1] : memref<2x3xf64>
+ %29 = affine.load %alloc[1, 0] : memref<3x2xf64>
+ %30 = arith.mulf %28, %29 : f64
+ %31 = arith.addf %27, %30 : f64
+ %32 = affine.load %alloc_6[1, 2] : memref<2x3xf64>
+ %33 = affine.load %alloc[2, 0] : memref<3x2xf64>
+ %34 = arith.mulf %32, %33 : f64
+ %35 = arith.addf %31, %34 : f64
+ affine.store %35, %alloc_7[1, 0] : memref<2x2xf64>
+ %36 = affine.load %alloc_6[1, 0] : memref<2x3xf64>
+ %37 = affine.load %alloc[0, 1] : memref<3x2xf64>
+ %38 = arith.mulf %36, %37 : f64
+ %39 = arith.addf %38, %cst : f64
+ %40 = affine.load %alloc_6[1, 1] : memref<2x3xf64>
+ %41 = affine.load %alloc[1, 1] : memref<3x2xf64>
+ %42 = arith.mulf %40, %41 : f64
+ %43 = arith.addf %39, %42 : f64
+ %44 = affine.load %alloc_6[1, 2] : memref<2x3xf64>
+ %45 = affine.load %alloc[2, 1] : memref<3x2xf64>
+ %46 = arith.mulf %44, %45 : f64
+ %47 = arith.addf %43, %46 : f64
+ affine.store %47, %alloc_7[1, 1] : memref<2x2xf64>
+ toy.print %alloc_7 : memref<2x2xf64>
+ memref.dealloc %alloc_6 : memref<2x3xf64>
+ memref.dealloc %alloc : memref<3x2xf64>
+ return
+ }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.llvm b/mlir/test/Examples/Toy/Matmul/matmul1.llvm
new file mode 100644
index 0000000000000..65f4e07d37f0f
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.llvm
@@ -0,0 +1,239 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at nl = internal constant [2 x i8] c"\0A\00"
+ at frmt_spec = internal constant [4 x i8] c"%f \00"
+
+declare !dbg !3 void @free(ptr)
+
+declare !dbg !6 i32 @printf(ptr, ...)
+
+declare !dbg !7 ptr @malloc(i64)
+
+define void @main() !dbg !8 {
+ %1 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !9
+ %2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %1, 0, !dbg !9
+ %3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %2, ptr %1, 1, !dbg !9
+ %4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, i64 0, 2, !dbg !9
+ %5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %4, i64 3, 3, 0, !dbg !9
+ %6 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %5, i64 2, 3, 1, !dbg !9
+ %7 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %6, i64 2, 4, 0, !dbg !9
+ %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %7, i64 1, 4, 1, !dbg !9
+ %9 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !10
+ %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %9, 0, !dbg !10
+ %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, ptr %9, 1, !dbg !10
+ %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 0, 2, !dbg !10
+ %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 2, 3, 0, !dbg !10
+ %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 3, 3, 1, !dbg !10
+ %15 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, i64 3, 4, 0, !dbg !10
+ %16 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %15, i64 1, 4, 1, !dbg !10
+ %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %18 = getelementptr double, ptr %17, i64 0, !dbg !10
+ store double 1.000000e+00, ptr %18, align 8, !dbg !10
+ %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %20 = getelementptr double, ptr %19, i64 1, !dbg !10
+ store double 2.000000e+00, ptr %20, align 8, !dbg !10
+ %21 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %22 = getelementptr double, ptr %21, i64 2, !dbg !10
+ store double 3.000000e+00, ptr %22, align 8, !dbg !10
+ %23 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %24 = getelementptr double, ptr %23, i64 3, !dbg !10
+ store double 4.000000e+00, ptr %24, align 8, !dbg !10
+ %25 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %26 = getelementptr double, ptr %25, i64 4, !dbg !10
+ store double 5.000000e+00, ptr %26, align 8, !dbg !10
+ %27 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %28 = getelementptr double, ptr %27, i64 5, !dbg !10
+ store double 6.000000e+00, ptr %28, align 8, !dbg !10
+ %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %30 = getelementptr double, ptr %29, i64 0, !dbg !9
+ store double 1.000000e+00, ptr %30, align 8, !dbg !9
+ %31 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %32 = getelementptr double, ptr %31, i64 1, !dbg !9
+ store double 2.000000e+00, ptr %32, align 8, !dbg !9
+ %33 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %34 = getelementptr double, ptr %33, i64 2, !dbg !9
+ store double 3.000000e+00, ptr %34, align 8, !dbg !9
+ %35 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %36 = getelementptr double, ptr %35, i64 3, !dbg !9
+ store double 4.000000e+00, ptr %36, align 8, !dbg !9
+ %37 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %38 = getelementptr double, ptr %37, i64 4, !dbg !9
+ store double 5.000000e+00, ptr %38, align 8, !dbg !9
+ %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %40 = getelementptr double, ptr %39, i64 5, !dbg !9
+ store double 6.000000e+00, ptr %40, align 8, !dbg !9
+ %41 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 4) to i64)), !dbg !11
+ %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %41, 0, !dbg !11
+ %43 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, ptr %41, 1, !dbg !11
+ %44 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %43, i64 0, 2, !dbg !11
+ %45 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %44, i64 2, 3, 0, !dbg !11
+ %46 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %45, i64 2, 3, 1, !dbg !11
+ %47 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %46, i64 2, 4, 0, !dbg !11
+ %48 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %47, i64 1, 4, 1, !dbg !11
+ %49 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %50 = getelementptr double, ptr %49, i64 0, !dbg !11
+ %51 = load double, ptr %50, align 8, !dbg !11
+ %52 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %53 = getelementptr double, ptr %52, i64 0, !dbg !11
+ %54 = load double, ptr %53, align 8, !dbg !11
+ %55 = fmul double %51, %54, !dbg !11
+ %56 = fadd double %55, 0.000000e+00, !dbg !11
+ %57 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %58 = getelementptr double, ptr %57, i64 1, !dbg !11
+ %59 = load double, ptr %58, align 8, !dbg !11
+ %60 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %61 = getelementptr double, ptr %60, i64 2, !dbg !11
+ %62 = load double, ptr %61, align 8, !dbg !11
+ %63 = fmul double %59, %62, !dbg !11
+ %64 = fadd double %56, %63, !dbg !11
+ %65 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %66 = getelementptr double, ptr %65, i64 2, !dbg !11
+ %67 = load double, ptr %66, align 8, !dbg !11
+ %68 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %69 = getelementptr double, ptr %68, i64 4, !dbg !11
+ %70 = load double, ptr %69, align 8, !dbg !11
+ %71 = fmul double %67, %70, !dbg !11
+ %72 = fadd double %64, %71, !dbg !11
+ %73 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %74 = getelementptr double, ptr %73, i64 0, !dbg !11
+ store double %72, ptr %74, align 8, !dbg !11
+ %75 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %76 = getelementptr double, ptr %75, i64 0, !dbg !11
+ %77 = load double, ptr %76, align 8, !dbg !11
+ %78 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %79 = getelementptr double, ptr %78, i64 1, !dbg !11
+ %80 = load double, ptr %79, align 8, !dbg !11
+ %81 = fmul double %77, %80, !dbg !11
+ %82 = fadd double %81, 0.000000e+00, !dbg !11
+ %83 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %84 = getelementptr double, ptr %83, i64 1, !dbg !11
+ %85 = load double, ptr %84, align 8, !dbg !11
+ %86 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %87 = getelementptr double, ptr %86, i64 3, !dbg !11
+ %88 = load double, ptr %87, align 8, !dbg !11
+ %89 = fmul double %85, %88, !dbg !11
+ %90 = fadd double %82, %89, !dbg !11
+ %91 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %92 = getelementptr double, ptr %91, i64 2, !dbg !11
+ %93 = load double, ptr %92, align 8, !dbg !11
+ %94 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %95 = getelementptr double, ptr %94, i64 5, !dbg !11
+ %96 = load double, ptr %95, align 8, !dbg !11
+ %97 = fmul double %93, %96, !dbg !11
+ %98 = fadd double %90, %97, !dbg !11
+ %99 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %100 = getelementptr double, ptr %99, i64 1, !dbg !11
+ store double %98, ptr %100, align 8, !dbg !11
+ %101 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %102 = getelementptr double, ptr %101, i64 3, !dbg !11
+ %103 = load double, ptr %102, align 8, !dbg !11
+ %104 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %105 = getelementptr double, ptr %104, i64 0, !dbg !11
+ %106 = load double, ptr %105, align 8, !dbg !11
+ %107 = fmul double %103, %106, !dbg !11
+ %108 = fadd double %107, 0.000000e+00, !dbg !11
+ %109 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %110 = getelementptr double, ptr %109, i64 4, !dbg !11
+ %111 = load double, ptr %110, align 8, !dbg !11
+ %112 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %113 = getelementptr double, ptr %112, i64 2, !dbg !11
+ %114 = load double, ptr %113, align 8, !dbg !11
+ %115 = fmul double %111, %114, !dbg !11
+ %116 = fadd double %108, %115, !dbg !11
+ %117 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %118 = getelementptr double, ptr %117, i64 5, !dbg !11
+ %119 = load double, ptr %118, align 8, !dbg !11
+ %120 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %121 = getelementptr double, ptr %120, i64 4, !dbg !11
+ %122 = load double, ptr %121, align 8, !dbg !11
+ %123 = fmul double %119, %122, !dbg !11
+ %124 = fadd double %116, %123, !dbg !11
+ %125 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %126 = getelementptr double, ptr %125, i64 2, !dbg !11
+ store double %124, ptr %126, align 8, !dbg !11
+ %127 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %128 = getelementptr double, ptr %127, i64 3, !dbg !11
+ %129 = load double, ptr %128, align 8, !dbg !11
+ %130 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %131 = getelementptr double, ptr %130, i64 1, !dbg !11
+ %132 = load double, ptr %131, align 8, !dbg !11
+ %133 = fmul double %129, %132, !dbg !11
+ %134 = fadd double %133, 0.000000e+00, !dbg !11
+ %135 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %136 = getelementptr double, ptr %135, i64 4, !dbg !11
+ %137 = load double, ptr %136, align 8, !dbg !11
+ %138 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %139 = getelementptr double, ptr %138, i64 3, !dbg !11
+ %140 = load double, ptr %139, align 8, !dbg !11
+ %141 = fmul double %137, %140, !dbg !11
+ %142 = fadd double %134, %141, !dbg !11
+ %143 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %144 = getelementptr double, ptr %143, i64 5, !dbg !11
+ %145 = load double, ptr %144, align 8, !dbg !11
+ %146 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %147 = getelementptr double, ptr %146, i64 5, !dbg !11
+ %148 = load double, ptr %147, align 8, !dbg !11
+ %149 = fmul double %145, %148, !dbg !11
+ %150 = fadd double %142, %149, !dbg !11
+ %151 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %152 = getelementptr double, ptr %151, i64 3, !dbg !11
+ store double %150, ptr %152, align 8, !dbg !11
+ br label %153, !dbg !12
+
+153: ; preds = %168, %0
+ %154 = phi i64 [ 0, %0 ], [ %170, %168 ]
+ %155 = icmp slt i64 %154, 2, !dbg !12
+ br i1 %155, label %156, label %171, !dbg !12
+
+156: ; preds = %153
+ br label %157, !dbg !12
+
+157: ; preds = %160, %156
+ %158 = phi i64 [ 0, %156 ], [ %167, %160 ]
+ %159 = icmp slt i64 %158, 2, !dbg !12
+ br i1 %159, label %160, label %168, !dbg !12
+
+160: ; preds = %157
+ %161 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !12
+ %162 = mul i64 %154, 2, !dbg !12
+ %163 = add i64 %162, %158, !dbg !12
+ %164 = getelementptr double, ptr %161, i64 %163, !dbg !12
+ %165 = load double, ptr %164, align 8, !dbg !12
+ %166 = call i32 (ptr, ...) @printf(ptr @frmt_spec, double %165), !dbg !12
+ %167 = add i64 %158, 1, !dbg !12
+ br label %157, !dbg !12
+
+168: ; preds = %157
+ %169 = call i32 (ptr, ...) @printf(ptr @nl), !dbg !12
+ %170 = add i64 %154, 1, !dbg !12
+ br label %153, !dbg !12
+
+171: ; preds = %153
+ %172 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 0, !dbg !10
+ call void @free(ptr %172), !dbg !10
+ %173 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 0, !dbg !9
+ call void @free(ptr %173), !dbg !9
+ ret void, !dbg !13
+}
+
+!llvm.module.flags = !{!0}
+!llvm.dbg.cu = !{!1}
+
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2, producer: "MLIR", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
+!2 = !DIFile(filename: "matmul1.mlir", directory: "")
+!3 = !DISubprogram(name: "free", linkageName: "free", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!4 = !DISubroutineType(cc: DW_CC_normal, types: !5)
+!5 = !{}
+!6 = !DISubprogram(name: "printf", linkageName: "printf", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!7 = !DISubprogram(name: "malloc", linkageName: "malloc", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!8 = distinct !DISubprogram(name: "main", linkageName: "main", scope: !2, file: !2, line: 2, type: !4, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
+!9 = !DILocation(line: 6, column: 10, scope: !8)
+!10 = !DILocation(line: 4, column: 10, scope: !8)
+!11 = !DILocation(line: 7, column: 10, scope: !8)
+!12 = !DILocation(line: 8, column: 5, scope: !8)
+!13 = !DILocation(line: 9, column: 5, scope: !8)
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.mlir b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
new file mode 100644
index 0000000000000..956da787f9b36
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
@@ -0,0 +1,11 @@
+module {
+ toy.func @main() {
+ %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+ %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<2x3xf64>
+ %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+ %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64>
+ %4 = toy.matmul %1, %3 : tensor<2x3xf64>, tensor<3x2xf64> -> tensor<*xf64>
+ toy.print %4 : tensor<*xf64>
+ toy.return
+ }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.toy b/mlir/test/Examples/Toy/Matmul/matmul1.toy
new file mode 100644
index 0000000000000..823608a6a2f02
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.toy
@@ -0,0 +1,6 @@
+def main() {
+ var a<2, 3> = [1, 2, 3, 4, 5, 6];
+ var b<3, 2> = [1, 2, 3, 4, 5, 6];
+ var c = matmul(a,b);
+ print(c);
+}
\ No newline at end of file
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir b/mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir
new file mode 100644
index 0000000000000..88a235f662b15
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2-affine.mlir
@@ -0,0 +1,210 @@
+module {
+ func.func @main() {
+ %cst = arith.constant 0.000000e+00 : f64
+ %cst_0 = arith.constant 6.000000e+00 : f64
+ %cst_1 = arith.constant 5.000000e+00 : f64
+ %cst_2 = arith.constant 4.000000e+00 : f64
+ %cst_3 = arith.constant 3.000000e+00 : f64
+ %cst_4 = arith.constant 2.000000e+00 : f64
+ %cst_5 = arith.constant 1.000000e+00 : f64
+ %alloc = memref.alloc() : memref<1x6xf64>
+ %alloc_6 = memref.alloc() : memref<6x1xf64>
+ affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+ affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+ affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+ affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+ affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+ affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+ affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+ affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+ affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+ affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+ affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+ affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+ %alloc_7 = memref.alloc() : memref<6x6xf64>
+ %0 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %1 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %2 = arith.mulf %0, %1 : f64
+ %3 = arith.addf %2, %cst : f64
+ affine.store %3, %alloc_7[0, 0] : memref<6x6xf64>
+ %4 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %5 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %6 = arith.mulf %4, %5 : f64
+ %7 = arith.addf %6, %cst : f64
+ affine.store %7, %alloc_7[0, 1] : memref<6x6xf64>
+ %8 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %9 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %10 = arith.mulf %8, %9 : f64
+ %11 = arith.addf %10, %cst : f64
+ affine.store %11, %alloc_7[0, 2] : memref<6x6xf64>
+ %12 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %13 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %14 = arith.mulf %12, %13 : f64
+ %15 = arith.addf %14, %cst : f64
+ affine.store %15, %alloc_7[0, 3] : memref<6x6xf64>
+ %16 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %17 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %18 = arith.mulf %16, %17 : f64
+ %19 = arith.addf %18, %cst : f64
+ affine.store %19, %alloc_7[0, 4] : memref<6x6xf64>
+ %20 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %21 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %22 = arith.mulf %20, %21 : f64
+ %23 = arith.addf %22, %cst : f64
+ affine.store %23, %alloc_7[0, 5] : memref<6x6xf64>
+ %24 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %25 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %26 = arith.mulf %24, %25 : f64
+ %27 = arith.addf %26, %cst : f64
+ affine.store %27, %alloc_7[1, 0] : memref<6x6xf64>
+ %28 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %29 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %30 = arith.mulf %28, %29 : f64
+ %31 = arith.addf %30, %cst : f64
+ affine.store %31, %alloc_7[1, 1] : memref<6x6xf64>
+ %32 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %33 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %34 = arith.mulf %32, %33 : f64
+ %35 = arith.addf %34, %cst : f64
+ affine.store %35, %alloc_7[1, 2] : memref<6x6xf64>
+ %36 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %37 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %38 = arith.mulf %36, %37 : f64
+ %39 = arith.addf %38, %cst : f64
+ affine.store %39, %alloc_7[1, 3] : memref<6x6xf64>
+ %40 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %41 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %42 = arith.mulf %40, %41 : f64
+ %43 = arith.addf %42, %cst : f64
+ affine.store %43, %alloc_7[1, 4] : memref<6x6xf64>
+ %44 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %45 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %46 = arith.mulf %44, %45 : f64
+ %47 = arith.addf %46, %cst : f64
+ affine.store %47, %alloc_7[1, 5] : memref<6x6xf64>
+ %48 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %49 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %50 = arith.mulf %48, %49 : f64
+ %51 = arith.addf %50, %cst : f64
+ affine.store %51, %alloc_7[2, 0] : memref<6x6xf64>
+ %52 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %53 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %54 = arith.mulf %52, %53 : f64
+ %55 = arith.addf %54, %cst : f64
+ affine.store %55, %alloc_7[2, 1] : memref<6x6xf64>
+ %56 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %57 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %58 = arith.mulf %56, %57 : f64
+ %59 = arith.addf %58, %cst : f64
+ affine.store %59, %alloc_7[2, 2] : memref<6x6xf64>
+ %60 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %61 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %62 = arith.mulf %60, %61 : f64
+ %63 = arith.addf %62, %cst : f64
+ affine.store %63, %alloc_7[2, 3] : memref<6x6xf64>
+ %64 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %65 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %66 = arith.mulf %64, %65 : f64
+ %67 = arith.addf %66, %cst : f64
+ affine.store %67, %alloc_7[2, 4] : memref<6x6xf64>
+ %68 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %69 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %70 = arith.mulf %68, %69 : f64
+ %71 = arith.addf %70, %cst : f64
+ affine.store %71, %alloc_7[2, 5] : memref<6x6xf64>
+ %72 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %73 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %74 = arith.mulf %72, %73 : f64
+ %75 = arith.addf %74, %cst : f64
+ affine.store %75, %alloc_7[3, 0] : memref<6x6xf64>
+ %76 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %77 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %78 = arith.mulf %76, %77 : f64
+ %79 = arith.addf %78, %cst : f64
+ affine.store %79, %alloc_7[3, 1] : memref<6x6xf64>
+ %80 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %81 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %82 = arith.mulf %80, %81 : f64
+ %83 = arith.addf %82, %cst : f64
+ affine.store %83, %alloc_7[3, 2] : memref<6x6xf64>
+ %84 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %85 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %86 = arith.mulf %84, %85 : f64
+ %87 = arith.addf %86, %cst : f64
+ affine.store %87, %alloc_7[3, 3] : memref<6x6xf64>
+ %88 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %89 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %90 = arith.mulf %88, %89 : f64
+ %91 = arith.addf %90, %cst : f64
+ affine.store %91, %alloc_7[3, 4] : memref<6x6xf64>
+ %92 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %93 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %94 = arith.mulf %92, %93 : f64
+ %95 = arith.addf %94, %cst : f64
+ affine.store %95, %alloc_7[3, 5] : memref<6x6xf64>
+ %96 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %97 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %98 = arith.mulf %96, %97 : f64
+ %99 = arith.addf %98, %cst : f64
+ affine.store %99, %alloc_7[4, 0] : memref<6x6xf64>
+ %100 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %101 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %102 = arith.mulf %100, %101 : f64
+ %103 = arith.addf %102, %cst : f64
+ affine.store %103, %alloc_7[4, 1] : memref<6x6xf64>
+ %104 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %105 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %106 = arith.mulf %104, %105 : f64
+ %107 = arith.addf %106, %cst : f64
+ affine.store %107, %alloc_7[4, 2] : memref<6x6xf64>
+ %108 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %109 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %110 = arith.mulf %108, %109 : f64
+ %111 = arith.addf %110, %cst : f64
+ affine.store %111, %alloc_7[4, 3] : memref<6x6xf64>
+ %112 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %113 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %114 = arith.mulf %112, %113 : f64
+ %115 = arith.addf %114, %cst : f64
+ affine.store %115, %alloc_7[4, 4] : memref<6x6xf64>
+ %116 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %117 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %118 = arith.mulf %116, %117 : f64
+ %119 = arith.addf %118, %cst : f64
+ affine.store %119, %alloc_7[4, 5] : memref<6x6xf64>
+ %120 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %121 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %122 = arith.mulf %120, %121 : f64
+ %123 = arith.addf %122, %cst : f64
+ affine.store %123, %alloc_7[5, 0] : memref<6x6xf64>
+ %124 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %125 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %126 = arith.mulf %124, %125 : f64
+ %127 = arith.addf %126, %cst : f64
+ affine.store %127, %alloc_7[5, 1] : memref<6x6xf64>
+ %128 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %129 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %130 = arith.mulf %128, %129 : f64
+ %131 = arith.addf %130, %cst : f64
+ affine.store %131, %alloc_7[5, 2] : memref<6x6xf64>
+ %132 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %133 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %134 = arith.mulf %132, %133 : f64
+ %135 = arith.addf %134, %cst : f64
+ affine.store %135, %alloc_7[5, 3] : memref<6x6xf64>
+ %136 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %137 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %138 = arith.mulf %136, %137 : f64
+ %139 = arith.addf %138, %cst : f64
+ affine.store %139, %alloc_7[5, 4] : memref<6x6xf64>
+ %140 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %141 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %142 = arith.mulf %140, %141 : f64
+ %143 = arith.addf %142, %cst : f64
+ affine.store %143, %alloc_7[5, 5] : memref<6x6xf64>
+ toy.print %alloc_7 : memref<6x6xf64>
+ memref.dealloc %alloc_6 : memref<6x1xf64>
+ memref.dealloc %alloc : memref<1x6xf64>
+ return
+ }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.llvm b/mlir/test/Examples/Toy/Matmul/matmul2.llvm
new file mode 100644
index 0000000000000..0d777aa2819ed
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.llvm
@@ -0,0 +1,527 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at nl = internal constant [2 x i8] c"\0A\00"
+ at frmt_spec = internal constant [4 x i8] c"%f \00"
+
+declare !dbg !3 void @free(ptr)
+
+declare !dbg !6 i32 @printf(ptr, ...)
+
+declare !dbg !7 ptr @malloc(i64)
+
+define void @main() !dbg !8 {
+ %1 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !9
+ %2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %1, 0, !dbg !9
+ %3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %2, ptr %1, 1, !dbg !9
+ %4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, i64 0, 2, !dbg !9
+ %5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %4, i64 1, 3, 0, !dbg !9
+ %6 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %5, i64 6, 3, 1, !dbg !9
+ %7 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %6, i64 6, 4, 0, !dbg !9
+ %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %7, i64 1, 4, 1, !dbg !9
+ %9 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !10
+ %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %9, 0, !dbg !10
+ %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, ptr %9, 1, !dbg !10
+ %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 0, 2, !dbg !10
+ %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 6, 3, 0, !dbg !10
+ %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 1, 3, 1, !dbg !10
+ %15 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, i64 1, 4, 0, !dbg !10
+ %16 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %15, i64 1, 4, 1, !dbg !10
+ %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %18 = getelementptr double, ptr %17, i64 0, !dbg !10
+ store double 1.000000e+00, ptr %18, align 8, !dbg !10
+ %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %20 = getelementptr double, ptr %19, i64 1, !dbg !10
+ store double 2.000000e+00, ptr %20, align 8, !dbg !10
+ %21 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %22 = getelementptr double, ptr %21, i64 2, !dbg !10
+ store double 3.000000e+00, ptr %22, align 8, !dbg !10
+ %23 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %24 = getelementptr double, ptr %23, i64 3, !dbg !10
+ store double 4.000000e+00, ptr %24, align 8, !dbg !10
+ %25 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %26 = getelementptr double, ptr %25, i64 4, !dbg !10
+ store double 5.000000e+00, ptr %26, align 8, !dbg !10
+ %27 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %28 = getelementptr double, ptr %27, i64 5, !dbg !10
+ store double 6.000000e+00, ptr %28, align 8, !dbg !10
+ %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %30 = getelementptr double, ptr %29, i64 0, !dbg !9
+ store double 1.000000e+00, ptr %30, align 8, !dbg !9
+ %31 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %32 = getelementptr double, ptr %31, i64 1, !dbg !9
+ store double 2.000000e+00, ptr %32, align 8, !dbg !9
+ %33 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %34 = getelementptr double, ptr %33, i64 2, !dbg !9
+ store double 3.000000e+00, ptr %34, align 8, !dbg !9
+ %35 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %36 = getelementptr double, ptr %35, i64 3, !dbg !9
+ store double 4.000000e+00, ptr %36, align 8, !dbg !9
+ %37 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %38 = getelementptr double, ptr %37, i64 4, !dbg !9
+ store double 5.000000e+00, ptr %38, align 8, !dbg !9
+ %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %40 = getelementptr double, ptr %39, i64 5, !dbg !9
+ store double 6.000000e+00, ptr %40, align 8, !dbg !9
+ %41 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 36) to i64)), !dbg !11
+ %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %41, 0, !dbg !11
+ %43 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, ptr %41, 1, !dbg !11
+ %44 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %43, i64 0, 2, !dbg !11
+ %45 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %44, i64 6, 3, 0, !dbg !11
+ %46 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %45, i64 6, 3, 1, !dbg !11
+ %47 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %46, i64 6, 4, 0, !dbg !11
+ %48 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %47, i64 1, 4, 1, !dbg !11
+ %49 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %50 = getelementptr double, ptr %49, i64 0, !dbg !11
+ %51 = load double, ptr %50, align 8, !dbg !11
+ %52 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %53 = getelementptr double, ptr %52, i64 0, !dbg !11
+ %54 = load double, ptr %53, align 8, !dbg !11
+ %55 = fmul double %51, %54, !dbg !11
+ %56 = fadd double %55, 0.000000e+00, !dbg !11
+ %57 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %58 = getelementptr double, ptr %57, i64 0, !dbg !11
+ store double %56, ptr %58, align 8, !dbg !11
+ %59 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %60 = getelementptr double, ptr %59, i64 0, !dbg !11
+ %61 = load double, ptr %60, align 8, !dbg !11
+ %62 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %63 = getelementptr double, ptr %62, i64 1, !dbg !11
+ %64 = load double, ptr %63, align 8, !dbg !11
+ %65 = fmul double %61, %64, !dbg !11
+ %66 = fadd double %65, 0.000000e+00, !dbg !11
+ %67 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %68 = getelementptr double, ptr %67, i64 1, !dbg !11
+ store double %66, ptr %68, align 8, !dbg !11
+ %69 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %70 = getelementptr double, ptr %69, i64 0, !dbg !11
+ %71 = load double, ptr %70, align 8, !dbg !11
+ %72 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %73 = getelementptr double, ptr %72, i64 2, !dbg !11
+ %74 = load double, ptr %73, align 8, !dbg !11
+ %75 = fmul double %71, %74, !dbg !11
+ %76 = fadd double %75, 0.000000e+00, !dbg !11
+ %77 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %78 = getelementptr double, ptr %77, i64 2, !dbg !11
+ store double %76, ptr %78, align 8, !dbg !11
+ %79 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %80 = getelementptr double, ptr %79, i64 0, !dbg !11
+ %81 = load double, ptr %80, align 8, !dbg !11
+ %82 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %83 = getelementptr double, ptr %82, i64 3, !dbg !11
+ %84 = load double, ptr %83, align 8, !dbg !11
+ %85 = fmul double %81, %84, !dbg !11
+ %86 = fadd double %85, 0.000000e+00, !dbg !11
+ %87 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %88 = getelementptr double, ptr %87, i64 3, !dbg !11
+ store double %86, ptr %88, align 8, !dbg !11
+ %89 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %90 = getelementptr double, ptr %89, i64 0, !dbg !11
+ %91 = load double, ptr %90, align 8, !dbg !11
+ %92 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %93 = getelementptr double, ptr %92, i64 4, !dbg !11
+ %94 = load double, ptr %93, align 8, !dbg !11
+ %95 = fmul double %91, %94, !dbg !11
+ %96 = fadd double %95, 0.000000e+00, !dbg !11
+ %97 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %98 = getelementptr double, ptr %97, i64 4, !dbg !11
+ store double %96, ptr %98, align 8, !dbg !11
+ %99 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %100 = getelementptr double, ptr %99, i64 0, !dbg !11
+ %101 = load double, ptr %100, align 8, !dbg !11
+ %102 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %103 = getelementptr double, ptr %102, i64 5, !dbg !11
+ %104 = load double, ptr %103, align 8, !dbg !11
+ %105 = fmul double %101, %104, !dbg !11
+ %106 = fadd double %105, 0.000000e+00, !dbg !11
+ %107 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %108 = getelementptr double, ptr %107, i64 5, !dbg !11
+ store double %106, ptr %108, align 8, !dbg !11
+ %109 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %110 = getelementptr double, ptr %109, i64 1, !dbg !11
+ %111 = load double, ptr %110, align 8, !dbg !11
+ %112 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %113 = getelementptr double, ptr %112, i64 0, !dbg !11
+ %114 = load double, ptr %113, align 8, !dbg !11
+ %115 = fmul double %111, %114, !dbg !11
+ %116 = fadd double %115, 0.000000e+00, !dbg !11
+ %117 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %118 = getelementptr double, ptr %117, i64 6, !dbg !11
+ store double %116, ptr %118, align 8, !dbg !11
+ %119 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %120 = getelementptr double, ptr %119, i64 1, !dbg !11
+ %121 = load double, ptr %120, align 8, !dbg !11
+ %122 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %123 = getelementptr double, ptr %122, i64 1, !dbg !11
+ %124 = load double, ptr %123, align 8, !dbg !11
+ %125 = fmul double %121, %124, !dbg !11
+ %126 = fadd double %125, 0.000000e+00, !dbg !11
+ %127 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %128 = getelementptr double, ptr %127, i64 7, !dbg !11
+ store double %126, ptr %128, align 8, !dbg !11
+ %129 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %130 = getelementptr double, ptr %129, i64 1, !dbg !11
+ %131 = load double, ptr %130, align 8, !dbg !11
+ %132 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %133 = getelementptr double, ptr %132, i64 2, !dbg !11
+ %134 = load double, ptr %133, align 8, !dbg !11
+ %135 = fmul double %131, %134, !dbg !11
+ %136 = fadd double %135, 0.000000e+00, !dbg !11
+ %137 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %138 = getelementptr double, ptr %137, i64 8, !dbg !11
+ store double %136, ptr %138, align 8, !dbg !11
+ %139 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %140 = getelementptr double, ptr %139, i64 1, !dbg !11
+ %141 = load double, ptr %140, align 8, !dbg !11
+ %142 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %143 = getelementptr double, ptr %142, i64 3, !dbg !11
+ %144 = load double, ptr %143, align 8, !dbg !11
+ %145 = fmul double %141, %144, !dbg !11
+ %146 = fadd double %145, 0.000000e+00, !dbg !11
+ %147 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %148 = getelementptr double, ptr %147, i64 9, !dbg !11
+ store double %146, ptr %148, align 8, !dbg !11
+ %149 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %150 = getelementptr double, ptr %149, i64 1, !dbg !11
+ %151 = load double, ptr %150, align 8, !dbg !11
+ %152 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %153 = getelementptr double, ptr %152, i64 4, !dbg !11
+ %154 = load double, ptr %153, align 8, !dbg !11
+ %155 = fmul double %151, %154, !dbg !11
+ %156 = fadd double %155, 0.000000e+00, !dbg !11
+ %157 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %158 = getelementptr double, ptr %157, i64 10, !dbg !11
+ store double %156, ptr %158, align 8, !dbg !11
+ %159 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %160 = getelementptr double, ptr %159, i64 1, !dbg !11
+ %161 = load double, ptr %160, align 8, !dbg !11
+ %162 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %163 = getelementptr double, ptr %162, i64 5, !dbg !11
+ %164 = load double, ptr %163, align 8, !dbg !11
+ %165 = fmul double %161, %164, !dbg !11
+ %166 = fadd double %165, 0.000000e+00, !dbg !11
+ %167 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %168 = getelementptr double, ptr %167, i64 11, !dbg !11
+ store double %166, ptr %168, align 8, !dbg !11
+ %169 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %170 = getelementptr double, ptr %169, i64 2, !dbg !11
+ %171 = load double, ptr %170, align 8, !dbg !11
+ %172 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %173 = getelementptr double, ptr %172, i64 0, !dbg !11
+ %174 = load double, ptr %173, align 8, !dbg !11
+ %175 = fmul double %171, %174, !dbg !11
+ %176 = fadd double %175, 0.000000e+00, !dbg !11
+ %177 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %178 = getelementptr double, ptr %177, i64 12, !dbg !11
+ store double %176, ptr %178, align 8, !dbg !11
+ %179 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %180 = getelementptr double, ptr %179, i64 2, !dbg !11
+ %181 = load double, ptr %180, align 8, !dbg !11
+ %182 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %183 = getelementptr double, ptr %182, i64 1, !dbg !11
+ %184 = load double, ptr %183, align 8, !dbg !11
+ %185 = fmul double %181, %184, !dbg !11
+ %186 = fadd double %185, 0.000000e+00, !dbg !11
+ %187 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %188 = getelementptr double, ptr %187, i64 13, !dbg !11
+ store double %186, ptr %188, align 8, !dbg !11
+ %189 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %190 = getelementptr double, ptr %189, i64 2, !dbg !11
+ %191 = load double, ptr %190, align 8, !dbg !11
+ %192 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %193 = getelementptr double, ptr %192, i64 2, !dbg !11
+ %194 = load double, ptr %193, align 8, !dbg !11
+ %195 = fmul double %191, %194, !dbg !11
+ %196 = fadd double %195, 0.000000e+00, !dbg !11
+ %197 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %198 = getelementptr double, ptr %197, i64 14, !dbg !11
+ store double %196, ptr %198, align 8, !dbg !11
+ %199 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %200 = getelementptr double, ptr %199, i64 2, !dbg !11
+ %201 = load double, ptr %200, align 8, !dbg !11
+ %202 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %203 = getelementptr double, ptr %202, i64 3, !dbg !11
+ %204 = load double, ptr %203, align 8, !dbg !11
+ %205 = fmul double %201, %204, !dbg !11
+ %206 = fadd double %205, 0.000000e+00, !dbg !11
+ %207 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %208 = getelementptr double, ptr %207, i64 15, !dbg !11
+ store double %206, ptr %208, align 8, !dbg !11
+ %209 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %210 = getelementptr double, ptr %209, i64 2, !dbg !11
+ %211 = load double, ptr %210, align 8, !dbg !11
+ %212 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %213 = getelementptr double, ptr %212, i64 4, !dbg !11
+ %214 = load double, ptr %213, align 8, !dbg !11
+ %215 = fmul double %211, %214, !dbg !11
+ %216 = fadd double %215, 0.000000e+00, !dbg !11
+ %217 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %218 = getelementptr double, ptr %217, i64 16, !dbg !11
+ store double %216, ptr %218, align 8, !dbg !11
+ %219 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %220 = getelementptr double, ptr %219, i64 2, !dbg !11
+ %221 = load double, ptr %220, align 8, !dbg !11
+ %222 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %223 = getelementptr double, ptr %222, i64 5, !dbg !11
+ %224 = load double, ptr %223, align 8, !dbg !11
+ %225 = fmul double %221, %224, !dbg !11
+ %226 = fadd double %225, 0.000000e+00, !dbg !11
+ %227 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %228 = getelementptr double, ptr %227, i64 17, !dbg !11
+ store double %226, ptr %228, align 8, !dbg !11
+ %229 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %230 = getelementptr double, ptr %229, i64 3, !dbg !11
+ %231 = load double, ptr %230, align 8, !dbg !11
+ %232 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %233 = getelementptr double, ptr %232, i64 0, !dbg !11
+ %234 = load double, ptr %233, align 8, !dbg !11
+ %235 = fmul double %231, %234, !dbg !11
+ %236 = fadd double %235, 0.000000e+00, !dbg !11
+ %237 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %238 = getelementptr double, ptr %237, i64 18, !dbg !11
+ store double %236, ptr %238, align 8, !dbg !11
+ %239 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %240 = getelementptr double, ptr %239, i64 3, !dbg !11
+ %241 = load double, ptr %240, align 8, !dbg !11
+ %242 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %243 = getelementptr double, ptr %242, i64 1, !dbg !11
+ %244 = load double, ptr %243, align 8, !dbg !11
+ %245 = fmul double %241, %244, !dbg !11
+ %246 = fadd double %245, 0.000000e+00, !dbg !11
+ %247 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %248 = getelementptr double, ptr %247, i64 19, !dbg !11
+ store double %246, ptr %248, align 8, !dbg !11
+ %249 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %250 = getelementptr double, ptr %249, i64 3, !dbg !11
+ %251 = load double, ptr %250, align 8, !dbg !11
+ %252 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %253 = getelementptr double, ptr %252, i64 2, !dbg !11
+ %254 = load double, ptr %253, align 8, !dbg !11
+ %255 = fmul double %251, %254, !dbg !11
+ %256 = fadd double %255, 0.000000e+00, !dbg !11
+ %257 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %258 = getelementptr double, ptr %257, i64 20, !dbg !11
+ store double %256, ptr %258, align 8, !dbg !11
+ %259 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %260 = getelementptr double, ptr %259, i64 3, !dbg !11
+ %261 = load double, ptr %260, align 8, !dbg !11
+ %262 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %263 = getelementptr double, ptr %262, i64 3, !dbg !11
+ %264 = load double, ptr %263, align 8, !dbg !11
+ %265 = fmul double %261, %264, !dbg !11
+ %266 = fadd double %265, 0.000000e+00, !dbg !11
+ %267 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %268 = getelementptr double, ptr %267, i64 21, !dbg !11
+ store double %266, ptr %268, align 8, !dbg !11
+ %269 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %270 = getelementptr double, ptr %269, i64 3, !dbg !11
+ %271 = load double, ptr %270, align 8, !dbg !11
+ %272 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %273 = getelementptr double, ptr %272, i64 4, !dbg !11
+ %274 = load double, ptr %273, align 8, !dbg !11
+ %275 = fmul double %271, %274, !dbg !11
+ %276 = fadd double %275, 0.000000e+00, !dbg !11
+ %277 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %278 = getelementptr double, ptr %277, i64 22, !dbg !11
+ store double %276, ptr %278, align 8, !dbg !11
+ %279 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %280 = getelementptr double, ptr %279, i64 3, !dbg !11
+ %281 = load double, ptr %280, align 8, !dbg !11
+ %282 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %283 = getelementptr double, ptr %282, i64 5, !dbg !11
+ %284 = load double, ptr %283, align 8, !dbg !11
+ %285 = fmul double %281, %284, !dbg !11
+ %286 = fadd double %285, 0.000000e+00, !dbg !11
+ %287 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %288 = getelementptr double, ptr %287, i64 23, !dbg !11
+ store double %286, ptr %288, align 8, !dbg !11
+ %289 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %290 = getelementptr double, ptr %289, i64 4, !dbg !11
+ %291 = load double, ptr %290, align 8, !dbg !11
+ %292 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %293 = getelementptr double, ptr %292, i64 0, !dbg !11
+ %294 = load double, ptr %293, align 8, !dbg !11
+ %295 = fmul double %291, %294, !dbg !11
+ %296 = fadd double %295, 0.000000e+00, !dbg !11
+ %297 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %298 = getelementptr double, ptr %297, i64 24, !dbg !11
+ store double %296, ptr %298, align 8, !dbg !11
+ %299 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %300 = getelementptr double, ptr %299, i64 4, !dbg !11
+ %301 = load double, ptr %300, align 8, !dbg !11
+ %302 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %303 = getelementptr double, ptr %302, i64 1, !dbg !11
+ %304 = load double, ptr %303, align 8, !dbg !11
+ %305 = fmul double %301, %304, !dbg !11
+ %306 = fadd double %305, 0.000000e+00, !dbg !11
+ %307 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %308 = getelementptr double, ptr %307, i64 25, !dbg !11
+ store double %306, ptr %308, align 8, !dbg !11
+ %309 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %310 = getelementptr double, ptr %309, i64 4, !dbg !11
+ %311 = load double, ptr %310, align 8, !dbg !11
+ %312 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %313 = getelementptr double, ptr %312, i64 2, !dbg !11
+ %314 = load double, ptr %313, align 8, !dbg !11
+ %315 = fmul double %311, %314, !dbg !11
+ %316 = fadd double %315, 0.000000e+00, !dbg !11
+ %317 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %318 = getelementptr double, ptr %317, i64 26, !dbg !11
+ store double %316, ptr %318, align 8, !dbg !11
+ %319 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %320 = getelementptr double, ptr %319, i64 4, !dbg !11
+ %321 = load double, ptr %320, align 8, !dbg !11
+ %322 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %323 = getelementptr double, ptr %322, i64 3, !dbg !11
+ %324 = load double, ptr %323, align 8, !dbg !11
+ %325 = fmul double %321, %324, !dbg !11
+ %326 = fadd double %325, 0.000000e+00, !dbg !11
+ %327 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %328 = getelementptr double, ptr %327, i64 27, !dbg !11
+ store double %326, ptr %328, align 8, !dbg !11
+ %329 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %330 = getelementptr double, ptr %329, i64 4, !dbg !11
+ %331 = load double, ptr %330, align 8, !dbg !11
+ %332 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %333 = getelementptr double, ptr %332, i64 4, !dbg !11
+ %334 = load double, ptr %333, align 8, !dbg !11
+ %335 = fmul double %331, %334, !dbg !11
+ %336 = fadd double %335, 0.000000e+00, !dbg !11
+ %337 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %338 = getelementptr double, ptr %337, i64 28, !dbg !11
+ store double %336, ptr %338, align 8, !dbg !11
+ %339 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %340 = getelementptr double, ptr %339, i64 4, !dbg !11
+ %341 = load double, ptr %340, align 8, !dbg !11
+ %342 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %343 = getelementptr double, ptr %342, i64 5, !dbg !11
+ %344 = load double, ptr %343, align 8, !dbg !11
+ %345 = fmul double %341, %344, !dbg !11
+ %346 = fadd double %345, 0.000000e+00, !dbg !11
+ %347 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %348 = getelementptr double, ptr %347, i64 29, !dbg !11
+ store double %346, ptr %348, align 8, !dbg !11
+ %349 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %350 = getelementptr double, ptr %349, i64 5, !dbg !11
+ %351 = load double, ptr %350, align 8, !dbg !11
+ %352 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %353 = getelementptr double, ptr %352, i64 0, !dbg !11
+ %354 = load double, ptr %353, align 8, !dbg !11
+ %355 = fmul double %351, %354, !dbg !11
+ %356 = fadd double %355, 0.000000e+00, !dbg !11
+ %357 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %358 = getelementptr double, ptr %357, i64 30, !dbg !11
+ store double %356, ptr %358, align 8, !dbg !11
+ %359 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %360 = getelementptr double, ptr %359, i64 5, !dbg !11
+ %361 = load double, ptr %360, align 8, !dbg !11
+ %362 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %363 = getelementptr double, ptr %362, i64 1, !dbg !11
+ %364 = load double, ptr %363, align 8, !dbg !11
+ %365 = fmul double %361, %364, !dbg !11
+ %366 = fadd double %365, 0.000000e+00, !dbg !11
+ %367 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %368 = getelementptr double, ptr %367, i64 31, !dbg !11
+ store double %366, ptr %368, align 8, !dbg !11
+ %369 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %370 = getelementptr double, ptr %369, i64 5, !dbg !11
+ %371 = load double, ptr %370, align 8, !dbg !11
+ %372 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %373 = getelementptr double, ptr %372, i64 2, !dbg !11
+ %374 = load double, ptr %373, align 8, !dbg !11
+ %375 = fmul double %371, %374, !dbg !11
+ %376 = fadd double %375, 0.000000e+00, !dbg !11
+ %377 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %378 = getelementptr double, ptr %377, i64 32, !dbg !11
+ store double %376, ptr %378, align 8, !dbg !11
+ %379 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %380 = getelementptr double, ptr %379, i64 5, !dbg !11
+ %381 = load double, ptr %380, align 8, !dbg !11
+ %382 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %383 = getelementptr double, ptr %382, i64 3, !dbg !11
+ %384 = load double, ptr %383, align 8, !dbg !11
+ %385 = fmul double %381, %384, !dbg !11
+ %386 = fadd double %385, 0.000000e+00, !dbg !11
+ %387 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %388 = getelementptr double, ptr %387, i64 33, !dbg !11
+ store double %386, ptr %388, align 8, !dbg !11
+ %389 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %390 = getelementptr double, ptr %389, i64 5, !dbg !11
+ %391 = load double, ptr %390, align 8, !dbg !11
+ %392 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %393 = getelementptr double, ptr %392, i64 4, !dbg !11
+ %394 = load double, ptr %393, align 8, !dbg !11
+ %395 = fmul double %391, %394, !dbg !11
+ %396 = fadd double %395, 0.000000e+00, !dbg !11
+ %397 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %398 = getelementptr double, ptr %397, i64 34, !dbg !11
+ store double %396, ptr %398, align 8, !dbg !11
+ %399 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %400 = getelementptr double, ptr %399, i64 5, !dbg !11
+ %401 = load double, ptr %400, align 8, !dbg !11
+ %402 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %403 = getelementptr double, ptr %402, i64 5, !dbg !11
+ %404 = load double, ptr %403, align 8, !dbg !11
+ %405 = fmul double %401, %404, !dbg !11
+ %406 = fadd double %405, 0.000000e+00, !dbg !11
+ %407 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %408 = getelementptr double, ptr %407, i64 35, !dbg !11
+ store double %406, ptr %408, align 8, !dbg !11
+ br label %409, !dbg !12
+
+409: ; preds = %424, %0
+ %410 = phi i64 [ 0, %0 ], [ %426, %424 ]
+ %411 = icmp slt i64 %410, 6, !dbg !12
+ br i1 %411, label %412, label %427, !dbg !12
+
+412: ; preds = %409
+ br label %413, !dbg !12
+
+413: ; preds = %416, %412
+ %414 = phi i64 [ 0, %412 ], [ %423, %416 ]
+ %415 = icmp slt i64 %414, 6, !dbg !12
+ br i1 %415, label %416, label %424, !dbg !12
+
+416: ; preds = %413
+ %417 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !12
+ %418 = mul i64 %410, 6, !dbg !12
+ %419 = add i64 %418, %414, !dbg !12
+ %420 = getelementptr double, ptr %417, i64 %419, !dbg !12
+ %421 = load double, ptr %420, align 8, !dbg !12
+ %422 = call i32 (ptr, ...) @printf(ptr @frmt_spec, double %421), !dbg !12
+ %423 = add i64 %414, 1, !dbg !12
+ br label %413, !dbg !12
+
+424: ; preds = %413
+ %425 = call i32 (ptr, ...) @printf(ptr @nl), !dbg !12
+ %426 = add i64 %410, 1, !dbg !12
+ br label %409, !dbg !12
+
+427: ; preds = %409
+ %428 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 0, !dbg !10
+ call void @free(ptr %428), !dbg !10
+ %429 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 0, !dbg !9
+ call void @free(ptr %429), !dbg !9
+ ret void, !dbg !13
+}
+
+!llvm.module.flags = !{!0}
+!llvm.dbg.cu = !{!1}
+
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2, producer: "MLIR", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
+!2 = !DIFile(filename: "matmul2.mlir", directory: "")
+!3 = !DISubprogram(name: "free", linkageName: "free", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!4 = !DISubroutineType(cc: DW_CC_normal, types: !5)
+!5 = !{}
+!6 = !DISubprogram(name: "printf", linkageName: "printf", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!7 = !DISubprogram(name: "malloc", linkageName: "malloc", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!8 = distinct !DISubprogram(name: "main", linkageName: "main", scope: !2, file: !2, line: 2, type: !4, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
+!9 = !DILocation(line: 6, column: 10, scope: !8)
+!10 = !DILocation(line: 4, column: 10, scope: !8)
+!11 = !DILocation(line: 7, column: 10, scope: !8)
+!12 = !DILocation(line: 8, column: 5, scope: !8)
+!13 = !DILocation(line: 9, column: 5, scope: !8)
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.mlir b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
new file mode 100644
index 0000000000000..e4145c8bb5dfa
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
@@ -0,0 +1,11 @@
+module {
+ toy.func @main() {
+ %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+ %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+ %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+ %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+ %4 = toy.matmul %1, %3 : tensor<6x1xf64>, tensor<1x6xf64> -> tensor<*xf64>
+ toy.print %4 : tensor<*xf64>
+ toy.return
+ }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.toy b/mlir/test/Examples/Toy/Matmul/matmul2.toy
new file mode 100644
index 0000000000000..6c37cc1af1f6a
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.toy
@@ -0,0 +1,6 @@
+def main() {
+ var a<6, 1> = [1, 2, 3, 4, 5, 6];
+ var b<1, 6> = [1, 2, 3, 4, 5, 6];
+ var c = matmul(a,b);
+ print(c);
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir b/mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir
new file mode 100644
index 0000000000000..46ac739c241b4
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3-affine.mlir
@@ -0,0 +1,55 @@
+module {
+ func.func @main() {
+ %cst = arith.constant 0.000000e+00 : f64
+ %cst_0 = arith.constant 6.000000e+00 : f64
+ %cst_1 = arith.constant 5.000000e+00 : f64
+ %cst_2 = arith.constant 4.000000e+00 : f64
+ %cst_3 = arith.constant 3.000000e+00 : f64
+ %cst_4 = arith.constant 2.000000e+00 : f64
+ %cst_5 = arith.constant 1.000000e+00 : f64
+ %alloc = memref.alloc() : memref<1x6xf64>
+ %alloc_6 = memref.alloc() : memref<6x1xf64>
+ affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+ affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+ affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+ affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+ affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+ affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+ affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+ affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+ affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+ affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+ affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+ affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+ %alloc_7 = memref.alloc() : memref<1x1xf64>
+ %0 = affine.load %alloc[0, 0] : memref<1x6xf64>
+ %1 = affine.load %alloc_6[0, 0] : memref<6x1xf64>
+ %2 = arith.mulf %0, %1 : f64
+ %3 = arith.addf %2, %cst : f64
+ %4 = affine.load %alloc[0, 1] : memref<1x6xf64>
+ %5 = affine.load %alloc_6[1, 0] : memref<6x1xf64>
+ %6 = arith.mulf %4, %5 : f64
+ %7 = arith.addf %3, %6 : f64
+ %8 = affine.load %alloc[0, 2] : memref<1x6xf64>
+ %9 = affine.load %alloc_6[2, 0] : memref<6x1xf64>
+ %10 = arith.mulf %8, %9 : f64
+ %11 = arith.addf %7, %10 : f64
+ %12 = affine.load %alloc[0, 3] : memref<1x6xf64>
+ %13 = affine.load %alloc_6[3, 0] : memref<6x1xf64>
+ %14 = arith.mulf %12, %13 : f64
+ %15 = arith.addf %11, %14 : f64
+ %16 = affine.load %alloc[0, 4] : memref<1x6xf64>
+ %17 = affine.load %alloc_6[4, 0] : memref<6x1xf64>
+ %18 = arith.mulf %16, %17 : f64
+ %19 = arith.addf %15, %18 : f64
+ %20 = affine.load %alloc[0, 5] : memref<1x6xf64>
+ %21 = affine.load %alloc_6[5, 0] : memref<6x1xf64>
+ %22 = arith.mulf %20, %21 : f64
+ %23 = arith.addf %19, %22 : f64
+ affine.store %23, %alloc_7[0, 0] : memref<1x1xf64>
+ toy.print %alloc_7 : memref<1x1xf64>
+ memref.dealloc %alloc_6 : memref<6x1xf64>
+ memref.dealloc %alloc : memref<1x6xf64>
+ return
+ }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.llvm b/mlir/test/Examples/Toy/Matmul/matmul3.llvm
new file mode 100644
index 0000000000000..d61f8da3b2101
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.llvm
@@ -0,0 +1,181 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at nl = internal constant [2 x i8] c"\0A\00"
+ at frmt_spec = internal constant [4 x i8] c"%f \00"
+
+declare !dbg !3 void @free(ptr)
+
+declare !dbg !6 i32 @printf(ptr, ...)
+
+declare !dbg !7 ptr @malloc(i64)
+
+define void @main() !dbg !8 {
+ %1 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !9
+ %2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %1, 0, !dbg !9
+ %3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %2, ptr %1, 1, !dbg !9
+ %4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, i64 0, 2, !dbg !9
+ %5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %4, i64 1, 3, 0, !dbg !9
+ %6 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %5, i64 6, 3, 1, !dbg !9
+ %7 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %6, i64 6, 4, 0, !dbg !9
+ %8 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %7, i64 1, 4, 1, !dbg !9
+ %9 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 6) to i64)), !dbg !10
+ %10 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %9, 0, !dbg !10
+ %11 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %10, ptr %9, 1, !dbg !10
+ %12 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %11, i64 0, 2, !dbg !10
+ %13 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, i64 6, 3, 0, !dbg !10
+ %14 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %13, i64 1, 3, 1, !dbg !10
+ %15 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %14, i64 1, 4, 0, !dbg !10
+ %16 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %15, i64 1, 4, 1, !dbg !10
+ %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %18 = getelementptr double, ptr %17, i64 0, !dbg !10
+ store double 1.000000e+00, ptr %18, align 8, !dbg !10
+ %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %20 = getelementptr double, ptr %19, i64 1, !dbg !10
+ store double 2.000000e+00, ptr %20, align 8, !dbg !10
+ %21 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %22 = getelementptr double, ptr %21, i64 2, !dbg !10
+ store double 3.000000e+00, ptr %22, align 8, !dbg !10
+ %23 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %24 = getelementptr double, ptr %23, i64 3, !dbg !10
+ store double 4.000000e+00, ptr %24, align 8, !dbg !10
+ %25 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %26 = getelementptr double, ptr %25, i64 4, !dbg !10
+ store double 5.000000e+00, ptr %26, align 8, !dbg !10
+ %27 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !10
+ %28 = getelementptr double, ptr %27, i64 5, !dbg !10
+ store double 6.000000e+00, ptr %28, align 8, !dbg !10
+ %29 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %30 = getelementptr double, ptr %29, i64 0, !dbg !9
+ store double 1.000000e+00, ptr %30, align 8, !dbg !9
+ %31 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %32 = getelementptr double, ptr %31, i64 1, !dbg !9
+ store double 2.000000e+00, ptr %32, align 8, !dbg !9
+ %33 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %34 = getelementptr double, ptr %33, i64 2, !dbg !9
+ store double 3.000000e+00, ptr %34, align 8, !dbg !9
+ %35 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %36 = getelementptr double, ptr %35, i64 3, !dbg !9
+ store double 4.000000e+00, ptr %36, align 8, !dbg !9
+ %37 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %38 = getelementptr double, ptr %37, i64 4, !dbg !9
+ store double 5.000000e+00, ptr %38, align 8, !dbg !9
+ %39 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !9
+ %40 = getelementptr double, ptr %39, i64 5, !dbg !9
+ store double 6.000000e+00, ptr %40, align 8, !dbg !9
+ %41 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i64 1) to i64)), !dbg !11
+ %42 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %41, 0, !dbg !11
+ %43 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %42, ptr %41, 1, !dbg !11
+ %44 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %43, i64 0, 2, !dbg !11
+ %45 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %44, i64 1, 3, 0, !dbg !11
+ %46 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %45, i64 1, 3, 1, !dbg !11
+ %47 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %46, i64 1, 4, 0, !dbg !11
+ %48 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %47, i64 1, 4, 1, !dbg !11
+ %49 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %50 = getelementptr double, ptr %49, i64 0, !dbg !11
+ %51 = load double, ptr %50, align 8, !dbg !11
+ %52 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %53 = getelementptr double, ptr %52, i64 0, !dbg !11
+ %54 = load double, ptr %53, align 8, !dbg !11
+ %55 = fmul double %51, %54, !dbg !11
+ %56 = fadd double %55, 0.000000e+00, !dbg !11
+ %57 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %58 = getelementptr double, ptr %57, i64 1, !dbg !11
+ %59 = load double, ptr %58, align 8, !dbg !11
+ %60 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %61 = getelementptr double, ptr %60, i64 1, !dbg !11
+ %62 = load double, ptr %61, align 8, !dbg !11
+ %63 = fmul double %59, %62, !dbg !11
+ %64 = fadd double %56, %63, !dbg !11
+ %65 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %66 = getelementptr double, ptr %65, i64 2, !dbg !11
+ %67 = load double, ptr %66, align 8, !dbg !11
+ %68 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %69 = getelementptr double, ptr %68, i64 2, !dbg !11
+ %70 = load double, ptr %69, align 8, !dbg !11
+ %71 = fmul double %67, %70, !dbg !11
+ %72 = fadd double %64, %71, !dbg !11
+ %73 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %74 = getelementptr double, ptr %73, i64 3, !dbg !11
+ %75 = load double, ptr %74, align 8, !dbg !11
+ %76 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %77 = getelementptr double, ptr %76, i64 3, !dbg !11
+ %78 = load double, ptr %77, align 8, !dbg !11
+ %79 = fmul double %75, %78, !dbg !11
+ %80 = fadd double %72, %79, !dbg !11
+ %81 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %82 = getelementptr double, ptr %81, i64 4, !dbg !11
+ %83 = load double, ptr %82, align 8, !dbg !11
+ %84 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %85 = getelementptr double, ptr %84, i64 4, !dbg !11
+ %86 = load double, ptr %85, align 8, !dbg !11
+ %87 = fmul double %83, %86, !dbg !11
+ %88 = fadd double %80, %87, !dbg !11
+ %89 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 1, !dbg !11
+ %90 = getelementptr double, ptr %89, i64 5, !dbg !11
+ %91 = load double, ptr %90, align 8, !dbg !11
+ %92 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 1, !dbg !11
+ %93 = getelementptr double, ptr %92, i64 5, !dbg !11
+ %94 = load double, ptr %93, align 8, !dbg !11
+ %95 = fmul double %91, %94, !dbg !11
+ %96 = fadd double %88, %95, !dbg !11
+ %97 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !11
+ %98 = getelementptr double, ptr %97, i64 0, !dbg !11
+ store double %96, ptr %98, align 8, !dbg !11
+ br label %99, !dbg !12
+
+99: ; preds = %113, %0
+ %100 = phi i64 [ 0, %0 ], [ %115, %113 ]
+ %101 = icmp slt i64 %100, 1, !dbg !12
+ br i1 %101, label %102, label %116, !dbg !12
+
+102: ; preds = %99
+ br label %103, !dbg !12
+
+103: ; preds = %106, %102
+ %104 = phi i64 [ 0, %102 ], [ %112, %106 ]
+ %105 = icmp slt i64 %104, 1, !dbg !12
+ br i1 %105, label %106, label %113, !dbg !12
+
+106: ; preds = %103
+ %107 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %48, 1, !dbg !12
+ %108 = add i64 %100, %104, !dbg !12
+ %109 = getelementptr double, ptr %107, i64 %108, !dbg !12
+ %110 = load double, ptr %109, align 8, !dbg !12
+ %111 = call i32 (ptr, ...) @printf(ptr @frmt_spec, double %110), !dbg !12
+ %112 = add i64 %104, 1, !dbg !12
+ br label %103, !dbg !12
+
+113: ; preds = %103
+ %114 = call i32 (ptr, ...) @printf(ptr @nl), !dbg !12
+ %115 = add i64 %100, 1, !dbg !12
+ br label %99, !dbg !12
+
+116: ; preds = %99
+ %117 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %16, 0, !dbg !10
+ call void @free(ptr %117), !dbg !10
+ %118 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %8, 0, !dbg !9
+ call void @free(ptr %118), !dbg !9
+ ret void, !dbg !13
+}
+
+!llvm.module.flags = !{!0}
+!llvm.dbg.cu = !{!1}
+
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2, producer: "MLIR", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
+!2 = !DIFile(filename: "matmul3.mlir", directory: "")
+!3 = !DISubprogram(name: "free", linkageName: "free", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!4 = !DISubroutineType(cc: DW_CC_normal, types: !5)
+!5 = !{}
+!6 = !DISubprogram(name: "printf", linkageName: "printf", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!7 = !DISubprogram(name: "malloc", linkageName: "malloc", scope: !2, file: !2, line: 1, type: !4, scopeLine: 1, spFlags: DISPFlagOptimized)
+!8 = distinct !DISubprogram(name: "main", linkageName: "main", scope: !2, file: !2, line: 2, type: !4, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
+!9 = !DILocation(line: 6, column: 10, scope: !8)
+!10 = !DILocation(line: 4, column: 10, scope: !8)
+!11 = !DILocation(line: 7, column: 10, scope: !8)
+!12 = !DILocation(line: 8, column: 5, scope: !8)
+!13 = !DILocation(line: 9, column: 5, scope: !8)
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.mlir b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
new file mode 100644
index 0000000000000..5e8fd5eba2398
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
@@ -0,0 +1,11 @@
+module {
+ toy.func @main() {
+ %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+ %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+ %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+ %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+ %4 = toy.matmul %3, %1 : tensor<1x6xf64>, tensor<6x1xf64> -> tensor<*xf64>
+ toy.print %4 : tensor<*xf64>
+ toy.return
+ }
+}
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.toy b/mlir/test/Examples/Toy/Matmul/matmul3.toy
new file mode 100644
index 0000000000000..d268735176e7f
--- /dev/null
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.toy
@@ -0,0 +1,7 @@
+def main() {
+ var a<6, 1> = [1, 2, 3, 4, 5, 6];
+ var b<1, 6> = [1, 2, 3, 4, 5, 6];
+ var c = matmul(b,a);
+ print(c);
+}
+
>From 8c9165319f8c9bb4e90fada68d90077cb60ba30f Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Fri, 21 Jun 2024 04:56:45 +0000
Subject: [PATCH 04/13] Feat: Convert lowering of toy matmul to affine loops
---
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 64 +++++++++----------
1 file changed, 32 insertions(+), 32 deletions(-)
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 5c3e6a552855b..e5cd3168e77cb 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -331,46 +331,46 @@ struct MatmulOpLowering : public ConversionPattern {
toy::MatmulOpAdaptor matmulAdaptor(operands);
Value lhs = matmulAdaptor.getLhs();
Value rhs = matmulAdaptor.getRhs();
-
+
auto lhsType = dyn_cast<MemRefType>(lhs.getType());
auto rhsType = dyn_cast<MemRefType>(rhs.getType());
- if (!lhsType || !rhsType) {
- return failure();
- }
+ if (!lhsType || !rhsType) return failure();
+
int64_t M = lhsType.getShape()[0];
int64_t N = rhsType.getShape()[1];
int64_t K = lhsType.getShape()[1];
auto elementType = lhsType.getElementType();
auto resultType = MemRefType::get({M, N}, elementType);
Value result = rewriter.create<memref::AllocOp>(loc, resultType);
- for (int64_t i = 0; i < M; ++i) {
- for (int64_t j = 0; j < N; ++j) {
- // Initialize the sum to zero.
- Value sum = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64Type(), rewriter.getF64FloatAttr(0.0));
-
- for (int64_t k = 0; k < K; ++k) {
- // Load lhs[i, k] and rhs[k, j].
- Value lhsVal = rewriter.create<affine::AffineLoadOp>(loc, lhs, ValueRange{
- rewriter.create<arith::ConstantIndexOp>(loc, i),
- rewriter.create<arith::ConstantIndexOp>(loc, k)
- });
- Value rhsVal = rewriter.create<affine::AffineLoadOp>(loc, rhs, ValueRange{
- rewriter.create<arith::ConstantIndexOp>(loc, k),
- rewriter.create<arith::ConstantIndexOp>(loc, j)
- });
-
- // Perform the multiplication and accumulate the result.
- Value product = rewriter.create<arith::MulFOp>(loc, lhsVal, rhsVal);
- sum = rewriter.create<arith::AddFOp>(loc, sum, product);
- }
-
- // Store the computed value into the result matrix.
- rewriter.create<affine::AffineStoreOp>(loc, sum, result, ValueRange{
- rewriter.create<arith::ConstantIndexOp>(loc, i),
- rewriter.create<arith::ConstantIndexOp>(loc, j)
- });
- }
- }
+
+ rewriter.setInsertionPoint(op);
+ Value zero = rewriter.create<arith::ConstantOp>(op->getLoc(), rewriter.getF64FloatAttr(0.0));
+
+ auto outerLoop = rewriter.create<affine::AffineForOp>(op->getLoc(), 0, M, 1);
+ rewriter.setInsertionPointToStart(outerLoop.getBody());
+ Value i = outerLoop.getInductionVar();
+
+ auto middleLoop = rewriter.create<affine::AffineForOp>(op->getLoc(),0,N,1);
+ rewriter.setInsertionPointToStart(middleLoop.getBody());
+ Value j = middleLoop.getInductionVar();
+
+ rewriter.create<affine::AffineStoreOp>(op->getLoc(),zero,result,ValueRange{i,j});
+ Value sum = zero;
+ Value total_sum = zero;
+
+ auto innerLoop = rewriter.create<affine::AffineForOp>(op->getLoc(),0,K,1);
+ rewriter.setInsertionPointToStart(innerLoop.getBody());
+ Value k = innerLoop.getInductionVar();
+
+ Value lhsValue = rewriter.create<affine::AffineLoadOp>(op->getLoc(),lhs,ValueRange{i,k});
+ Value rhsValue = rewriter.create<affine::AffineLoadOp>(op->getLoc(),rhs,ValueRange{k,j});
+ Value product = rewriter.create<arith::MulFOp>(op->getLoc(),lhsValue,rhsValue);
+ sum = rewriter.create<arith::AddFOp>(op->getLoc(),sum,product);
+ Value currentValue = rewriter.create<affine::AffineLoadOp>(op->getLoc(), result, ValueRange{i, j});
+ total_sum = rewriter.create<arith::AddFOp>(op->getLoc(),sum,currentValue);
+ rewriter.setInsertionPoint(innerLoop.getBody()->getTerminator());
+ rewriter.create<affine::AffineStoreOp>(op->getLoc(),total_sum,result,ValueRange{i,j});
+
rewriter.replaceOp(op, result);
return success();
}
>From 239f031ec9d499d4fa794e01b8ce85e9955cfe6b Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Tue, 25 Jun 2024 13:31:51 +0000
Subject: [PATCH 05/13] Test: Add FileCheck commands for test cases
---
mlir/test/Examples/Toy/Matmul/matmul1.toy | 15 +++++++-
mlir/test/Examples/Toy/Matmul/matmul2.mlir | 45 ++++++++++++++++++++++
mlir/test/Examples/Toy/Matmul/matmul2.toy | 12 ++++++
mlir/test/Examples/Toy/Matmul/matmul3.mlir | 45 ++++++++++++++++++++++
mlir/test/Examples/Toy/Matmul/matmul3.toy | 11 ++++++
5 files changed, 127 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.toy b/mlir/test/Examples/Toy/Matmul/matmul1.toy
index 823608a6a2f02..136985d26a073 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul1.toy
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.toy
@@ -1,6 +1,19 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
+# Function that performs matrix multiplication
def main() {
var a<2, 3> = [1, 2, 3, 4, 5, 6];
var b<3, 2> = [1, 2, 3, 4, 5, 6];
var c = matmul(a,b);
print(c);
-}
\ No newline at end of file
+}
+
+#CHECK-LABEL: toy.func @main() {
+# CHECK-NEXT: %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<2x3xf64>
+# CHECK-NEXT: %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT: %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64>
+# CHECK-NEXT: %4 = toy.matmul %1, %3 : tensor<2x3xf64>, tensor<3x2xf64> -> tensor<*xf64>
+# CHECK-NEXT: toy.print %4 : tensor<*xf64>
+# CHECK-NEXT: toy.return
+# CHECK-NEXT: }
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.mlir b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
index e4145c8bb5dfa..a1a605817b575 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul2.mlir
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.mlir
@@ -1,3 +1,5 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
+
module {
toy.func @main() {
%0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@@ -9,3 +11,46 @@ module {
toy.return
}
}
+
+//CHECK-LABEL: func @main
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
+//CHECK-NEXT: %cst_0 = arith.constant 6.000000e+00 : f64
+//CHECK-NEXT: %cst_1 = arith.constant 5.000000e+00 : f64
+//CHECK-NEXT: %cst_2 = arith.constant 4.000000e+00 : f64
+//CHECK-NEXT: %cst_3 = arith.constant 3.000000e+00 : f64
+//CHECK-NEXT: %cst_4 = arith.constant 2.000000e+00 : f64
+//CHECK-NEXT: %cst_5 = arith.constant 1.000000e+00 : f64
+//CHECK-NEXT: %alloc = memref.alloc() : memref<1x6xf64>
+//CHECK-NEXT: %alloc_6 = memref.alloc() : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+//CHECK-NEXT: %alloc_7 = memref.alloc() : memref<6x6xf64>
+//CHECK-NEXT: affine.for %arg0 = 0 to 6 {
+//CHECK-NEXT: affine.for %arg1 = 0 to 6 {
+//CHECK-NEXT: affine.store %cst, %alloc_7[%arg0, %arg1] : memref<6x6xf64>
+//CHECK-NEXT: affine.for %arg2 = 0 to 1 {
+//CHECK-NEXT: %0 = affine.load %alloc_6[%arg0, %arg2] : memref<6x1xf64>
+//CHECK-NEXT: %1 = affine.load %alloc[%arg2, %arg1] : memref<1x6xf64>
+//CHECK-NEXT: %2 = arith.mulf %0, %1 : f64
+//CHECK-NEXT: %3 = arith.addf %2, %cst : f64
+//CHECK-NEXT: %4 = affine.load %alloc_7[%arg0, %arg1] : memref<6x6xf64>
+//CHECK-NEXT: %5 = arith.addf %3, %4 : f64
+//CHECK-NEXT: affine.store %5, %alloc_7[%arg0, %arg1] : memref<6x6xf64>
+//CHECK-NEXT: }
+//CHECK-NEXT: }
+//CHECK-NEXT: }
+//CHECK-NEXT: toy.print %alloc_7 : memref<6x6xf64>
+//CHECK-NEXT: memref.dealloc %alloc_6 : memref<6x1xf64>
+//CHECK-NEXT: memref.dealloc %alloc : memref<1x6xf64>
+//CHECK-NEXT: return
+//CHECK-NEXT: }
diff --git a/mlir/test/Examples/Toy/Matmul/matmul2.toy b/mlir/test/Examples/Toy/Matmul/matmul2.toy
index 6c37cc1af1f6a..d72c35be4c692 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul2.toy
+++ b/mlir/test/Examples/Toy/Matmul/matmul2.toy
@@ -1,6 +1,18 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
def main() {
var a<6, 1> = [1, 2, 3, 4, 5, 6];
var b<1, 6> = [1, 2, 3, 4, 5, 6];
var c = matmul(a,b);
print(c);
}
+
+#CHECK-LABEL: toy.func @main() {
+# CHECK-NEXT: %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+# CHECK-NEXT: %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT: %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+# CHECK-NEXT: %4 = toy.matmul %1, %3 : tensor<6x1xf64>, tensor<1x6xf64> -> tensor<*xf64>
+# CHECK-NEXT: toy.print %4 : tensor<*xf64>
+# CHECK-NEXT: toy.return
+# CHECK-NEXT: }
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.mlir b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
index 5e8fd5eba2398..8573534375f8d 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul3.mlir
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.mlir
@@ -1,3 +1,4 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
module {
toy.func @main() {
%0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@@ -9,3 +10,47 @@ module {
toy.return
}
}
+
+//CHECK-LABEL: func @main
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
+//CHECK-NEXT: %cst_0 = arith.constant 6.000000e+00 : f64
+//CHECK-NEXT: %cst_1 = arith.constant 5.000000e+00 : f64
+//CHECK-NEXT: %cst_2 = arith.constant 4.000000e+00 : f64
+//CHECK-NEXT: %cst_3 = arith.constant 3.000000e+00 : f64
+//CHECK-NEXT: %cst_4 = arith.constant 2.000000e+00 : f64
+//CHECK-NEXT: %cst_5 = arith.constant 1.000000e+00 : f64
+//CHECK-NEXT: %alloc = memref.alloc() : memref<1x6xf64>
+//CHECK-NEXT: %alloc_6 = memref.alloc() : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc_6[0, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc_6[1, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc_6[2, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc_6[3, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc_6[4, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc_6[5, 0] : memref<6x1xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc[0, 0] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc[0, 1] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc[0, 2] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc[0, 3] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc[0, 4] : memref<1x6xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc[0, 5] : memref<1x6xf64>
+//CHECK-NEXT: %alloc_7 = memref.alloc() : memref<1x1xf64>
+//CHECK-NEXT: affine.for %arg0 = 0 to 1 {
+//CHECK-NEXT: affine.for %arg1 = 0 to 1 {
+//CHECK-NEXT: affine.store %cst, %alloc_7[%arg0, %arg1] : memref<1x1xf64>
+//CHECK-NEXT: affine.for %arg2 = 0 to 6 {
+//CHECK-NEXT: %0 = affine.load %alloc[%arg0, %arg2] : memref<1x6xf64>
+//CHECK-NEXT: %1 = affine.load %alloc_6[%arg2, %arg1] : memref<6x1xf64>
+//CHECK-NEXT: %2 = arith.mulf %0, %1 : f64
+//CHECK-NEXT: %3 = arith.addf %2, %cst : f64
+//CHECK-NEXT: %4 = affine.load %alloc_7[%arg0, %arg1] : memref<1x1xf64>
+//CHECK-NEXT: %5 = arith.addf %3, %4 : f64
+//CHECK-NEXT: affine.store %5, %alloc_7[%arg0, %arg1] : memref<1x1xf64>
+//CHECK-NEXT: }
+//CHECK-NEXT: }
+//CHECK-NEXT: }
+//CHECK-NEXT: toy.print %alloc_7 : memref<1x1xf64>
+//CHECK-NEXT: memref.dealloc %alloc_6 : memref<6x1xf64>
+//CHECK-NEXT: memref.dealloc %alloc : memref<1x6xf64>
+//CHECK-NEXT: return
+//CHECK-NEXT: }
+
diff --git a/mlir/test/Examples/Toy/Matmul/matmul3.toy b/mlir/test/Examples/Toy/Matmul/matmul3.toy
index d268735176e7f..a56c6e31cacb8 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul3.toy
+++ b/mlir/test/Examples/Toy/Matmul/matmul3.toy
@@ -1,3 +1,5 @@
+# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s
+
def main() {
var a<6, 1> = [1, 2, 3, 4, 5, 6];
var b<1, 6> = [1, 2, 3, 4, 5, 6];
@@ -5,3 +7,12 @@ def main() {
print(c);
}
+#CHECK-LABEL: toy.func @main() {
+# CHECK-NEXT: %0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<6xf64>) to tensor<6x1xf64>
+# CHECK-NEXT: %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
+# CHECK-NEXT: %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<1x6xf64>
+# CHECK-NEXT: %4 = toy.matmul %3, %1 : tensor<1x6xf64>, tensor<6x1xf64> -> tensor<*xf64>
+# CHECK-NEXT: toy.print %4 : tensor<*xf64>
+# CHECK-NEXT: toy.return
+# CHECK-NEXT: }
>From 68ea36a89ea3ab60a5371d978488d7c04b7fd7cf Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Tue, 25 Jun 2024 13:48:41 +0000
Subject: [PATCH 06/13] Feat add Fileck for matmul1.mlir
---
mlir/test/Examples/Toy/Matmul/matmul1.mlir | 45 ++++++++++++++++++++++
1 file changed, 45 insertions(+)
diff --git a/mlir/test/Examples/Toy/Matmul/matmul1.mlir b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
index 956da787f9b36..0a6fbf5be7087 100644
--- a/mlir/test/Examples/Toy/Matmul/matmul1.mlir
+++ b/mlir/test/Examples/Toy/Matmul/matmul1.mlir
@@ -1,3 +1,5 @@
+// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
+
module {
toy.func @main() {
%0 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@@ -9,3 +11,46 @@ module {
toy.return
}
}
+
+//CHECK-LABEL: func @main
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
+//CHECK-NEXT: %cst_0 = arith.constant 6.000000e+00 : f64
+//CHECK-NEXT: %cst_1 = arith.constant 5.000000e+00 : f64
+//CHECK-NEXT: %cst_2 = arith.constant 4.000000e+00 : f64
+//CHECK-NEXT: %cst_3 = arith.constant 3.000000e+00 : f64
+//CHECK-NEXT: %cst_4 = arith.constant 2.000000e+00 : f64
+//CHECK-NEXT: %cst_5 = arith.constant 1.000000e+00 : f64
+//CHECK-NEXT: %alloc = memref.alloc() : memref<3x2xf64>
+//CHECK-NEXT: %alloc_6 = memref.alloc() : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc_6[0, 0] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc_6[0, 1] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc_6[0, 2] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc_6[1, 0] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc_6[1, 1] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc_6[1, 2] : memref<2x3xf64>
+//CHECK-NEXT: affine.store %cst_5, %alloc[0, 0] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_4, %alloc[0, 1] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_3, %alloc[1, 0] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_2, %alloc[1, 1] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_1, %alloc[2, 0] : memref<3x2xf64>
+//CHECK-NEXT: affine.store %cst_0, %alloc[2, 1] : memref<3x2xf64>
+//CHECK-NEXT: %alloc_7 = memref.alloc() : memref<2x2xf64>
+//CHECK-NEXT: affine.for %arg0 = 0 to 2 {
+//CHECK-NEXT: affine.for %arg1 = 0 to 2 {
+//CHECK-NEXT: affine.store %cst, %alloc_7[%arg0, %arg1] : memref<2x2xf64>
+//CHECK-NEXT: affine.for %arg2 = 0 to 3 {
+//CHECK-NEXT: %0 = affine.load %alloc_6[%arg0, %arg2] : memref<2x3xf64>
+//CHECK-NEXT: %1 = affine.load %alloc[%arg2, %arg1] : memref<3x2xf64>
+//CHECK-NEXT: %2 = arith.mulf %0, %1 : f64
+//CHECK-NEXT: %3 = arith.addf %2, %cst : f64
+//CHECK-NEXT: %4 = affine.load %alloc_7[%arg0, %arg1] : memref<2x2xf64>
+//CHECK-NEXT: %5 = arith.addf %3, %4 : f64
+//CHECK-NEXT: affine.store %5, %alloc_7[%arg0, %arg1] : memref<2x2xf64>
+//CHECK-NEXT: }
+//CHECK-NEXT: }
+//CHECK-NEXT: }
+//CHECK-NEXT: toy.print %alloc_7 : memref<2x2xf64>
+//CHECK-NEXT: memref.dealloc %alloc_6 : memref<2x3xf64>
+//CHECK-NEXT: memref.dealloc %alloc : memref<3x2xf64>
+//CHECK-NEXT: return
+//CHECK-NEXT: }
>From 283c8c927bbf2d1f5d1cb27efe81f406ec407d89 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 4 Jul 2024 07:25:41 +0000
Subject: [PATCH 07/13] Feat: Code for Loop unroll in SCF
---
.../mlir/Dialect/SCF/Transforms/Passes.h | 8 +
.../mlir/Dialect/SCF/Transforms/Passes.td | 17 ++
.../lib/Dialect/SCF/Transforms/LoopUnroll.cpp | 183 ++++++++++++++++++
3 files changed, 208 insertions(+)
create mode 100644 mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index fb8411418ff9a..d325fa515a0e3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,12 +62,20 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();
+// Creates a pass that counts the number of operations in SCF
+std::unique_ptr<Pass> createCounterPass();
+
+// Creates pass for loop unroll
+std::unique_ptr<Pass> createLoopUnroll(
+ int unrollFactor = 4, bool unrollFull = false );
+
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();
// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 9b29affb97c43..694e75bdd49b6 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -125,6 +125,22 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let constructor = "mlir::createForallToForLoopPass()";
}
+def SCFCounter : Pass<"scf-counter"> {
+ let summary = "Count operations for SCF";
+ let constructor = "mlir::createCounterPass()";
+}
+
+def SCFLoopUnroll : Pass<"scf-loop-unroll">{
+ let summary = " Pass to unroll for loops in scf";
+ let constructor = "mlir::createLoopUnroll()";
+ let options = [
+ Option<"unrollFactor", "unroll-factor", "int", /*default=*/"4",
+ "Use this unroll factor for all loops being unrolled">,
+ Option<"unrollFull", "unroll-full", "bool", /*default=*/"false",
+ "Fully unroll loops">,
+ ];
+}
+
def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let summary = "Convert SCF forall loops to SCF parallel loops";
let constructor = "mlir::createForallToParallelLoopPass()";
@@ -164,4 +180,5 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
new file mode 100644
index 0000000000000..38a8349793d0e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
@@ -0,0 +1,183 @@
+// /*
+// Unrolls loop in the SCF Dialect
+// */
+
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+namespace mlir{
+ #define GEN_PASS_DEF_SCFLOOPUNROLL
+ #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+}
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::scf;
+using scf::ForOp;
+
+namespace {
+struct LoopUnroll : public impl::SCFLoopUnrollBase<LoopUnroll> {
+ private:
+ int unrollFactor;
+ bool unrollFull;
+ public:
+ LoopUnroll() = default;
+ LoopUnroll(int unrollFactor, bool unrollFull)
+ : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ OpBuilder builder(parentOp->getContext());
+ SmallVector<ForOp, 4> loops;
+ gatherInnermostLoops(parentOp,loops);
+ if(loops.empty())return;
+ for(auto forOp: loops)
+ {
+ auto result = runOnSCFForOp(builder, forOp);
+ }
+ }
+ Value createUpdatedInductionVar(unsigned i,Value iv,OpBuilder &builder,int64_t step){
+ Location loc = iv.getLoc();
+ Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+ Value constantStep = builder.create<arith::ConstantIndexOp>(loc, step);
+ Value increment = builder.create<arith::MulIOp>(loc, constantI, constantStep);
+ Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+ return updatedIv;
+ }
+
+ void generateUnroll(ForOp forOp, int64_t step)
+ {
+ Block *loopBody = forOp.getBody();
+ auto returnValues = forOp.getBody()->getTerminator()->getOperands();
+ Value forOpInductionVar = forOp.getInductionVar();
+ ValueRange iterArgs(forOp.getRegionIterArgs());
+ auto builder = OpBuilder::atBlockTerminator(loopBody);
+ Block::iterator originalBlockEnd = std::prev(loopBody->end(),2);
+ SmallVector<Value,4> lastReturnValues(returnValues);
+ for(int i = 1 ; i < unrollFactor; i++)
+ {
+ IRMapping mapper;
+ mapper.map(iterArgs,lastReturnValues);
+ if(!forOpInductionVar.use_empty())
+ {
+ Value updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar,builder,step);
+ mapper.map(forOpInductionVar,updatedInductionVar);
+ }
+ for (auto it = loopBody->begin(); it != std::next(originalBlockEnd); it++) {
+ Operation *clonedOp = builder.clone(*it, mapper);
+ }
+ for(int j = 0; j < lastReturnValues.size(); j++)
+ {
+ Operation *defOp = returnValues[j].getDefiningOp();
+ if(defOp && defOp->getBlock()==loopBody)
+ {
+ lastReturnValues[j] = mapper.lookup(returnValues[j]);
+ }
+ }
+ }
+ loopBody->getTerminator()->setOperands(lastReturnValues);
+ }
+ void fullUnroll(ForOp forOp, int64_t tripCount, int64_t step)
+ {
+ IRRewriter rewriter(forOp.getContext());
+ if(tripCount == 0)return;
+ if(tripCount ==1)
+ {
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return;
+ }
+ unrollFactor = tripCount;
+ generateUnroll(forOp,step);
+ return;
+ }
+ LogicalResult runOnSCFForOp(OpBuilder &builder, ForOp forOp) {
+ IRRewriter rewriter(forOp.getContext());
+ Value lowerBound = forOp.getLowerBound();
+ Value upperBound = forOp.getUpperBound();
+ Value step = forOp.getStep();
+ Value newStepValue,upperBoundUnrolled;
+ bool createHandlerLoop = false;
+ auto lowerBoundConst = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
+ auto upperBoundConst = upperBound.getDefiningOp<arith::ConstantIndexOp>();
+ auto stepConst = step.getDefiningOp<arith::ConstantIndexOp>();
+ if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+ forOp.emitError("Expected constant bounds and step for unrolling.");
+ return failure();
+ }
+ int64_t lowerBoundValue = lowerBoundConst.value();
+ int64_t upperBoundValue = upperBoundConst.value();
+ int64_t stepValue = stepConst.value();
+ int64_t tripCount = (upperBoundValue-lowerBoundValue)/stepValue;
+ int64_t multipliedStepValue = stepValue * unrollFactor;
+ int64_t tripCountUnrolled = tripCount -(tripCount%unrollFactor);
+ int64_t unrolledUpperBound = lowerBoundValue +(tripCountUnrolled*stepValue);
+ createHandlerLoop = unrolledUpperBound < upperBoundValue;
+ if (Block *prevBlock = forOp->getBlock()->getPrevNode())builder.setInsertionPointToEnd(prevBlock);
+ else builder.setInsertionPoint(forOp);
+ if(createHandlerLoop)
+ {
+ upperBoundUnrolled = builder.create<arith::ConstantIndexOp>(forOp.getLoc(),unrolledUpperBound);
+ newStepValue = builder.create<arith::ConstantIndexOp>(
+ forOp.getLoc(), multipliedStepValue);
+ builder.setInsertionPoint(forOp);
+ }
+ else{
+ upperBoundUnrolled = upperBound;
+ newStepValue = step;
+ }
+ if(createHandlerLoop)
+ {
+ OpBuilder loopBuilder(forOp->getContext());
+ loopBuilder.setInsertionPoint(forOp->getBlock(),std::next(Block::iterator(forOp)));
+ auto handlerForOp = cast<scf::ForOp>(loopBuilder.clone(*forOp));
+ handlerForOp.setLowerBound(upperBoundUnrolled);
+
+ auto results = forOp.getResults();
+ auto handlerResults = handlerForOp.getResults();
+ for (auto element : llvm::zip(results, handlerResults)) {
+ std::get<0>(element).replaceAllUsesWith(std::get<1>(element));
+ }
+ handlerForOp->setOperands(handlerForOp.getNumControlOperands(),
+ handlerForOp.getInitArgs().size(), results);
+ (void)handlerForOp.promoteIfSingleIteration(rewriter);
+
+ }
+ forOp.setStep(upperBoundUnrolled);
+ forOp.setStep(newStepValue);
+ if(unrollFull == false)generateUnroll(forOp,stepValue);
+ else fullUnroll(forOp,tripCount,stepValue);
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ llvm::errs() << "Completed unroll\n";
+ return success();
+ }
+ static bool isInnermostSCFForOp(ForOp op) {
+ return !op.getBody()
+ ->walk([&](ForOp nestedForOp) {
+ return WalkResult::interrupt();
+ })
+ .wasInterrupted();
+ }
+
+ static void gatherInnermostLoops(Operation *op,
+ SmallVectorImpl<ForOp> &loops) {
+ op->walk([&](ForOp forOp) {
+ if (isInnermostSCFForOp(forOp))
+ loops.push_back(forOp);
+ });
+ }
+
+ virtual ~LoopUnroll() = default;
+};
+} // end anonymous namespace
+
+
+std::unique_ptr<Pass> mlir::createLoopUnroll(int unrollFactor, bool unrollFull) {
+ return std::make_unique<LoopUnroll>(unrollFactor,unrollFull);
+}
\ No newline at end of file
>From e3c1232602825974772e6d9bf7577c6f2ba2633f Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Thu, 4 Jul 2024 10:01:18 +0000
Subject: [PATCH 08/13] Test: Add test cases
---
mlir/test/Dialect/SCF/loop-unroll-scf.mlir | 38 +++++++++
mlir/test/Dialect/SCF/loop-unroll-scf2.mlir | 90 +++++++++++++++++++++
mlir/test/Dialect/SCF/loop-unroll-scf3.mlir | 49 +++++++++++
3 files changed, 177 insertions(+)
create mode 100644 mlir/test/Dialect/SCF/loop-unroll-scf.mlir
create mode 100644 mlir/test/Dialect/SCF/loop-unroll-scf2.mlir
create mode 100644 mlir/test/Dialect/SCF/loop-unroll-scf3.mlir
diff --git a/mlir/test/Dialect/SCF/loop-unroll-scf.mlir b/mlir/test/Dialect/SCF/loop-unroll-scf.mlir
new file mode 100644
index 0000000000000..d0fb3fd0aa41c
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-scf.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --scf-loop-unroll --split-input-file | FileCheck %s
+module {
+ func.func @main() -> f32 {
+ %sum = arith.constant 0.0 : f32
+ %val = arith.constant 2.0 : f32
+ %N = arith.constant 10 : index
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+ %new_sum = arith.mulf %iter_sum, %val : f32
+ scf.yield %new_sum : f32
+ }
+ return %result : f32
+ }
+}
+//CHECK-LABEL: func.func @main() -> f32 {
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
+//CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f32
+//CHECK-NEXT: %c10 = arith.constant 10 : index
+//CHECK-NEXT: %c0 = arith.constant 0 : index
+//CHECK-NEXT: %c1 = arith.constant 1 : index
+//CHECK-NEXT: %c8 = arith.constant 8 : index
+//CHECK-NEXT: %c4 = arith.constant 4 : index
+//CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c10 step %c4 iter_args(%arg1 = %cst) -> (f32) {
+//CHECK-NEXT: %2 = arith.mulf %arg1, %cst_0 : f32
+//CHECK-NEXT: %3 = arith.mulf %2, %cst_0 : f32
+//CHECK-NEXT: %4 = arith.mulf %3, %cst_0 : f32
+//CHECK-NEXT: %5 = arith.mulf %4, %cst_0 : f32
+//CHECK-NEXT: scf.yield %5 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: %1 = scf.for %arg0 = %c8 to %c10 step %c1 iter_args(%arg1 = %0) -> (f32) {
+//CHECK-NEXT: %2 = arith.mulf %arg1, %cst_0 : f32
+//CHECK-NEXT: scf.yield %2 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: return
+//CHECK-NEXT: }
diff --git a/mlir/test/Dialect/SCF/loop-unroll-scf2.mlir b/mlir/test/Dialect/SCF/loop-unroll-scf2.mlir
new file mode 100644
index 0000000000000..552b8f5fb67ff
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-scf2.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s --scf-loop-unroll --split-input-file | FileCheck %s
+
+module {
+ func.func @main() -> f32 {
+ %N = arith.constant 10 : index
+ %val = arith.constant 2.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %array = memref.alloc() : memref<10xf32>
+
+ // Initialize array with %val
+ scf.for %i = %c0 to %N step %c1 {
+ memref.store %val, %array[%i] : memref<10xf32>
+ }
+
+ %sum = arith.constant 0.0 : f32
+
+ %result = scf.for %j = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+ %current_val = memref.load %array[%j] : memref<10xf32>
+ %new_sum = arith.addf %iter_sum, %current_val : f32
+ scf.yield %new_sum : f32
+ }
+
+ return %result : f32
+ }
+}
+
+//CHECK-LABEL: func.func @main() -> f32 {
+//CHECK-NEXT: %c10 = arith.constant 10 : index
+//CHECK-NEXT: %cst = arith.constant 2.000000e+00 : f32
+//CHECK-NEXT: %c0 = arith.constant 0 : index
+//CHECK-NEXT: %c1 = arith.constant 1 : index
+//CHECK-NEXT: %alloc = memref.alloc() : memref<10xf32>
+//CHECK-NEXT: %c8 = arith.constant 8 : index
+//CHECK-NEXT: %c4 = arith.constant 4 : index
+//CHECK-NEXT: scf.for %arg0 = %c0 to %c10 step %c4 {
+//CHECK-NEXT: memref.store %cst, %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT: %c1_3 = arith.constant 1 : index
+//CHECK-NEXT: %c1_4 = arith.constant 1 : index
+//CHECK-NEXT: %2 = arith.muli %c1_3, %c1_4 : index
+//CHECK-NEXT: %3 = arith.addi %arg0, %2 : index
+//CHECK-NEXT: memref.store %cst, %alloc[%3] : memref<10xf32>
+//CHECK-NEXT: %c2 = arith.constant 2 : index
+//CHECK-NEXT: %c1_5 = arith.constant 1 : index
+//CHECK-NEXT: %4 = arith.muli %c2, %c1_5 : index
+//CHECK-NEXT: %5 = arith.addi %arg0, %4 : index
+//CHECK-NEXT: memref.store %cst, %alloc[%5] : memref<10xf32>
+//CHECK-NEXT: %c3 = arith.constant 3 : index
+//CHECK-NEXT: %c1_6 = arith.constant 1 : index
+//CHECK-NEXT: %6 = arith.muli %c3, %c1_6 : index
+//CHECK-NEXT: %7 = arith.addi %arg0, %6 : index
+//CHECK-NEXT: memref.store %cst, %alloc[%7] : memref<10xf32>
+//CHECK-NEXT: }
+//CHECK-NEXT: scf.for %arg0 = %c8 to %c10 step %c1 {
+//CHECK-NEXT: memref.store %cst, %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT: }
+//CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32
+//CHECK-NEXT: %c8_1 = arith.constant 8 : index
+//CHECK-NEXT: %c4_2 = arith.constant 4 : index
+//CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c10 step %c4_2 iter_args(%arg1 = %cst_0) -> (f32) {
+//CHECK-NEXT: %2 = memref.load %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT: %3 = arith.addf %arg1, %2 : f32
+//CHECK-NEXT: %c1_3 = arith.constant 1 : index
+//CHECK-NEXT: %c1_4 = arith.constant 1 : index
+//CHECK-NEXT: %4 = arith.muli %c1_3, %c1_4 : index
+//CHECK-NEXT: %5 = arith.addi %arg0, %4 : index
+//CHECK-NEXT: %6 = memref.load %alloc[%5] : memref<10xf32>
+//CHECK-NEXT: %7 = arith.addf %3, %6 : f32
+//CHECK-NEXT: %c2 = arith.constant 2 : index
+//CHECK-NEXT: %c1_5 = arith.constant 1 : index
+//CHECK-NEXT: %8 = arith.muli %c2, %c1_5 : index
+//CHECK-NEXT: %9 = arith.addi %arg0, %8 : index
+//CHECK-NEXT: %10 = memref.load %alloc[%9] : memref<10xf32>
+//CHECK-NEXT: %11 = arith.addf %7, %10 : f32
+//CHECK-NEXT: %c3 = arith.constant 3 : index
+//CHECK-NEXT: %c1_6 = arith.constant 1 : index
+//CHECK-NEXT: %12 = arith.muli %c3, %c1_6 : index
+//CHECK-NEXT: %13 = arith.addi %arg0, %12 : index
+//CHECK-NEXT: %14 = memref.load %alloc[%13] : memref<10xf32>
+//CHECK-NEXT: %15 = arith.addf %11, %14 : f32
+//CHECK-NEXT: scf.yield %15 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: %1 = scf.for %arg0 = %c8_1 to %c10 step %c1 iter_args(%arg1 = %0) -> (f32) {
+//CHECK-NEXT: %2 = memref.load %alloc[%arg0] : memref<10xf32>
+//CHECK-NEXT: %3 = arith.addf %arg1, %2 : f32
+//CHECK-NEXT: scf.yield %3 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: return %1 : f32
+//CHECK-NEXT: }
diff --git a/mlir/test/Dialect/SCF/loop-unroll-scf3.mlir b/mlir/test/Dialect/SCF/loop-unroll-scf3.mlir
new file mode 100644
index 0000000000000..7c2f4bdb76cce
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-scf3.mlir
@@ -0,0 +1,49 @@
+module {
+ func.func @main() -> f32 {
+ %sum = arith.constant 0.0 : f32
+ %val = arith.constant 2.0 : f32
+ %N = arith.constant 4 : index
+ %num = arith.constant 10 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+ %new_sum = arith.addf %iter_sum, %val : f32
+ %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %new_sum) -> (f32) {
+ %new_sum2 = arith.addf %iter_sum2, %val : f32
+ scf.yield %new_sum2 : f32
+ }
+ %new_sum3 = arith.addf %result2, %val : f32
+ scf.yield %new_sum : f32
+ }
+ return %result : f32
+ }
+}
+
+//CHECK-LABEL: func.func @main() -> f32 {
+//CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
+//CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f32
+//CHECK-NEXT: %c4 = arith.constant 4 : index
+//CHECK-NEXT: %c10 = arith.constant 10 : index
+//CHECK-NEXT: %c0 = arith.constant 0 : index
+//CHECK-NEXT: %c1 = arith.constant 1 : index
+//CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %cst) -> (f32) {
+//CHECK-NEXT: %1 = arith.addf %arg1, %cst_0 : f32
+//CHECK-NEXT: %c8 = arith.constant 8 : index
+//CHECK-NEXT: %c4_1 = arith.constant 4 : index
+//CHECK-NEXT: %2 = scf.for %arg2 = %c0 to %c10 step %c4_1 iter_args(%arg3 = %1) -> (f32) {
+//CHECK-NEXT: %5 = arith.addf %arg3, %cst_0 : f32
+//CHECK-NEXT: %6 = arith.addf %5, %cst_0 : f32
+//CHECK-NEXT: %7 = arith.addf %6, %cst_0 : f32
+//CHECK-NEXT: %8 = arith.addf %7, %cst_0 : f32
+//CHECK-NEXT: scf.yield %8 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: %3 = scf.for %arg2 = %c8 to %c10 step %c1 iter_args(%arg3 = %2) -> (f32) {
+//CHECK-NEXT: %5 = arith.addf %arg3, %cst_0 : f32
+//CHECK-NEXT: scf.yield %5 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: %4 = arith.addf %3, %cst_0 : f32
+//CHECK-NEXT: scf.yield %1 : f32
+//CHECK-NEXT: }
+//CHECK-NEXT: return %0 : f32
+//CHECK-NEXT: }
>From e2e749087678ac53d3bbe7ef23d2d9319714d049 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Sat, 6 Jul 2024 13:49:54 +0000
Subject: [PATCH 09/13] Refactor:Make code formatted and readable
---
.../lib/Dialect/SCF/Transforms/LoopUnroll.cpp | 307 +++++++++---------
1 file changed, 157 insertions(+), 150 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
index 38a8349793d0e..f07c6d12f4c79 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnroll.cpp
@@ -2,22 +2,21 @@
// Unrolls loop in the SCF Dialect
// */
-
#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LLVM.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-namespace mlir{
- #define GEN_PASS_DEF_SCFLOOPUNROLL
- #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
-}
+#include "mlir/Support/LLVM.h"
+namespace mlir {
+#define GEN_PASS_DEF_SCFLOOPUNROLL
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
using namespace llvm;
using namespace mlir;
using namespace mlir::scf;
@@ -25,159 +24,167 @@ using scf::ForOp;
namespace {
struct LoopUnroll : public impl::SCFLoopUnrollBase<LoopUnroll> {
- private:
- int unrollFactor;
- bool unrollFull;
- public:
- LoopUnroll() = default;
- LoopUnroll(int unrollFactor, bool unrollFull)
- : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
- void runOnOperation() override {
- Operation *parentOp = getOperation();
- OpBuilder builder(parentOp->getContext());
- SmallVector<ForOp, 4> loops;
- gatherInnermostLoops(parentOp,loops);
- if(loops.empty())return;
- for(auto forOp: loops)
- {
- auto result = runOnSCFForOp(builder, forOp);
- }
- }
- Value createUpdatedInductionVar(unsigned i,Value iv,OpBuilder &builder,int64_t step){
- Location loc = iv.getLoc();
- Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
- Value constantStep = builder.create<arith::ConstantIndexOp>(loc, step);
- Value increment = builder.create<arith::MulIOp>(loc, constantI, constantStep);
- Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
- return updatedIv;
+private:
+ int unrollFactor;
+ bool unrollFull;
+
+public:
+ LoopUnroll() = default;
+ LoopUnroll(int unrollFactor, bool unrollFull)
+ : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ OpBuilder builder(parentOp->getContext());
+ SmallVector<ForOp, 4> loops;
+ gatherInnermostLoops(parentOp, loops);
+ if (loops.empty())
+ return;
+ for (auto forOp : loops) {
+ auto result = runOnSCFForOp(builder, forOp);
}
+ }
- void generateUnroll(ForOp forOp, int64_t step)
- {
- Block *loopBody = forOp.getBody();
- auto returnValues = forOp.getBody()->getTerminator()->getOperands();
- Value forOpInductionVar = forOp.getInductionVar();
- ValueRange iterArgs(forOp.getRegionIterArgs());
- auto builder = OpBuilder::atBlockTerminator(loopBody);
- Block::iterator originalBlockEnd = std::prev(loopBody->end(),2);
- SmallVector<Value,4> lastReturnValues(returnValues);
- for(int i = 1 ; i < unrollFactor; i++)
- {
- IRMapping mapper;
- mapper.map(iterArgs,lastReturnValues);
- if(!forOpInductionVar.use_empty())
- {
- Value updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar,builder,step);
- mapper.map(forOpInductionVar,updatedInductionVar);
- }
- for (auto it = loopBody->begin(); it != std::next(originalBlockEnd); it++) {
- Operation *clonedOp = builder.clone(*it, mapper);
- }
- for(int j = 0; j < lastReturnValues.size(); j++)
- {
- Operation *defOp = returnValues[j].getDefiningOp();
- if(defOp && defOp->getBlock()==loopBody)
- {
- lastReturnValues[j] = mapper.lookup(returnValues[j]);
- }
- }
+ Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
+ int64_t step) {
+ /*Function writes ir to update the induction variable*/
+ Location loc = iv.getLoc();
+ Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+ Value constantStep = builder.create<arith::ConstantIndexOp>(loc, step);
+ Value increment =
+ builder.create<arith::MulIOp>(loc, constantI, constantStep);
+ Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+ return updatedIv;
+ }
+
+ void generateUnroll(ForOp forOp, int64_t step) {
+ /*Function unrolls a given loop by the given step*/
+ Block *loopBody = forOp.getBody();
+ auto returnValues = forOp.getBody()->getTerminator()->getOperands();
+ Value forOpInductionVar = forOp.getInductionVar();
+ ValueRange iterArgs(forOp.getRegionIterArgs());
+ auto builder = OpBuilder::atBlockTerminator(loopBody);
+ Block::iterator originalBlockEnd = std::prev(loopBody->end(), 2);
+ SmallVector<Value, 4> lastReturnValues(returnValues);
+ for (int i = 1; i < unrollFactor; i++) {
+ IRMapping mapper;
+ mapper.map(iterArgs, lastReturnValues);
+ if (!forOpInductionVar.use_empty()) {
+ Value updatedInductionVar =
+ createUpdatedInductionVar(i, forOpInductionVar, builder, step);
+ mapper.map(forOpInductionVar, updatedInductionVar);
+ }
+ for (auto it = loopBody->begin(); it != std::next(originalBlockEnd);
+ it++) {
+ Operation *clonedOp = builder.clone(*it, mapper);
+ }
+ for (int j = 0; j < lastReturnValues.size(); j++) {
+ Operation *defOp = returnValues[j].getDefiningOp();
+ if (defOp && defOp->getBlock() == loopBody) {
+ lastReturnValues[j] = mapper.lookup(returnValues[j]);
}
- loopBody->getTerminator()->setOperands(lastReturnValues);
+ }
}
- void fullUnroll(ForOp forOp, int64_t tripCount, int64_t step)
- {
- IRRewriter rewriter(forOp.getContext());
- if(tripCount == 0)return;
- if(tripCount ==1)
- {
- (void)forOp.promoteIfSingleIteration(rewriter);
- return;
- }
- unrollFactor = tripCount;
- generateUnroll(forOp,step);
- return;
+ loopBody->getTerminator()->setOperands(lastReturnValues);
+ }
+ void fullUnroll(ForOp forOp, int64_t tripCount, int64_t step) {
+ /*Function to fully unroll a given scf for Op*/
+ IRRewriter rewriter(forOp.getContext());
+ if (tripCount == 0)
+ return;
+ if (tripCount == 1) {
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return;
}
- LogicalResult runOnSCFForOp(OpBuilder &builder, ForOp forOp) {
- IRRewriter rewriter(forOp.getContext());
- Value lowerBound = forOp.getLowerBound();
- Value upperBound = forOp.getUpperBound();
- Value step = forOp.getStep();
- Value newStepValue,upperBoundUnrolled;
- bool createHandlerLoop = false;
- auto lowerBoundConst = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
- auto upperBoundConst = upperBound.getDefiningOp<arith::ConstantIndexOp>();
- auto stepConst = step.getDefiningOp<arith::ConstantIndexOp>();
- if (!lowerBoundConst || !upperBoundConst || !stepConst) {
- forOp.emitError("Expected constant bounds and step for unrolling.");
- return failure();
- }
- int64_t lowerBoundValue = lowerBoundConst.value();
- int64_t upperBoundValue = upperBoundConst.value();
- int64_t stepValue = stepConst.value();
- int64_t tripCount = (upperBoundValue-lowerBoundValue)/stepValue;
- int64_t multipliedStepValue = stepValue * unrollFactor;
- int64_t tripCountUnrolled = tripCount -(tripCount%unrollFactor);
- int64_t unrolledUpperBound = lowerBoundValue +(tripCountUnrolled*stepValue);
- createHandlerLoop = unrolledUpperBound < upperBoundValue;
- if (Block *prevBlock = forOp->getBlock()->getPrevNode())builder.setInsertionPointToEnd(prevBlock);
- else builder.setInsertionPoint(forOp);
- if(createHandlerLoop)
- {
- upperBoundUnrolled = builder.create<arith::ConstantIndexOp>(forOp.getLoc(),unrolledUpperBound);
- newStepValue = builder.create<arith::ConstantIndexOp>(
- forOp.getLoc(), multipliedStepValue);
- builder.setInsertionPoint(forOp);
- }
- else{
- upperBoundUnrolled = upperBound;
- newStepValue = step;
- }
- if(createHandlerLoop)
- {
- OpBuilder loopBuilder(forOp->getContext());
- loopBuilder.setInsertionPoint(forOp->getBlock(),std::next(Block::iterator(forOp)));
- auto handlerForOp = cast<scf::ForOp>(loopBuilder.clone(*forOp));
- handlerForOp.setLowerBound(upperBoundUnrolled);
-
- auto results = forOp.getResults();
- auto handlerResults = handlerForOp.getResults();
- for (auto element : llvm::zip(results, handlerResults)) {
- std::get<0>(element).replaceAllUsesWith(std::get<1>(element));
- }
- handlerForOp->setOperands(handlerForOp.getNumControlOperands(),
- handlerForOp.getInitArgs().size(), results);
- (void)handlerForOp.promoteIfSingleIteration(rewriter);
-
- }
- forOp.setStep(upperBoundUnrolled);
- forOp.setStep(newStepValue);
- if(unrollFull == false)generateUnroll(forOp,stepValue);
- else fullUnroll(forOp,tripCount,stepValue);
- (void)forOp.promoteIfSingleIteration(rewriter);
- llvm::errs() << "Completed unroll\n";
- return success();
+ unrollFactor = tripCount;
+ generateUnroll(forOp, step);
+ return;
+ }
+ LogicalResult runOnSCFForOp(OpBuilder &builder, ForOp forOp) {
+ /*Function process SCF ForOps*/
+ IRRewriter rewriter(forOp.getContext());
+ Value lowerBound = forOp.getLowerBound();
+ Value upperBound = forOp.getUpperBound();
+ Value step = forOp.getStep();
+ Value newStepValue, upperBoundUnrolled;
+ bool createHandlerLoop = false;
+ auto lowerBoundConst = lowerBound.getDefiningOp<arith::ConstantIndexOp>();
+ auto upperBoundConst = upperBound.getDefiningOp<arith::ConstantIndexOp>();
+ auto stepConst = step.getDefiningOp<arith::ConstantIndexOp>();
+ if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+ forOp.emitError("Expected constant bounds and step for unrolling.");
+ return failure();
+ }
+ int64_t lowerBoundValue = lowerBoundConst.value();
+ int64_t upperBoundValue = upperBoundConst.value();
+ int64_t stepValue = stepConst.value();
+ int64_t tripCount = (upperBoundValue - lowerBoundValue) / stepValue;
+ int64_t multipliedStepValue = stepValue * unrollFactor;
+ int64_t tripCountUnrolled = tripCount - (tripCount % unrollFactor);
+ int64_t unrolledUpperBound =
+ lowerBoundValue + (tripCountUnrolled * stepValue);
+ createHandlerLoop = unrolledUpperBound < upperBoundValue;
+ if (Block *prevBlock = forOp->getBlock()->getPrevNode())
+ builder.setInsertionPointToEnd(prevBlock);
+ else
+ builder.setInsertionPoint(forOp);
+ if (createHandlerLoop) {
+ upperBoundUnrolled = builder.create<arith::ConstantIndexOp>(
+ forOp.getLoc(), unrolledUpperBound);
+ newStepValue = builder.create<arith::ConstantIndexOp>(
+ forOp.getLoc(), multipliedStepValue);
+ builder.setInsertionPoint(forOp);
+ } else {
+ upperBoundUnrolled = upperBound;
+ newStepValue = step;
}
- static bool isInnermostSCFForOp(ForOp op) {
- return !op.getBody()
- ->walk([&](ForOp nestedForOp) {
- return WalkResult::interrupt();
- })
- .wasInterrupted();
+ if (createHandlerLoop) {
+ OpBuilder loopBuilder(forOp->getContext());
+ loopBuilder.setInsertionPoint(forOp->getBlock(),
+ std::next(Block::iterator(forOp)));
+ auto handlerForOp = cast<scf::ForOp>(loopBuilder.clone(*forOp));
+ handlerForOp.setLowerBound(upperBoundUnrolled);
+
+ auto results = forOp.getResults();
+ auto handlerResults = handlerForOp.getResults();
+ for (auto element : llvm::zip(results, handlerResults)) {
+ std::get<0>(element).replaceAllUsesWith(std::get<1>(element));
+ }
+ handlerForOp->setOperands(handlerForOp.getNumControlOperands(),
+ handlerForOp.getInitArgs().size(), results);
+ (void)handlerForOp.promoteIfSingleIteration(rewriter);
}
+ forOp.setStep(upperBoundUnrolled);
+ forOp.setStep(newStepValue);
+ if (unrollFull == false)
+ generateUnroll(forOp, stepValue);
+ else
+ fullUnroll(forOp, tripCount, stepValue);
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ llvm::errs() << "Completed unroll\n";
+ return success();
+ }
+ static bool isInnermostSCFForOp(ForOp op) {
+ /*Identifies if a given for loop is the inner most scf ForOp*/
+ return !op.getBody()
+ ->walk(
+ [&](ForOp nestedForOp) { return WalkResult::interrupt(); })
+ .wasInterrupted();
+ }
- static void gatherInnermostLoops(Operation *op,
- SmallVectorImpl<ForOp> &loops) {
- op->walk([&](ForOp forOp) {
- if (isInnermostSCFForOp(forOp))
+ static void gatherInnermostLoops(Operation *op,
+ SmallVectorImpl<ForOp> &loops) {
+ /*Function gathers all the innermost scf for loops*/
+ op->walk([&](ForOp forOp) {
+ if (isInnermostSCFForOp(forOp))
loops.push_back(forOp);
});
- }
+ }
- virtual ~LoopUnroll() = default;
+ virtual ~LoopUnroll() = default;
};
} // end anonymous namespace
-
-std::unique_ptr<Pass> mlir::createLoopUnroll(int unrollFactor, bool unrollFull) {
- return std::make_unique<LoopUnroll>(unrollFactor,unrollFull);
+std::unique_ptr<Pass> mlir::createLoopUnroll(int unrollFactor,
+ bool unrollFull) {
+ return std::make_unique<LoopUnroll>(unrollFactor, unrollFull);
}
\ No newline at end of file
>From 2b835b41d21cb208fb21551fd8d0f8f73f3afaee Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Tue, 9 Jul 2024 09:11:07 +0000
Subject: [PATCH 10/13] Feat: Upddate LoopUnrollJam
---
.../mlir/Dialect/SCF/Transforms/Passes.h | 4 +
.../mlir/Dialect/SCF/Transforms/Passes.td | 10 +
.../lib/Dialect/SCF/Transforms/CMakeLists.txt | 3 +
.../Dialect/SCF/Transforms/LoopUnrollJam.cpp | 220 ++++++++++++++++++
4 files changed, 237 insertions(+)
create mode 100644 mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index d325fa515a0e3..a83925a322834 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -69,6 +69,10 @@ std::unique_ptr<Pass> createCounterPass();
std::unique_ptr<Pass> createLoopUnroll(
int unrollFactor = 4, bool unrollFull = false );
+// Create pass for loop unroll jam
+std::unique_ptr<Pass> createLoopUnrollJam(
+ int unrollJamFactor=-1);
+
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 694e75bdd49b6..c5b389e379112 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -141,6 +141,16 @@ def SCFLoopUnroll : Pass<"scf-loop-unroll">{
];
}
+def SCFLoopUnrollJam : Pass<"scf-loop-unroll-jam">
+{
+ let summary= "Unroll and Jam SCF ForOps";
+ let constructor = "mlir::createLoopUnrollJam()";
+ let options = [
+ Option<"unrollJamFactor", "unroll-jam-factor", "unsigned","4",
+ "Use this unroll jam factor for all loops (default 4)">,
+ ];
+}
+
def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let summary = "Convert SCF forall loops to SCF parallel loops";
let constructor = "mlir::createForallToParallelLoopPass()";
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..a3856080efcd0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -17,6 +17,9 @@ add_mlir_dialect_library(MLIRSCFTransforms
TileUsingInterface.cpp
WrapInZeroTripCheck.cpp
UpliftWhileToFor.cpp
+ counter.cpp
+ LoopUnroll.cpp
+ LoopUnrollJam.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
new file mode 100644
index 0000000000000..82ea65f247e8f
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -0,0 +1,220 @@
+/*
+File containing pass for loop unroll and jam transformation
+ on scf dialect forOp
+*/
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+namespace mlir {
+ #define GEN_PASS_DEF_SCFLOOPUNROLLJAM
+ #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::scf;
+using scf::ForOp;
+
+
+namespace{
+ struct LoopReduction {
+ arith::AtomicRMWKind kind;
+ unsigned position;
+ Value reducedValue;
+ };
+ struct ForBlockGatherer{
+ SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+ void walk(Operation *op) {
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ walk(block);
+ }
+ void walk(Block &block) {
+ assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+ "expected block to have a terminator");
+ for (Block::iterator it = block.begin(), e = std::prev(block.end());
+ it != e;) {
+ Block::iterator subBlockStart = it;
+ while (it != e && !isa<ForOp>(&*it))
+ ++it;
+ if (it != subBlockStart)
+ subBlocks.emplace_back(subBlockStart, std::prev(it));
+ // Process all for ops that appear next.
+ while (it != e && isa<ForOp>(&*it))
+ walk(&*it++);
+ }
+ }
+ };
+ struct LoopUnrollJam : public impl::SCFLoopUnrollJamBase<LoopUnrollJam>
+ {
+ explicit LoopUnrollJam(
+ std::optional<unsigned> unrollJamFactor= std::nullopt){
+ if(unrollJamFactor) this->unrollJamFactor = *unrollJamFactor;
+ }
+ void runOnOperation() override
+ {
+ auto *op = getOperation();
+ op->walk([&](ForOp forOp)
+ {
+ (void)loopUnrollJamByFactor(forOp);
+ });
+ }
+ std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
+ // This is a placeholder implementation. You need to adapt it to your use case.
+ auto lowerBoundConst = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto upperBoundConst = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+
+ if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+ return std::nullopt;
+ }
+
+ int64_t lowerBoundValue = lowerBoundConst.value();
+ int64_t upperBoundValue = upperBoundConst.value();
+ int64_t stepValue = stepConst.value();
+ return (upperBoundValue - lowerBoundValue) / stepValue;
+ }
+ static bool invariantBounds(scf::ForOp forOp) {
+ auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+ if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+ }
+
+ void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
+ ForOp &forOp, IRRewriter &rewriter,
+ const SmallVector<ForOp> &innerLoops, unsigned unrollJamFactor) {
+ for (scf::ForOp currentForOp : innerLoops) {
+ SmallVector<Value> iterOperandsList, yeildOperandsList;
+ ValueRange previousIterOperands = currentForOp.getInits();
+ ValueRange previousIterArgs = currentForOp.getRegionIterArgs();
+ ValueRange previousYeildOperands = cast<YieldOp>(currentForOp.getBody()->getTerminator()).getOperands();
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ iterOperandsList.append(previousIterOperands.begin(), previousIterOperands.end());
+ yeildOperandsList.append(previousYeildOperands.begin(), previousYeildOperands.end());
+ }
+ bool forOpReplaced = currentForOp == forOp;
+ scf::ForOp newForOp =
+ cast<scf::ForOp>(*currentForOp.replaceWithAdditionalYields(
+ rewriter, iterOperandsList, /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return yeildOperandsList;
+ }));
+ newInnerLoops.push_back(newForOp);
+
+ if (forOpReplaced) forOp = newForOp;
+ ValueRange newIterArgs = newForOp.getRegionIterArgs();
+ unsigned oldNumIterArgs = previousIterArgs.size();
+ ValueRange newResults = newForOp.getResults();
+ unsigned oldNumResults = newResults.size() / unrollJamFactor;
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+ mapper[i - 1].map(newIterArgs[j],
+ newIterArgs[i * oldNumIterArgs + j]);
+ mapper[i - 1].map(newResults[j],
+ newResults[i * oldNumResults + j]);
+ }
+ }
+ }
+ }
+
+ Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
+ Value step) {
+ /*Function writes ir to update the induction variable*/
+ Location loc = iv.getLoc();
+ Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+ Value increment =builder.create<arith::MulIOp>(loc, constantI, step);
+ Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+ return updatedIv;
+ }
+
+ void updateStep(ForOp forOp,IRRewriter &rewriter)
+ {
+ if (Block *prevBlock = forOp->getBlock()->getPrevNode())
+ rewriter.setInsertionPointToEnd(prevBlock);
+ else
+ rewriter.setInsertionPoint(forOp);
+ // rewriter.setInsertionPoint(forOp);
+ auto newStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), forOp.getStep(),
+ rewriter.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+ forOp.setStep(newStep);
+ }
+ void finalUnrollJam(Value forOpInductionVar,
+ SmallVector<std::pair<Block::iterator, Block::iterator>> &subBlocks,
+ SmallVector<IRMapping> &mapper, const SmallVector<ForOp> &newInnerLoops, ForOp forOp) {
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (auto &subBlock : subBlocks) {
+ OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+ if (!forOpInductionVar.use_empty()) {
+ auto updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar, builder, forOp.getStep());
+ mapper[i - 1].map(forOpInductionVar, updatedInductionVar);
+ }
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+ builder.clone(*it, mapper[i - 1]);
+ }
+ for (auto newForOp : newInnerLoops) {
+ unsigned oldNumIterOperands = newForOp.getNumRegionIterArgs() / unrollJamFactor;
+ unsigned numControlOperands = newForOp.getNumControlOperands();
+ auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+ for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+ newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+ mapper[i - 1].lookupOrDefault(newForOp.getOperand(numControlOperands + j)));
+ yieldOp.setOperand(i * oldNumYieldOperands + j,
+ mapper[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+ }
+ }
+ }
+ }
+ LogicalResult loopUnrollJamByFactor(ForOp forOp)
+ {
+ if(unrollJamFactor ==1) return success();
+ if(!invariantBounds(forOp))return failure();
+ std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (unrollJamFactor > *tripCount) unrollJamFactor = *tripCount;
+ else if (*tripCount % unrollJamFactor != 0) return failure();
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success();
+
+ auto bg = ForBlockGatherer();
+ bg.walk(forOp);
+ auto &subBlocks = bg.subBlocks;
+
+ SmallVector<ForOp> innerLoops;
+ forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+ SmallVector<IRMapping> mapper(unrollJamFactor - 1);
+ IRRewriter rewriter(forOp.getContext());
+ SmallVector<ForOp> newInnerLoops;
+ duplicateArgs(mapper, newInnerLoops, forOp, rewriter, innerLoops, unrollJamFactor);
+ updateStep(forOp,rewriter);
+ auto forOpInductionVar = forOp.getInductionVar();
+ finalUnrollJam(forOpInductionVar,subBlocks,mapper,newInnerLoops,forOp);
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return success();
+ }
+
+ };
+} // namespace ends
+
+std::unique_ptr<Pass> mlir::createLoopUnrollJam(int unrollJamFactor){
+ return std::make_unique<LoopUnrollJam>(
+ unrollJamFactor == -1 ? std::nullopt
+ : std::optional<unsigned>(unrollJamFactor));
+ }
\ No newline at end of file
>From 5bfa0af50037f1ecc32aaf07242a0c8870611069 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 10 Jul 2024 03:46:09 +0000
Subject: [PATCH 11/13] Fix: Add code to run loop unroll jam without
reductionsand add test cases
---
.../Dialect/SCF/Transforms/LoopUnrollJam.cpp | 39 +++++++++++----
.../test/Dialect/SCF/loop-unroll-jam-scf.mlir | 47 +++++++++++++++++++
2 files changed, 78 insertions(+), 8 deletions(-)
create mode 100644 mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
index 82ea65f247e8f..6679816e9d77a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -16,6 +16,7 @@ File containing pass for loop unroll and jam transformation
#include "mlir/Support/LLVM.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include <map>
namespace mlir {
#define GEN_PASS_DEF_SCFLOOPUNROLLJAM
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -27,11 +28,6 @@ using scf::ForOp;
namespace{
- struct LoopReduction {
- arith::AtomicRMWKind kind;
- unsigned position;
- Value reducedValue;
- };
struct ForBlockGatherer{
SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
void walk(Operation *op) {
@@ -63,11 +59,25 @@ namespace{
}
void runOnOperation() override
{
+ std::map<ForOp, int> outerLoopMap;
+ std::map<ForOp, int> innerLoopMap;
auto *op = getOperation();
op->walk([&](ForOp forOp)
{
- (void)loopUnrollJamByFactor(forOp);
+ outerLoopMap[forOp]=0;
+ innerLoopMap[forOp]=0;
});
+ op->walk([&](ForOp forOp)
+ {
+ loopGatherer(forOp,outerLoopMap,innerLoopMap);
+ });
+ for(auto ele :innerLoopMap)
+ {
+ if(ele.second ==0)
+ {
+ (void)loopUnrollJamByFactor(ele.first);
+ }
+ }
}
std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
// This is a placeholder implementation. You need to adapt it to your use case.
@@ -95,6 +105,21 @@ namespace{
});
return !walkResult.wasInterrupted();
}
+ void loopNestMapper(ForOp op,std::map<ForOp, int> &outerLoopMap,
+ std::map<ForOp, int> &innerLoopMap) {
+ /*Identifies if a given for loop is the inner most scf ForOp*/
+ op.getBody()->walk([&](ForOp nestedForOp){
+ outerLoopMap[op]+=1;
+ innerLoopMap[nestedForOp]+=1;
+ });
+ }
+ void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
+ std::map<ForOp, int> &innerLoopMap) {
+ /*Function gathers all the innermost scf for loops*/
+ op->walk([&](ForOp forOp) {
+ loopNestMapper(forOp,outerLoopMap,innerLoopMap);
+ });
+ }
void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
ForOp &forOp, IRRewriter &rewriter,
@@ -149,7 +174,6 @@ namespace{
rewriter.setInsertionPointToEnd(prevBlock);
else
rewriter.setInsertionPoint(forOp);
- // rewriter.setInsertionPoint(forOp);
auto newStep = rewriter.createOrFold<arith::MulIOp>(
forOp.getLoc(), forOp.getStep(),
rewriter.createOrFold<arith::ConstantOp>(
@@ -186,7 +210,6 @@ namespace{
LogicalResult loopUnrollJamByFactor(ForOp forOp)
{
if(unrollJamFactor ==1) return success();
- if(!invariantBounds(forOp))return failure();
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (unrollJamFactor > *tripCount) unrollJamFactor = *tripCount;
else if (*tripCount % unrollJamFactor != 0) return failure();
diff --git a/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
new file mode 100644
index 0000000000000..2e4b331c59c03
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --scf-loop-unroll-jam="unroll-jam-factor=2" --split-input-file | FileCheck %s
+
+module {
+ func.func @main() -> f32 {
+ %sum = arith.constant 0.0 : f32
+ %val = arith.constant 2.0 : f32
+ %N = arith.constant 4 : index
+ %num = arith.constant 10 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
+ %new_sum = arith.addf %iter_sum, %val : f32
+ %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %new_sum) -> (f32) {
+ %new_sum2 = arith.addf %iter_sum2, %val : f32
+ scf.yield %new_sum2 : f32
+ }
+ %new_sum3 = arith.addf %result2, %val : f32
+ scf.yield %new_sum : f32
+ }
+ return %result : f32
+ }
+}
+
+// CHECK-LABEL: func.func @main() -> f32 {
+// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f32
+// CHECK-NEXT: %c4 = arith.constant 4 : index
+// CHECK-NEXT: %c10 = arith.constant 10 : index
+// CHECK-NEXT: %c0 = arith.constant 0 : index
+// CHECK-NEXT: %c1 = arith.constant 1 : index
+// CHECK-NEXT: %c2 = arith.constant 2 : index
+// CHECK-NEXT: %c2_1 = arith.constant 2 : index
+// CHECK-NEXT: %0:2 = scf.for %arg0 = %c0 to %c4 step %c2_1 iter_args(%arg1 = %cst, %arg2 = %cst) -> (f32, f32) {
+// CHECK-NEXT: %1 = arith.addf %arg1, %cst_0 : f32
+// CHECK-NEXT: %2 = arith.addf %arg2, %cst_0 : f32
+// CHECK-NEXT: %3:2 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (f32, f32) {
+// CHECK-NEXT: %6 = arith.addf %arg4, %cst_0 : f32
+// CHECK-NEXT: %7 = arith.addf %arg5, %cst_0 : f32
+// CHECK-NEXT: scf.yield %6, %7 : f32, f32
+// CHECK-NEXT: }
+// CHECK-NEXT: %4 = arith.addf %3#0, %cst_0 : f32
+// CHECK-NEXT: %5 = arith.addf %3#1, %cst_0 : f32
+// CHECK-NEXT: scf.yield %1, %2 : f32, f32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %0#0 : f32
+// CHECK-NEXT: }
>From 025240d160d0376232af864d73fda8bd84aa288f Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 10 Jul 2024 11:33:44 +0000
Subject: [PATCH 12/13] Feat: Add code to support reductions
---
.../Dialect/SCF/Transforms/LoopUnrollJam.cpp | 97 ++++++++++++++++++-
1 file changed, 92 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
index 6679816e9d77a..7319af950ed3f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -16,7 +16,10 @@ File containing pass for loop unroll and jam transformation
#include "mlir/Support/LLVM.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include <map>
+#include "../lib/Analysis/SliceAnalysis.cpp"
+#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFLOOPUNROLLJAM
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -24,6 +27,7 @@ namespace mlir {
using namespace llvm;
using namespace mlir;
using namespace mlir::scf;
+using namespace mlir::affine;
using scf::ForOp;
@@ -80,7 +84,6 @@ namespace{
}
}
std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
- // This is a placeholder implementation. You need to adapt it to your use case.
auto lowerBoundConst = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
auto upperBoundConst = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
@@ -107,7 +110,6 @@ namespace{
}
void loopNestMapper(ForOp op,std::map<ForOp, int> &outerLoopMap,
std::map<ForOp, int> &innerLoopMap) {
- /*Identifies if a given for loop is the inner most scf ForOp*/
op.getBody()->walk([&](ForOp nestedForOp){
outerLoopMap[op]+=1;
innerLoopMap[nestedForOp]+=1;
@@ -115,12 +117,10 @@ namespace{
}
void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
std::map<ForOp, int> &innerLoopMap) {
- /*Function gathers all the innermost scf for loops*/
op->walk([&](ForOp forOp) {
loopNestMapper(forOp,outerLoopMap,innerLoopMap);
});
}
-
void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
ForOp &forOp, IRRewriter &rewriter,
const SmallVector<ForOp> &innerLoops, unsigned unrollJamFactor) {
@@ -160,7 +160,6 @@ namespace{
Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
Value step) {
- /*Function writes ir to update the induction variable*/
Location loc = iv.getLoc();
Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
Value increment =builder.create<arith::MulIOp>(loc, constantI, step);
@@ -207,6 +206,85 @@ namespace{
}
}
}
+ void getReductions(ForOp forOp, SmallVectorImpl<LoopReduction> &reductions) {
+ ValueRange iterArgs = forOp.getRegionIterArgs();
+ unsigned numIterArgs = iterArgs.size();
+
+ if (numIterArgs == 0)
+ {
+ llvm::errs()<<"Iter args == 0";
+ return;
+ }
+ llvm::errs()<<"Iter args"<<numIterArgs;
+ reductions.reserve(numIterArgs);
+ for (unsigned i = 0; i < numIterArgs; ++i) {
+ arith::AtomicRMWKind kind;
+ if (Value value = getReduction(forOp, i, kind))
+ {
+ llvm::errs()<<"Inside here \n";
+ reductions.emplace_back(LoopReduction{kind, i, value});
+ llvm::errs()<<"\nValue "<<value;
+ }
+
+ }
+ }
+ Value getReduction(ForOp forOp, unsigned pos,
+ arith::AtomicRMWKind &kind) {
+ SmallVector<Operation *> combinerOps;
+ Value reducedVal = matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
+ if (!reducedVal)return nullptr;
+
+ if (combinerOps.size() > 1)return nullptr;
+
+ Operation *combinerOp = combinerOps.back();
+ std::optional<arith::AtomicRMWKind> maybeKind =
+ mlir::TypeSwitch<Operation *, std::optional<arith::AtomicRMWKind>>(combinerOp)
+ .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
+ .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
+ .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
+ .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
+ .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
+ .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
+ .Case(
+ [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
+ .Case(
+ [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
+ .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
+ .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
+ .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
+ .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+ .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
+ return std::nullopt;
+ });
+ if (!maybeKind)
+ return nullptr;
+
+ kind = *maybeKind;
+ return reducedVal;
+ }
+ void updateReductions(ForOp forOp, IRRewriter &rewriter,
+ SmallVector<LoopReduction> &reductions)
+ {
+ rewriter.setInsertionPointAfter(forOp);
+ auto loc = forOp.getLoc();
+ unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
+ for (LoopReduction &reduction : reductions) {
+ unsigned pos = reduction.iterArgPosition;
+ Value lhs = forOp.getResult(pos);
+ Value rhs;
+ SmallPtrSet<Operation *, 4> newOps;
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ rhs = forOp.getResult(i * oldNumResults + pos);
+ lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
+ if (!lhs)
+ return;
+ Operation *op = lhs.getDefiningOp();
+ assert(op && "Reduction op should have been created");
+ newOps.insert(op);
+ }
+ forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
+ }
+ }
LogicalResult loopUnrollJamByFactor(ForOp forOp)
{
if(unrollJamFactor ==1) return success();
@@ -222,6 +300,12 @@ namespace{
SmallVector<ForOp> innerLoops;
forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+ SmallVector<LoopReduction> reductions;
+ ValueRange iterArgs = forOp.getRegionIterArgs();
+ unsigned numIterOperands = iterArgs.size();
+ if (numIterOperands > 0)
+ getReductions(forOp, reductions);
+
SmallVector<IRMapping> mapper(unrollJamFactor - 1);
IRRewriter rewriter(forOp.getContext());
SmallVector<ForOp> newInnerLoops;
@@ -229,6 +313,9 @@ namespace{
updateStep(forOp,rewriter);
auto forOpInductionVar = forOp.getInductionVar();
finalUnrollJam(forOpInductionVar,subBlocks,mapper,newInnerLoops,forOp);
+ if (forOp.getNumResults() > 0) {
+ updateReductions(forOp,rewriter,reductions);
+ }
(void)forOp.promoteIfSingleIteration(rewriter);
return success();
}
>From 933f41944dd5416aa85b62626fd4b15a51495533 Mon Sep 17 00:00:00 2001
From: Anirudh-Sathish <anirudhsathish at gmail.com>
Date: Wed, 10 Jul 2024 12:35:20 +0000
Subject: [PATCH 13/13] Refactor and Test: Refactor the code and test case
---
.../Dialect/SCF/Transforms/LoopUnrollJam.cpp | 583 +++++++++---------
.../test/Dialect/SCF/loop-unroll-jam-scf.mlir | 35 +-
2 files changed, 311 insertions(+), 307 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
index 7319af950ed3f..ef1d1f291f43d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopUnrollJam.cpp
@@ -5,24 +5,24 @@ File containing pass for loop unroll and jam transformation
#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "../lib/Analysis/SliceAnalysis.cpp"
+#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/LoopLikeInterface.h"
-#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
-#include <map>
-#include "../lib/Analysis/SliceAnalysis.cpp"
#include "llvm/ADT/TypeSwitch.h"
+#include <map>
namespace mlir {
- #define GEN_PASS_DEF_SCFLOOPUNROLLJAM
- #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+#define GEN_PASS_DEF_SCFLOOPUNROLLJAM
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace llvm;
using namespace mlir;
@@ -30,301 +30,304 @@ using namespace mlir::scf;
using namespace mlir::affine;
using scf::ForOp;
+namespace {
+struct ForBlockGatherer {
+ SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+ void walk(Operation *op) {
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ walk(block);
+ }
+ void walk(Block &block) {
+ assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+ "expected block to have a terminator");
+ for (Block::iterator it = block.begin(), e = std::prev(block.end());
+ it != e;) {
+ Block::iterator subBlockStart = it;
+ while (it != e && !isa<ForOp>(&*it))
+ ++it;
+ if (it != subBlockStart)
+ subBlocks.emplace_back(subBlockStart, std::prev(it));
+ while (it != e && isa<ForOp>(&*it))
+ walk(&*it++);
+ }
+ }
+};
+struct LoopUnrollJam : public impl::SCFLoopUnrollJamBase<LoopUnrollJam> {
+ explicit LoopUnrollJam(
+ std::optional<unsigned> unrollJamFactor = std::nullopt) {
+ if (unrollJamFactor)
+ this->unrollJamFactor = *unrollJamFactor;
+ }
+ void runOnOperation() override {
+ std::map<ForOp, int> outerLoopMap;
+ std::map<ForOp, int> innerLoopMap;
+ auto *op = getOperation();
+ op->walk([&](ForOp forOp) {
+ outerLoopMap[forOp] = 0;
+ innerLoopMap[forOp] = 0;
+ });
+ op->walk(
+ [&](ForOp forOp) { loopGatherer(forOp, outerLoopMap, innerLoopMap); });
+ for (auto ele : innerLoopMap) {
+ if (ele.second == 0) {
+ (void)loopUnrollJamByFactor(ele.first);
+ }
+ }
+ }
+
+ // Obtains trip count for a given for op
+ std::optional<uint64_t> getTripCount(scf::ForOp forOp) {
+ auto lowerBoundConst =
+ forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto upperBoundConst =
+ forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+
+ if (!lowerBoundConst || !upperBoundConst || !stepConst) {
+ return std::nullopt;
+ }
+
+ int64_t lowerBoundValue = lowerBoundConst.value();
+ int64_t upperBoundValue = upperBoundConst.value();
+ int64_t stepValue = stepConst.value();
+ return (upperBoundValue - lowerBoundValue) / stepValue;
+ }
+ // Maps the nesting levels of the loops
+ void loopNestMapper(ForOp op, std::map<ForOp, int> &outerLoopMap,
+ std::map<ForOp, int> &innerLoopMap) {
+ op.getBody()->walk([&](ForOp nestedForOp) {
+ outerLoopMap[op] += 1;
+ innerLoopMap[nestedForOp] += 1;
+ });
+ }
-namespace{
- struct ForBlockGatherer{
- SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
- void walk(Operation *op) {
- for (Region ®ion : op->getRegions())
- for (Block &block : region)
- walk(block);
+ // Gathers loops to map their nesting levels
+ void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
+ std::map<ForOp, int> &innerLoopMap) {
+ op->walk([&](ForOp forOp) {
+ loopNestMapper(forOp, outerLoopMap, innerLoopMap);
+ });
+ }
+
+ // duplicates loop argument for unrolling and jamming
+ void duplicateArgs(SmallVector<IRMapping> &mapper,
+ SmallVector<ForOp> &newInnerLoops, ForOp &forOp,
+ IRRewriter &rewriter, const SmallVector<ForOp> &innerLoops,
+ unsigned unrollJamFactor) {
+ for (scf::ForOp currentForOp : innerLoops) {
+ SmallVector<Value> iterOperandsList, yeildOperandsList;
+ ValueRange previousIterOperands = currentForOp.getInits();
+ ValueRange previousIterArgs = currentForOp.getRegionIterArgs();
+ ValueRange previousYeildOperands =
+ cast<YieldOp>(currentForOp.getBody()->getTerminator()).getOperands();
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ iterOperandsList.append(previousIterOperands.begin(),
+ previousIterOperands.end());
+ yeildOperandsList.append(previousYeildOperands.begin(),
+ previousYeildOperands.end());
+ }
+ bool forOpReplaced = currentForOp == forOp;
+ scf::ForOp newForOp =
+ cast<scf::ForOp>(*currentForOp.replaceWithAdditionalYields(
+ rewriter, iterOperandsList,
+ /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBbArgs) {
+ return yeildOperandsList;
+ }));
+ newInnerLoops.push_back(newForOp);
+
+ if (forOpReplaced)
+ forOp = newForOp;
+ ValueRange newIterArgs = newForOp.getRegionIterArgs();
+ unsigned oldNumIterArgs = previousIterArgs.size();
+ ValueRange newResults = newForOp.getResults();
+ unsigned oldNumResults = newResults.size() / unrollJamFactor;
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+ mapper[i - 1].map(newIterArgs[j],
+ newIterArgs[i * oldNumIterArgs + j]);
+ mapper[i - 1].map(newResults[j], newResults[i * oldNumResults + j]);
}
- void walk(Block &block) {
- assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
- "expected block to have a terminator");
- for (Block::iterator it = block.begin(), e = std::prev(block.end());
- it != e;) {
- Block::iterator subBlockStart = it;
- while (it != e && !isa<ForOp>(&*it))
- ++it;
- if (it != subBlockStart)
- subBlocks.emplace_back(subBlockStart, std::prev(it));
- // Process all for ops that appear next.
- while (it != e && isa<ForOp>(&*it))
- walk(&*it++);
- }
+ }
+ }
+ }
+ // creates an updated induction variable
+ Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
+ Value step) {
+ Location loc = iv.getLoc();
+ Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
+ Value increment = builder.create<arith::MulIOp>(loc, constantI, step);
+ Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
+ return updatedIv;
+ }
+ // updates the step of a given loop
+ void updateStep(ForOp forOp, IRRewriter &rewriter) {
+ if (Block *prevBlock = forOp->getBlock()->getPrevNode())
+ rewriter.setInsertionPointToEnd(prevBlock);
+ else
+ rewriter.setInsertionPoint(forOp);
+ auto newStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), forOp.getStep(),
+ rewriter.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+ forOp.setStep(newStep);
+ }
+
+ // performs the final unroll and jam operations, cloning and updating the
+ // subblocks
+ void finalUnrollJam(
+ Value forOpInductionVar,
+ SmallVector<std::pair<Block::iterator, Block::iterator>> &subBlocks,
+ SmallVector<IRMapping> &mapper, const SmallVector<ForOp> &newInnerLoops,
+ ForOp forOp) {
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (auto &subBlock : subBlocks) {
+ OpBuilder builder(subBlock.first->getBlock(),
+ std::next(subBlock.second));
+ if (!forOpInductionVar.use_empty()) {
+ auto updatedInductionVar = createUpdatedInductionVar(
+ i, forOpInductionVar, builder, forOp.getStep());
+ mapper[i - 1].map(forOpInductionVar, updatedInductionVar);
}
- };
- struct LoopUnrollJam : public impl::SCFLoopUnrollJamBase<LoopUnrollJam>
- {
- explicit LoopUnrollJam(
- std::optional<unsigned> unrollJamFactor= std::nullopt){
- if(unrollJamFactor) this->unrollJamFactor = *unrollJamFactor;
- }
- void runOnOperation() override
- {
- std::map<ForOp, int> outerLoopMap;
- std::map<ForOp, int> innerLoopMap;
- auto *op = getOperation();
- op->walk([&](ForOp forOp)
- {
- outerLoopMap[forOp]=0;
- innerLoopMap[forOp]=0;
- });
- op->walk([&](ForOp forOp)
- {
- loopGatherer(forOp,outerLoopMap,innerLoopMap);
- });
- for(auto ele :innerLoopMap)
- {
- if(ele.second ==0)
- {
- (void)loopUnrollJamByFactor(ele.first);
- }
- }
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+ builder.clone(*it, mapper[i - 1]);
+ }
+ for (auto newForOp : newInnerLoops) {
+ unsigned oldNumIterOperands =
+ newForOp.getNumRegionIterArgs() / unrollJamFactor;
+ unsigned numControlOperands = newForOp.getNumControlOperands();
+ auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned oldNumYieldOperands =
+ yieldOp.getNumOperands() / unrollJamFactor;
+ for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+ newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+ mapper[i - 1].lookupOrDefault(
+ newForOp.getOperand(numControlOperands + j)));
+ yieldOp.setOperand(
+ i * oldNumYieldOperands + j,
+ mapper[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
}
- std::optional<uint64_t> getConstantTripCount(scf::ForOp forOp) {
- auto lowerBoundConst = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
- auto upperBoundConst = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
- auto stepConst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+ }
+ }
+ }
- if (!lowerBoundConst || !upperBoundConst || !stepConst) {
- return std::nullopt;
- }
+ // Gathers reduction operations in the loop
+ void getReductions(ForOp forOp, SmallVectorImpl<LoopReduction> &reductions) {
+ ValueRange iterArgs = forOp.getRegionIterArgs();
+ unsigned numIterArgs = iterArgs.size();
- int64_t lowerBoundValue = lowerBoundConst.value();
- int64_t upperBoundValue = upperBoundConst.value();
- int64_t stepValue = stepConst.value();
- return (upperBoundValue - lowerBoundValue) / stepValue;
- }
- static bool invariantBounds(scf::ForOp forOp) {
- auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
- if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
- !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
- !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
- return WalkResult::interrupt();
-
- return WalkResult::advance();
- });
- return !walkResult.wasInterrupted();
- }
- void loopNestMapper(ForOp op,std::map<ForOp, int> &outerLoopMap,
- std::map<ForOp, int> &innerLoopMap) {
- op.getBody()->walk([&](ForOp nestedForOp){
- outerLoopMap[op]+=1;
- innerLoopMap[nestedForOp]+=1;
- });
- }
- void loopGatherer(Operation *op, std::map<ForOp, int> &outerLoopMap,
- std::map<ForOp, int> &innerLoopMap) {
- op->walk([&](ForOp forOp) {
- loopNestMapper(forOp,outerLoopMap,innerLoopMap);
+ if (numIterArgs == 0) {
+ return;
+ }
+ reductions.reserve(numIterArgs);
+ for (unsigned i = 0; i < numIterArgs; ++i) {
+ arith::AtomicRMWKind kind;
+ if (Value value = getReduction(forOp, i, kind)) {
+ reductions.emplace_back(LoopReduction{kind, i, value});
+ }
+ }
+ }
+
+ // Matches the reduction operation within the loop
+ Value getReduction(ForOp forOp, unsigned pos, arith::AtomicRMWKind &kind) {
+ SmallVector<Operation *> combinerOps;
+ Value reducedVal =
+ matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
+ if (!reducedVal)
+ return nullptr;
+
+ if (combinerOps.size() > 1)
+ return nullptr;
+
+ Operation *combinerOp = combinerOps.back();
+ std::optional<arith::AtomicRMWKind> maybeKind =
+ mlir::TypeSwitch<Operation *, std::optional<arith::AtomicRMWKind>>(
+ combinerOp)
+ .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
+ .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
+ .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
+ .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
+ .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
+ .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
+ .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
+ return std::nullopt;
});
- }
- void duplicateArgs(SmallVector<IRMapping> &mapper, SmallVector<ForOp> &newInnerLoops,
- ForOp &forOp, IRRewriter &rewriter,
- const SmallVector<ForOp> &innerLoops, unsigned unrollJamFactor) {
- for (scf::ForOp currentForOp : innerLoops) {
- SmallVector<Value> iterOperandsList, yeildOperandsList;
- ValueRange previousIterOperands = currentForOp.getInits();
- ValueRange previousIterArgs = currentForOp.getRegionIterArgs();
- ValueRange previousYeildOperands = cast<YieldOp>(currentForOp.getBody()->getTerminator()).getOperands();
- for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
- iterOperandsList.append(previousIterOperands.begin(), previousIterOperands.end());
- yeildOperandsList.append(previousYeildOperands.begin(), previousYeildOperands.end());
- }
- bool forOpReplaced = currentForOp == forOp;
- scf::ForOp newForOp =
- cast<scf::ForOp>(*currentForOp.replaceWithAdditionalYields(
- rewriter, iterOperandsList, /*replaceInitOperandUsesInLoop=*/false,
- [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
- return yeildOperandsList;
- }));
- newInnerLoops.push_back(newForOp);
+ if (!maybeKind)
+ return nullptr;
- if (forOpReplaced) forOp = newForOp;
- ValueRange newIterArgs = newForOp.getRegionIterArgs();
- unsigned oldNumIterArgs = previousIterArgs.size();
- ValueRange newResults = newForOp.getResults();
- unsigned oldNumResults = newResults.size() / unrollJamFactor;
- for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
- for (unsigned j = 0; j < oldNumIterArgs; ++j) {
- mapper[i - 1].map(newIterArgs[j],
- newIterArgs[i * oldNumIterArgs + j]);
- mapper[i - 1].map(newResults[j],
- newResults[i * oldNumResults + j]);
- }
- }
- }
- }
+ kind = *maybeKind;
+ return reducedVal;
+ }
- Value createUpdatedInductionVar(unsigned i, Value iv, OpBuilder &builder,
- Value step) {
- Location loc = iv.getLoc();
- Value constantI = builder.create<arith::ConstantIndexOp>(loc, i);
- Value increment =builder.create<arith::MulIOp>(loc, constantI, step);
- Value updatedIv = builder.create<arith::AddIOp>(loc, iv, increment);
- return updatedIv;
- }
+ // Updates reduction operations after unrolling
+ void updateReductions(ForOp forOp, IRRewriter &rewriter,
+ SmallVector<LoopReduction> &reductions) {
+ rewriter.setInsertionPointAfter(forOp);
+ auto loc = forOp.getLoc();
+ unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
+ for (LoopReduction &reduction : reductions) {
+ unsigned pos = reduction.iterArgPosition;
+ Value lhs = forOp.getResult(pos);
+ Value rhs;
+ SmallPtrSet<Operation *, 4> newOps;
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ rhs = forOp.getResult(i * oldNumResults + pos);
+ lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
+ if (!lhs)
+ return;
+ Operation *op = lhs.getDefiningOp();
+ newOps.insert(op);
+ }
+ forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
+ }
+ }
- void updateStep(ForOp forOp,IRRewriter &rewriter)
- {
- if (Block *prevBlock = forOp->getBlock()->getPrevNode())
- rewriter.setInsertionPointToEnd(prevBlock);
- else
- rewriter.setInsertionPoint(forOp);
- auto newStep = rewriter.createOrFold<arith::MulIOp>(
- forOp.getLoc(), forOp.getStep(),
- rewriter.createOrFold<arith::ConstantOp>(
- forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
- forOp.setStep(newStep);
- }
- void finalUnrollJam(Value forOpInductionVar,
- SmallVector<std::pair<Block::iterator, Block::iterator>> &subBlocks,
- SmallVector<IRMapping> &mapper, const SmallVector<ForOp> &newInnerLoops, ForOp forOp) {
- for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
- for (auto &subBlock : subBlocks) {
- OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
- if (!forOpInductionVar.use_empty()) {
- auto updatedInductionVar = createUpdatedInductionVar(i, forOpInductionVar, builder, forOp.getStep());
- mapper[i - 1].map(forOpInductionVar, updatedInductionVar);
- }
- for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
- builder.clone(*it, mapper[i - 1]);
- }
- for (auto newForOp : newInnerLoops) {
- unsigned oldNumIterOperands = newForOp.getNumRegionIterArgs() / unrollJamFactor;
- unsigned numControlOperands = newForOp.getNumControlOperands();
- auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
- unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
- for (unsigned j = 0; j < oldNumIterOperands; ++j) {
- newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
- mapper[i - 1].lookupOrDefault(newForOp.getOperand(numControlOperands + j)));
- yieldOp.setOperand(i * oldNumYieldOperands + j,
- mapper[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
- }
- }
- }
- }
- void getReductions(ForOp forOp, SmallVectorImpl<LoopReduction> &reductions) {
- ValueRange iterArgs = forOp.getRegionIterArgs();
- unsigned numIterArgs = iterArgs.size();
+ // unrolls and jams a loop by a given factor
+ LogicalResult loopUnrollJamByFactor(ForOp forOp) {
+ if (unrollJamFactor == 1)
+ return success();
+ std::optional<uint64_t> tripCount = getTripCount(forOp);
+ if (unrollJamFactor > *tripCount)
+ unrollJamFactor = *tripCount;
+ else if (*tripCount % unrollJamFactor != 0)
+ return failure();
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+ return success();
- if (numIterArgs == 0)
- {
- llvm::errs()<<"Iter args == 0";
- return;
- }
- llvm::errs()<<"Iter args"<<numIterArgs;
- reductions.reserve(numIterArgs);
- for (unsigned i = 0; i < numIterArgs; ++i) {
- arith::AtomicRMWKind kind;
- if (Value value = getReduction(forOp, i, kind))
- {
- llvm::errs()<<"Inside here \n";
- reductions.emplace_back(LoopReduction{kind, i, value});
- llvm::errs()<<"\nValue "<<value;
- }
-
- }
- }
- Value getReduction(ForOp forOp, unsigned pos,
- arith::AtomicRMWKind &kind) {
- SmallVector<Operation *> combinerOps;
- Value reducedVal = matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
- if (!reducedVal)return nullptr;
-
- if (combinerOps.size() > 1)return nullptr;
-
- Operation *combinerOp = combinerOps.back();
- std::optional<arith::AtomicRMWKind> maybeKind =
- mlir::TypeSwitch<Operation *, std::optional<arith::AtomicRMWKind>>(combinerOp)
- .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
- .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
- .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
- .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
- .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
- .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
- .Case(
- [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
- .Case(
- [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
- .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
- .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
- .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
- .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
- .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
- return std::nullopt;
- });
- if (!maybeKind)
- return nullptr;
-
- kind = *maybeKind;
- return reducedVal;
- }
- void updateReductions(ForOp forOp, IRRewriter &rewriter,
- SmallVector<LoopReduction> &reductions)
- {
- rewriter.setInsertionPointAfter(forOp);
- auto loc = forOp.getLoc();
- unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
- for (LoopReduction &reduction : reductions) {
- unsigned pos = reduction.iterArgPosition;
- Value lhs = forOp.getResult(pos);
- Value rhs;
- SmallPtrSet<Operation *, 4> newOps;
- for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
- rhs = forOp.getResult(i * oldNumResults + pos);
- lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
- if (!lhs)
- return;
- Operation *op = lhs.getDefiningOp();
- assert(op && "Reduction op should have been created");
- newOps.insert(op);
- }
- forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
- }
- }
- LogicalResult loopUnrollJamByFactor(ForOp forOp)
- {
- if(unrollJamFactor ==1) return success();
- std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
- if (unrollJamFactor > *tripCount) unrollJamFactor = *tripCount;
- else if (*tripCount % unrollJamFactor != 0) return failure();
- if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success();
-
- auto bg = ForBlockGatherer();
- bg.walk(forOp);
- auto &subBlocks = bg.subBlocks;
-
- SmallVector<ForOp> innerLoops;
- forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+ auto bg = ForBlockGatherer();
+ bg.walk(forOp);
+ auto &subBlocks = bg.subBlocks;
- SmallVector<LoopReduction> reductions;
- ValueRange iterArgs = forOp.getRegionIterArgs();
- unsigned numIterOperands = iterArgs.size();
- if (numIterOperands > 0)
- getReductions(forOp, reductions);
+ SmallVector<ForOp> innerLoops;
+ forOp.walk([&](ForOp innerForOp) { innerLoops.push_back(innerForOp); });
- SmallVector<IRMapping> mapper(unrollJamFactor - 1);
- IRRewriter rewriter(forOp.getContext());
- SmallVector<ForOp> newInnerLoops;
- duplicateArgs(mapper, newInnerLoops, forOp, rewriter, innerLoops, unrollJamFactor);
- updateStep(forOp,rewriter);
- auto forOpInductionVar = forOp.getInductionVar();
- finalUnrollJam(forOpInductionVar,subBlocks,mapper,newInnerLoops,forOp);
- if (forOp.getNumResults() > 0) {
- updateReductions(forOp,rewriter,reductions);
- }
- (void)forOp.promoteIfSingleIteration(rewriter);
- return success();
- }
+ SmallVector<LoopReduction> reductions;
+ ValueRange iterArgs = forOp.getRegionIterArgs();
+ unsigned numIterOperands = iterArgs.size();
+ if (numIterOperands > 0)
+ getReductions(forOp, reductions);
- };
-} // namespace ends
+ SmallVector<IRMapping> mapper(unrollJamFactor - 1);
+ IRRewriter rewriter(forOp.getContext());
+ SmallVector<ForOp> newInnerLoops;
+ duplicateArgs(mapper, newInnerLoops, forOp, rewriter, innerLoops,
+ unrollJamFactor);
+ updateStep(forOp, rewriter);
+ auto forOpInductionVar = forOp.getInductionVar();
+ finalUnrollJam(forOpInductionVar, subBlocks, mapper, newInnerLoops, forOp);
+ if (forOp.getNumResults() > 0) {
+ updateReductions(forOp, rewriter, reductions);
+ }
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return success();
+ }
+};
+} // namespace
-std::unique_ptr<Pass> mlir::createLoopUnrollJam(int unrollJamFactor){
- return std::make_unique<LoopUnrollJam>(
- unrollJamFactor == -1 ? std::nullopt
- : std::optional<unsigned>(unrollJamFactor));
- }
\ No newline at end of file
+std::unique_ptr<Pass> mlir::createLoopUnrollJam(int unrollJamFactor) {
+ return std::make_unique<LoopUnrollJam>(
+ unrollJamFactor == -1 ? std::nullopt
+ : std::optional<unsigned>(unrollJamFactor));
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
index 2e4b331c59c03..4366026600811 100644
--- a/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll-jam-scf.mlir
@@ -4,14 +4,14 @@ module {
func.func @main() -> f32 {
%sum = arith.constant 0.0 : f32
%val = arith.constant 2.0 : f32
- %N = arith.constant 4 : index
- %num = arith.constant 10 : index
+ %N = arith.constant 16 : index
+ %num = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%result = scf.for %i = %c0 to %N step %c1 iter_args(%iter_sum = %sum) -> (f32) {
%new_sum = arith.addf %iter_sum, %val : f32
- %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %new_sum) -> (f32) {
+ %result2 = scf.for %j = %c0 to %num step %c1 iter_args(%iter_sum2 = %val) -> (f32) {
%new_sum2 = arith.addf %iter_sum2, %val : f32
scf.yield %new_sum2 : f32
}
@@ -25,23 +25,24 @@ module {
// CHECK-LABEL: func.func @main() -> f32 {
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f32
-// CHECK-NEXT: %c4 = arith.constant 4 : index
-// CHECK-NEXT: %c10 = arith.constant 10 : index
+// CHECK-NEXT: %c16 = arith.constant 16 : index
+// CHECK-NEXT: %c16_1 = arith.constant 16 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c2 = arith.constant 2 : index
-// CHECK-NEXT: %c2_1 = arith.constant 2 : index
-// CHECK-NEXT: %0:2 = scf.for %arg0 = %c0 to %c4 step %c2_1 iter_args(%arg1 = %cst, %arg2 = %cst) -> (f32, f32) {
-// CHECK-NEXT: %1 = arith.addf %arg1, %cst_0 : f32
-// CHECK-NEXT: %2 = arith.addf %arg2, %cst_0 : f32
-// CHECK-NEXT: %3:2 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (f32, f32) {
-// CHECK-NEXT: %6 = arith.addf %arg4, %cst_0 : f32
-// CHECK-NEXT: %7 = arith.addf %arg5, %cst_0 : f32
-// CHECK-NEXT: scf.yield %6, %7 : f32, f32
+// CHECK-NEXT: %c2_2 = arith.constant 2 : index
+// CHECK-NEXT: %0:2 = scf.for %arg0 = %c0 to %c16 step %c2_2 iter_args(%arg1 = %cst, %arg2 = %cst) -> (f32, f32) {
+// CHECK-NEXT: %2 = arith.addf %arg1, %cst_0 : f32
+// CHECK-NEXT: %3 = arith.addf %arg2, %cst_0 : f32
+// CHECK-NEXT: %4:2 = scf.for %arg3 = %c0 to %c16_1 step %c1 iter_args(%arg4 = %cst_0, %arg5 = %cst_0) -> (f32, f32) {
+// CHECK-NEXT: %7 = arith.addf %arg4, %cst_0 : f32
+// CHECK-NEXT: %8 = arith.addf %arg5, %cst_0 : f32
+// CHECK-NEXT: scf.yield %7, %8 : f32, f32
// CHECK-NEXT: }
-// CHECK-NEXT: %4 = arith.addf %3#0, %cst_0 : f32
-// CHECK-NEXT: %5 = arith.addf %3#1, %cst_0 : f32
-// CHECK-NEXT: scf.yield %1, %2 : f32, f32
+// CHECK-NEXT: %5 = arith.addf %4#0, %cst_0 : f32
+// CHECK-NEXT: %6 = arith.addf %4#1, %cst_0 : f32
+// CHECK-NEXT: scf.yield %2, %3 : f32, f32
// CHECK-NEXT: }
-// CHECK-NEXT: return %0#0 : f32
+// CHECK-NEXT: %1 = arith.addf %0#0, %0#1 : f32
+// CHECK-NEXT: return %1 : f32
// CHECK-NEXT: }
More information about the Mlir-commits
mailing list