[flang-commits] [flang] 5be0f0c - [Flang] Lower Matmul intrinsic

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Thu Mar 17 05:33:02 PDT 2022


Author: Kiran Chandramohan
Date: 2022-03-17T12:31:03Z
New Revision: 5be0f0c83d39310d1c6c078ec5b0731c48d34080

URL: https://github.com/llvm/llvm-project/commit/5be0f0c83d39310d1c6c078ec5b0731c48d34080
DIFF: https://github.com/llvm/llvm-project/commit/5be0f0c83d39310d1c6c078ec5b0731c48d34080.diff

LOG: [Flang] Lower Matmul intrinsic

The Matmul intrinsic performs matrix multiplication on rank 2 arrays.
The intrinsic is lowered to a runtime call.

This is part of the upstreaming effort from the fir-dev branch in [1].
[1] https://github.com/flang-compiler/f18-llvm-project

Reviewed By: clementval

Differential Revision: https://reviews.llvm.org/D121904

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Valentin Clement <clementval at gmail.com>

Added: 
    flang/test/Lower/Intrinsics/matmul.f90

Modified: 
    flang/lib/Lower/IntrinsicCall.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp
index 8039cef87a625..b7c02e2c2cacf 100644
--- a/flang/lib/Lower/IntrinsicCall.cpp
+++ b/flang/lib/Lower/IntrinsicCall.cpp
@@ -477,6 +477,7 @@ struct IntrinsicLibrary {
   fir::ExtendedValue genLbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genLen(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genLenTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+  fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMinloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@@ -692,6 +693,10 @@ static constexpr IntrinsicHandler handlers[]{
     {"lgt", &I::genCharacterCompare<mlir::arith::CmpIPredicate::sgt>},
     {"lle", &I::genCharacterCompare<mlir::arith::CmpIPredicate::sle>},
     {"llt", &I::genCharacterCompare<mlir::arith::CmpIPredicate::slt>},
+    {"matmul",
+     &I::genMatmul,
+     {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
+     /*isElemental=*/false},
     {"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
     {"maxloc",
      &I::genMaxloc,
@@ -2374,6 +2379,34 @@ IntrinsicLibrary::genCharacterCompare(mlir::Type type,
       fir::getBase(args[1]), fir::getLen(args[1]));
 }
 
+// MATMUL
+fir::ExtendedValue
+IntrinsicLibrary::genMatmul(mlir::Type resultType,
+                            llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 2);
+
+  // Handle required matmul arguments
+  fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
+  mlir::Value matrixA = fir::getBase(matrixTmpA);
+  fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
+  mlir::Value matrixB = fir::getBase(matrixTmpB);
+  unsigned resultRank =
+      (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
+
+  // Create mutable fir.box to be passed to the runtime for the result.
+  mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
+  fir::MutableBoxValue resultMutableBox =
+      fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+  mlir::Value resultIrBox =
+      fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+  // Call runtime. The runtime is allocating the result.
+  fir::runtime::genMatmul(builder, loc, resultIrBox, matrixA, matrixB);
+  // Read result from mutable fir.box and add it to the list of temps to be
+  // finalized by the StatementContext.
+  return readAndAddCleanUp(resultMutableBox, resultType,
+                           "unexpected result for MATMUL");
+}
+
 // Compare two FIR values and return boolean result as i1.
 template <Extremum extremum, ExtremumBehavior behavior>
 static mlir::Value createExtremumCompare(mlir::Location loc,

diff  --git a/flang/test/Lower/Intrinsics/matmul.f90 b/flang/test/Lower/Intrinsics/matmul.f90
new file mode 100644
index 0000000000000..6c3c721063c9f
--- /dev/null
+++ b/flang/test/Lower/Intrinsics/matmul.f90
@@ -0,0 +1,68 @@
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+
+! Test matmul intrinsic
+
+! CHECK-LABEL: matmul_test
+! CHECK-SAME: (%[[X:.*]]: !fir.ref<!fir.array<3x1xf32>>{{.*}}, %[[Y:.*]]: !fir.ref<!fir.array<1x3xf32>>{{.*}}, %[[Z:.*]]: !fir.ref<!fir.array<2x2xf32>>{{.*}})
+! CHECK:  %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>>
+! CHECK:  %[[C3:.*]] = arith.constant 3 : index
+! CHECK:  %[[C1:.*]] = arith.constant 1 : index
+! CHECK:  %[[C1_0:.*]] = arith.constant 1 : index
+! CHECK:  %[[C3_1:.*]] = arith.constant 3 : index
+! CHECK:  %[[Z_BOX:.*]] = fir.array_load %[[Z]]({{.*}}) : (!fir.ref<!fir.array<2x2xf32>>, !fir.shape<2>) -> !fir.array<2x2xf32>
+! CHECK:  %[[X_SHAPE:.*]] = fir.shape %[[C3]], %[[C1]] : (index, index) -> !fir.shape<2>
+! CHECK:  %[[X_BOX:.*]] = fir.embox %[[X]](%[[X_SHAPE]]) : (!fir.ref<!fir.array<3x1xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<3x1xf32>>
+! CHECK:  %[[Y_SHAPE:.*]] = fir.shape %[[C1_0]], %[[C3_1]] : (index, index) -> !fir.shape<2>
+! CHECK:  %[[Y_BOX:.*]] = fir.embox %[[Y]](%[[Y_SHAPE]]) : (!fir.ref<!fir.array<1x3xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<1x3xf32>>
+! CHECK:  %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x?xf32>>
+! CHECK:  %[[C0:.*]] = arith.constant 0 : index
+! CHECK:  %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2>
+! CHECK:  %[[RESULT_BOX_VAL:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.heap<!fir.array<?x?xf32>>>
+! CHECK:  fir.store %[[RESULT_BOX_VAL]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+! CHECK:  %[[RESULT_BOX_ADDR_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK:  %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<3x1xf32>>) -> !fir.box<none>
+! CHECK:  %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<1x3xf32>>) -> !fir.box<none>
+! CHECK:  {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_ADDR_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+! CHECK:  %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+! CHECK:  %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
+! CHECK:  %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
+! CHECK:    {{.*}}fir.array_fetch
+! CHECK:    {{.*}}fir.array_update
+! CHECK:    fir.result
+! CHECK:  }
+! CHECK:  fir.array_merge_store %[[Z_BOX]], %[[Z_COPY_FROM_RESULT]] to %[[Z]] : !fir.array<2x2xf32>, !fir.array<2x2xf32>, !fir.ref<!fir.array<2x2xf32>>
+! CHECK:  fir.freemem %[[RESULT_TMP]] : <!fir.array<?x?xf32>>
+subroutine matmul_test(x,y,z)
+  real :: x(3,1), y(1,3), z(2,2)
+  z = matmul(x,y)
+end subroutine
+
+! CHECK-LABEL: matmul_test2
+! CHECK-SAME: (%[[X_BOX:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>{{.*}}, %[[Y_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}}, %[[Z_BOX:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}})
+!CHECK:  %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+!CHECK:  %[[Z:.*]] = fir.array_load %[[Z_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.array<?x!fir.logical<4>>
+!CHECK:  %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.logical<4>>>
+!CHECK:  %[[C0:.*]] = arith.constant 0 : index
+!CHECK:  %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1>
+!CHECK:  %[[RESULT_BOX:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap<!fir.array<?x!fir.logical<4>>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+!CHECK:  fir.store %[[RESULT_BOX]] to %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+!CHECK:  %[[RESULT_BOX_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> !fir.ref<!fir.box<none>>
+!CHECK:  %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.box<none>
+!CHECK:  %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
+!CHECK:  {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+!CHECK:  %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+!CHECK:  %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>) -> !fir.heap<!fir.array<?x!fir.logical<4>>>
+!CHECK:  %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop
+!CHECK:    {{.*}}fir.array_fetch
+!CHECK:    {{.*}}fir.array_update
+!CHECK:    fir.result
+!CHECK:  }
+!CHECK:  fir.array_merge_store %[[Z]], %[[Z_COPY_FROM_RESULT]] to %[[Z_BOX]] : !fir.array<?x!fir.logical<4>>, !fir.array<?x!fir.logical<4>>, !fir.box<!fir.array<?x!fir.logical<4>>>
+!CHECK:  fir.freemem %[[RESULT_TMP]] : <!fir.array<?x!fir.logical<4>>>
+subroutine matmul_test2(X, Y, Z)
+  logical :: X(:,:)
+  logical :: Y(:)
+  logical :: Z(:)
+  Z = matmul(X, Y)
+end subroutine


        


More information about the flang-commits mailing list