[flang-commits] [flang] 49bd444 - [flang][hlfir] add hlfir.matmul_transpose operation

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Fri Mar 17 02:30:57 PDT 2023


Author: Tom Eccles
Date: 2023-03-17T09:30:04Z
New Revision: 49bd444fc3617a140ef67047d756c4d652a2a835

URL: https://github.com/llvm/llvm-project/commit/49bd444fc3617a140ef67047d756c4d652a2a835
DIFF: https://github.com/llvm/llvm-project/commit/49bd444fc3617a140ef67047d756c4d652a2a835.diff

LOG: [flang][hlfir] add hlfir.matmul_transpose operation

This operation will be used to transform MATMUL(TRANSPOSE(a), b). The
transformation will go in the following stages:
        1. Lowering to hlfir.transpose and hlfir.matmul
        2. Canonicalise to hlfir.matmul_transpose
        3. hlfir.matmul_transpose will be lowered to FIR as a new runtime
          library call

Step 2 (and this operation) are included for consistency with the other
hlfir intrinsic operations and to avoid mixing concerns in the intrinsic
lowering pass.

In step 3, a new runtime library call is used because this operation is
most easily implemented in one go (the transposed indexing actually
makes the indexing simpler than for a normal matrix multiplication). In
the long run, it is intended that HLFIR will allow the same buffer
to be shared between different runtime calls without temporary
allocations, but in this specific case we can do even better than that
with a dedicated implementation.

This should speed up galgel from SPEC2000 (but this hadn't been tested
yet). The optimization was implemented in Classic Flang.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D145957

Added: 
    flang/test/HLFIR/matmul_transpose.fir

Modified: 
    flang/docs/HighLevelFIR.md
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/test/HLFIR/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/docs/HighLevelFIR.md b/flang/docs/HighLevelFIR.md
index 033ac4102ba07..8c671ae216678 100644
--- a/flang/docs/HighLevelFIR.md
+++ b/flang/docs/HighLevelFIR.md
@@ -652,7 +652,6 @@ Syntax:
 %element = hlfir.apply %array_expr %i, %j: (hlfir.expr<?x?xi32>) -> i32
 ```
 
-
 #### Introducing operations for transformational intrinsic functions
 
 Motivation: Represent transformational intrinsics functions at a high-level so
@@ -701,6 +700,39 @@ call will probably be used since there is little point to keep them high level:
 - selected_char_kind, selected_int_kind, selected_real_kind that returns scalar
   integers
 
+#### Introducing operations for composed intrinsic functions
+
+Motivation: optimize commonly composed intrinsic functions (e.g.
+MATMUL(TRANSPOSE(a), b)). This optimization is implemented in Classic Flang.
+
+An operation and runtime function will be added for each commonly used
+composition of intrinsic functions. The operation will be the canonical way to
+write this chained operation (the MLIR canonicalization pass will rewrite the
+operations for the composed intrinsics into this one operation).
+
+These new operations will be treated as though they were standard
+transformational intrinsic functions.
+
+The composed intrinsic operation will return a hlfir.expr<T>. The arguments
+may be hlfir.expr<T>, boxed arrays, simple scalar types (e.g. i32, f32), or
+variables.
+
+To keep things simple, these operations will only match one form of the composed
+intrinsic functions: therefore there will be no optional arguments.
+
+Syntax:
+```
+%res = hlfir."intrinsic_name" %expr_or_var, ...
+```
+
+The composed intrinsic operation will be lowered to a `fir.call` to the newly
+added runtime implementation of the operation.
+
+These operations should not be added where the only improvement is to avoid
+creating a temporary intermediate buffer which would otherwise be removed by
+intelligent bufferization of a hlfir.expr. Similarly, these should not replace
+profitable uses of hlfir.elemental.
+
 #### Introducing operations for character operations and elemental intrinsic functions
 
 

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index b797cd2cd17e1..86318f64bfe9e 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -402,6 +402,29 @@ def hlfir_TransposeOp : hlfir_Op<"transpose", []> {
   let hasVerifier = 1;
 }
 
+def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "Optimized MATMUL(TRANSPOSE(...), ...)";
+  let description = [{
+    Matrix multiplication where the left hand side is transposed
+  }];
+
+  let arguments = (ins
+    AnyFortranNumericalOrLogicalArrayObject:$lhs,
+    AnyFortranNumericalOrLogicalArrayObject:$rhs,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+
+  let results = (outs hlfir_ExprType);
+
+  let assemblyFormat = [{
+    $lhs $rhs attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
   let summary = "Create a variable from an expression value";

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 13103c50d0dfe..0114c12117e29 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -668,6 +668,71 @@ mlir::LogicalResult hlfir::TransposeOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// MatmulTransposeOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::MatmulTransposeOp::verify() {
+  mlir::Value lhs = getLhs();
+  mlir::Value rhs = getRhs();
+  fir::SequenceType lhsTy =
+      hlfir::getFortranElementOrSequenceType(lhs.getType())
+          .cast<fir::SequenceType>();
+  fir::SequenceType rhsTy =
+      hlfir::getFortranElementOrSequenceType(rhs.getType())
+          .cast<fir::SequenceType>();
+  llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  std::size_t lhsRank = lhsShape.size();
+  std::size_t rhsRank = rhsShape.size();
+  mlir::Type lhsEleTy = lhsTy.getEleTy();
+  mlir::Type rhsEleTy = rhsTy.getEleTy();
+  hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
+  llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+  mlir::Type resultEleTy = resultTy.getEleTy();
+
+  // lhs must have rank 2 for the transpose to be valid
+  if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2)))
+    return emitOpError("array must have either rank 1 or rank 2");
+
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(rhsEleTy))
+    return emitOpError("if one array is logical, so should the other be");
+
+  // for matmul we compare the last dimension of lhs with the first dimension of
+  // rhs, but for MatmulTranspose, dimensions of lhs are inverted by the
+  // transpose
+  int64_t firstLhsDim = lhsShape[0];
+  int64_t firstRhsDim = rhsShape[0];
+  constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+  if (firstLhsDim != firstRhsDim)
+    if ((firstLhsDim != unknownExtent) && (firstRhsDim != unknownExtent))
+      return emitOpError(
+          "the first dimension of LHS should match the first dimension of RHS");
+
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(resultEleTy))
+    return emitOpError("the result type should be a logical only if the "
+                       "argument types are logical");
+
+  llvm::SmallVector<int64_t, 2> expectedResultShape;
+  if (rhsRank == 2) {
+    expectedResultShape.push_back(lhsShape[1]);
+    expectedResultShape.push_back(rhsShape[1]);
+  } else {
+    // rhsRank == 1
+    expectedResultShape.push_back(lhsShape[1]);
+  }
+  if (resultShape.size() != expectedResultShape.size())
+    return emitOpError("incorrect result shape");
+  if (resultShape[0] != expectedResultShape[0])
+    return emitOpError("incorrect result shape");
+  if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1])
+    return emitOpError("incorrect result shape");
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index 2ec7c689b5ea1..a8ba337ad8b5a 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -397,6 +397,48 @@ func.func @bad_transpose3(%arg0: !hlfir.expr<2x3xi32>) {
   return
 }
 
+// -----
+func.func @bad_matmultranspose1(%arg0: !hlfir.expr<?x?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose2(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul_transpose' op if one array is logical, so should the other be}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose5(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul_transpose' op the result type should be a logical only if the argument types are logical}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose6(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+  // expected-error at +1 {{'hlfir.matmul_transpose' op incorrect result shape}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose7(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2xi32>) {
+  // expected-error at +1 {{'hlfir.matmul_transpose' op incorrect result shape}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32>
+  return
+}
+
 // -----
 func.func @bad_assign_1(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.box<!fir.array<?xi32>>) {
   // expected-error at +1 {{'hlfir.assign' op lhs must be an allocatable when `realloc` is set}}

diff  --git a/flang/test/HLFIR/matmul_transpose.fir b/flang/test/HLFIR/matmul_transpose.fir
new file mode 100644
index 0000000000000..967edecb463b2
--- /dev/null
+++ b/flang/test/HLFIR/matmul_transpose.fir
@@ -0,0 +1,87 @@
+// Test hlfir.matmul_transpose operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// arguments are expressions of known shape
+func.func @matmul_transpose0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose0
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2x2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions of assumed shape
+func.func @matmul_transpose1(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose1
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions where only some dimensions are known #1
+func.func @matmul_transpose2(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<?x2xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose2
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x2xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x2xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions where only some dimensions are known #2
+func.func @matmul_transpose3(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr<2x?xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose3
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are logicals
+func.func @matmul_transpose4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?x!fir.logical<4>>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose4
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?x!fir.logical<4>>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?x!fir.logical<4>>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// rhs is rank 1
+func.func @matmul_transpose6(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose6
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are boxed arrays
+func.func @matmul_transpose7(%arg0: !fir.box<!fir.array<2x2xf32>>, %arg1: !fir.box<!fir.array<2x2xf32>>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose7
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<2x2xf32>>,
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<2x2xf32>>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }


        


More information about the flang-commits mailing list