[flang-commits] [flang] 5be0f0c - [Flang] Lower Matmul intrinsic
Kiran Chandramohan via flang-commits
flang-commits at lists.llvm.org
Thu Mar 17 05:33:02 PDT 2022
Author: Kiran Chandramohan
Date: 2022-03-17T12:31:03Z
New Revision: 5be0f0c83d39310d1c6c078ec5b0731c48d34080
URL: https://github.com/llvm/llvm-project/commit/5be0f0c83d39310d1c6c078ec5b0731c48d34080
DIFF: https://github.com/llvm/llvm-project/commit/5be0f0c83d39310d1c6c078ec5b0731c48d34080.diff
LOG: [Flang] Lower Matmul intrinsic
The Matmul intrinsic performs matrix multiplication on rank 2 arrays.
The intrinsic is lowered to a runtime call.
This is part of the upstreaming effort from the fir-dev branch in [1].
[1] https://github.com/flang-compiler/f18-llvm-project
Reviewed By: clementval
Differential Revision: https://reviews.llvm.org/D121904
Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Valentin Clement <clementval at gmail.com>
Added:
flang/test/Lower/Intrinsics/matmul.f90
Modified:
flang/lib/Lower/IntrinsicCall.cpp
Removed:
################################################################################
diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp
index 8039cef87a625..b7c02e2c2cacf 100644
--- a/flang/lib/Lower/IntrinsicCall.cpp
+++ b/flang/lib/Lower/IntrinsicCall.cpp
@@ -477,6 +477,7 @@ struct IntrinsicLibrary {
fir::ExtendedValue genLbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genLen(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genLenTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMinloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@@ -692,6 +693,10 @@ static constexpr IntrinsicHandler handlers[]{
{"lgt", &I::genCharacterCompare<mlir::arith::CmpIPredicate::sgt>},
{"lle", &I::genCharacterCompare<mlir::arith::CmpIPredicate::sle>},
{"llt", &I::genCharacterCompare<mlir::arith::CmpIPredicate::slt>},
+ {"matmul",
+ &I::genMatmul,
+ {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
+ /*isElemental=*/false},
{"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
{"maxloc",
&I::genMaxloc,
@@ -2374,6 +2379,34 @@ IntrinsicLibrary::genCharacterCompare(mlir::Type type,
fir::getBase(args[1]), fir::getLen(args[1]));
}
+// MATMUL
+fir::ExtendedValue
+IntrinsicLibrary::genMatmul(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+
+ // Handle required matmul arguments
+ fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
+ mlir::Value matrixA = fir::getBase(matrixTmpA);
+ fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
+ mlir::Value matrixB = fir::getBase(matrixTmpB);
+ unsigned resultRank =
+ (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
+
+ // Create mutable fir.box to be passed to the runtime for the result.
+ mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
+ fir::MutableBoxValue resultMutableBox =
+ fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+ mlir::Value resultIrBox =
+ fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+ // Call runtime. The runtime is allocating the result.
+ fir::runtime::genMatmul(builder, loc, resultIrBox, matrixA, matrixB);
+ // Read result from mutable fir.box and add it to the list of temps to be
+ // finalized by the StatementContext.
+ return readAndAddCleanUp(resultMutableBox, resultType,
+ "unexpected result for MATMUL");
+}
+
// Compare two FIR values and return boolean result as i1.
template <Extremum extremum, ExtremumBehavior behavior>
static mlir::Value createExtremumCompare(mlir::Location loc,
diff --git a/flang/test/Lower/Intrinsics/matmul.f90 b/flang/test/Lower/Intrinsics/matmul.f90
new file mode 100644
index 0000000000000..6c3c721063c9f
--- /dev/null
+++ b/flang/test/Lower/Intrinsics/matmul.f90
@@ -0,0 +1,68 @@
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+
+! Test matmul intrinsic
+
+! CHECK-LABEL: matmul_test
+! CHECK-SAME: (%[[X:.*]]: !fir.ref<!fir.array<3x1xf32>>{{.*}}, %[[Y:.*]]: !fir.ref<!fir.array<1x3xf32>>{{.*}}, %[[Z:.*]]: !fir.ref<!fir.array<2x2xf32>>{{.*}})
+! CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>>
+! CHECK: %[[C3:.*]] = arith.constant 3 : index
+! CHECK: %[[C1:.*]] = arith.constant 1 : index
+! CHECK: %[[C1_0:.*]] = arith.constant 1 : index
+! CHECK: %[[C3_1:.*]] = arith.constant 3 : index
+! CHECK: %[[Z_BOX:.*]] = fir.array_load %[[Z]]({{.*}}) : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> !fir.array<2x2xf32>
+! CHECK: %[[X_SHAPE:.*]] = fir.shape %[[C3]], %[[C1]] : (index, index) -> !fir.shape<2>
+! CHECK: %[[X_BOX:.*]] = fir.embox %[[X]](%[[X_SHAPE]]) : (!fir.ref<!fir.array<3x1xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<3x1xf32>>
+! CHECK: %[[Y_SHAPE:.*]] = fir.shape %[[C1_0]], %[[C3_1]] : (index, index) -> !fir.shape<2>
+! CHECK: %[[Y_BOX:.*]] = fir.embox %[[Y]](%[[Y_SHAPE]]) : (!fir.ref<!fir.array<1x3xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<1x3xf32>>
+! CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x?xf32>>
+! CHECK: %[[C0:.*]] = arith.constant 0 : index
+! CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2>
+! CHECK: %[[RESULT_BOX_VAL:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.heap<!fir.array<?x?xf32>>>
+! CHECK: fir.store %[[RESULT_BOX_VAL]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+! CHECK: %[[RESULT_BOX_ADDR_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<3x1xf32>>) -> !fir.box<none>
+! CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<1x3xf32>>) -> !fir.box<none>
+! CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_ADDR_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+! CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+! CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
+! CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
+! CHECK: {{.*}}fir.array_fetch
+! CHECK: {{.*}}fir.array_update
+! CHECK: fir.result
+! CHECK: }
+! CHECK: fir.array_merge_store %[[Z_BOX]], %[[Z_COPY_FROM_RESULT]] to %[[Z]] : !fir.array<2x2xf32>, !fir.array<2x2xf32>, !fir.ref<!fir.array<2x2xf32>>
+! CHECK: fir.freemem %[[RESULT_TMP]] : <!fir.array<?x?xf32>>
+subroutine matmul_test(x,y,z)
+ real :: x(3,1), y(1,3), z(2,2)
+ z = matmul(x,y)
+end subroutine
+
+! CHECK-LABEL: matmul_test2
+! CHECK-SAME: (%[[X_BOX:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>{{.*}}, %[[Y_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}}, %[[Z_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}})
+!CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+!CHECK: %[[Z:.*]] = fir.array_load %[[Z_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.array<?x!fir.logical<4>>
+!CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.logical<4>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1>
+!CHECK: %[[RESULT_BOX:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+!CHECK: fir.store %[[RESULT_BOX]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+!CHECK: %[[RESULT_BOX_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> !fir.ref<!fir.box<none>>
+!CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.box<none>
+!CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
+!CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+!CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+!CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>) -> !fir.heap<!fir.array<?x!fir.logical<4>>>
+!CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
+!CHECK: {{.*}}fir.array_fetch
+!CHECK: {{.*}}fir.array_update
+!CHECK: fir.result
+!CHECK: }
+!CHECK: fir.array_merge_store %[[Z]], %[[Z_COPY_FROM_RESULT]] to %[[Z_BOX]] : !fir.array<?x!fir.logical<4>>, !fir.array<?x!fir.logical<4>>, !fir.box<!fir.array<?x!fir.logical<4>>>
+!CHECK: fir.freemem %[[RESULT_TMP]] : <!fir.array<?x!fir.logical<4>>>
+subroutine matmul_test2(X, Y, Z)
+ logical :: X(:,:)
+ logical :: Y(:)
+ logical :: Z(:)
+ Z = matmul(X, Y)
+end subroutine
More information about the flang-commits
mailing list