[flang-commits] [flang] 09472ba - [flang] add hlfir.matmul operation
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Thu Feb 16 07:32:12 PST 2023
Author: Tom Eccles
Date: 2023-02-16T15:30:46Z
New Revision: 09472ba315043748db01b092fdaca54daf19a513
URL: https://github.com/llvm/llvm-project/commit/09472ba315043748db01b092fdaca54daf19a513
DIFF: https://github.com/llvm/llvm-project/commit/09472ba315043748db01b092fdaca54daf19a513.diff
LOG: [flang] add hlfir.matmul operation
Add a HLFIR operation for the MATMUL transformational intrinsic,
according to the design set out in flang/doc/HighLevelFIR.md
Differential Revision: https://reviews.llvm.org/D144094
Added:
flang/test/HLFIR/matmul.fir
Modified:
flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 6a9acb443f9d8..a0bcd33f57e06 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -71,6 +71,7 @@ inline bool isBoxAddressOrValueType(mlir::Type type) {
bool isFortranScalarNumericalType(mlir::Type);
bool isFortranNumericalArrayObject(mlir::Type);
+bool isFortranNumericalOrLogicalArrayObject(mlir::Type);
bool isPassByRefOrIntegerType(mlir::Type);
bool isI1Type(mlir::Type);
// scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
index eca5f05d3d02d..e210de3ce33eb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
@@ -113,6 +113,11 @@ def IsFortranNumericalArrayObjectPred
def AnyFortranNumericalArrayObject : Type<IsFortranNumericalArrayObjectPred,
"any array-like object containing a numerical type">;
+def IsFortranNumericalOrLogicalArrayObjectPred
+ : CPred<"::hlfir::isFortranNumericalOrLogicalArrayObject($_self)">;
+def AnyFortranNumericalOrLogicalArrayObject : Type<IsFortranNumericalOrLogicalArrayObjectPred,
+ "any array-like object containing a numerical or logical type">;
+
def IsPassByRefOrIntegerTypePred
: CPred<"::hlfir::isPassByRefOrIntegerType($_self)">;
def AnyPassByRefOrIntegerType : Type<IsPassByRefOrIntegerTypePred,
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index d929850d9ff28..9f187ac6afca8 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -289,6 +289,33 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
let hasVerifier = 1;
}
+def hlfir_MatmulOp : hlfir_Op<"matmul",
+ [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let summary = "MATMUL transformational intrinsic";
+ let description = [{
+ Matrix multiplication
+ }];
+
+ let arguments = (ins
+ AnyFortranNumericalOrLogicalArrayObject:$lhs,
+ AnyFortranNumericalOrLogicalArrayObject:$rhs,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath
+ );
+
+ let results = (outs hlfir_ExprType);
+
+ let assemblyFormat = [{
+ $lhs $rhs attr-dict `:` functional-type(operands, results)
+ }];
+
+ let builders = [OpBuilder<(ins "mlir::Value":$lhs,
+ "mlir::Value":$rhs,
+ "mlir::Type":$resultType)>];
+
+ let hasVerifier = 1;
+}
+
def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
let summary = "Create a variable from an expression value";
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index f23be5de3be14..3dd757f993848 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -115,6 +115,18 @@ bool hlfir::isFortranNumericalArrayObject(mlir::Type type) {
return false;
}
+bool hlfir::isFortranNumericalOrLogicalArrayObject(mlir::Type type) {
+ if (isBoxAddressType(type))
+ return false;
+ if (auto arrayTy =
+ getFortranElementOrSequenceType(type).dyn_cast<fir::SequenceType>()) {
+ mlir::Type eleTy = arrayTy.getEleTy();
+ return isFortranScalarNumericalType(eleTy) ||
+ mlir::isa<fir::LogicalType>(eleTy);
+ }
+ return false;
+}
+
bool hlfir::isPassByRefOrIntegerType(mlir::Type type) {
mlir::Type unwrappedType = fir::unwrapPassByRefType(type);
return fir::isa_integer(unwrappedType);
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 39b36e7c4a3a1..5fdc909f7500d 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -510,6 +510,76 @@ void hlfir::SumOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
build(builder, result, resultType, array, dim, mask);
}
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::MatmulOp::verify() {
+ mlir::Value lhs = getLhs();
+ mlir::Value rhs = getRhs();
+ fir::SequenceType lhsTy =
+ hlfir::getFortranElementOrSequenceType(lhs.getType())
+ .cast<fir::SequenceType>();
+ fir::SequenceType rhsTy =
+ hlfir::getFortranElementOrSequenceType(rhs.getType())
+ .cast<fir::SequenceType>();
+ llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ std::size_t lhsRank = lhsShape.size();
+ std::size_t rhsRank = rhsShape.size();
+ mlir::Type lhsEleTy = lhsTy.getEleTy();
+ mlir::Type rhsEleTy = rhsTy.getEleTy();
+ hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
+ llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+ mlir::Type resultEleTy = resultTy.getEleTy();
+
+ if (((lhsRank != 1) && (lhsRank != 2)) || ((rhsRank != 1) && (rhsRank != 2)))
+ return emitOpError("array must have either rank 1 or rank 2");
+
+ if ((lhsRank == 1) && (rhsRank == 1))
+ return emitOpError("at least one array must have rank 2");
+
+ if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+ mlir::isa<fir::LogicalType>(rhsEleTy))
+ return emitOpError("if one array is logical, so should the other be");
+
+ int64_t lastLhsDim = lhsShape[lhsRank - 1];
+ int64_t firstRhsDim = rhsShape[0];
+ constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+ if (lastLhsDim != firstRhsDim)
+ if ((lastLhsDim != unknownExtent) && (firstRhsDim != unknownExtent))
+ return emitOpError(
+ "the last dimension of LHS should match the first dimension of RHS");
+
+ if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+ mlir::isa<fir::LogicalType>(resultEleTy))
+ return emitOpError("the result type should be a logical only if the "
+ "argument types are logical");
+
+ llvm::SmallVector<int64_t, 2> expectedResultShape;
+ if (lhsRank == 2) {
+ if (rhsRank == 2) {
+ expectedResultShape.push_back(lhsShape[0]);
+ expectedResultShape.push_back(rhsShape[1]);
+ } else {
+ // rhsRank == 1
+ expectedResultShape.push_back(lhsShape[0]);
+ }
+ } else {
+ // lhsRank == 1
+ // rhsRank == 2
+ expectedResultShape.push_back(rhsShape[1]);
+ }
+ if (resultShape.size() != expectedResultShape.size())
+ return emitOpError("incorrect result shape");
+ if (resultShape[0] != expectedResultShape[0])
+ return emitOpError("incorrect result shape");
+ if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1])
+ return emitOpError("incorrect result shape");
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// AssociateOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index dd801634c6f7d..b4afc8fe48a11 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -319,3 +319,59 @@ func.func @bad_sum4(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.
// expected-error at +1 {{'hlfir.sum' op result rank must be one less than ARRAY}}
%0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
}
+
+// -----
+func.func @bad_matmul1(%arg0: !hlfir.expr<?x?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op array must have either rank 1 or rank 2}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+ return
+}
+
+// -----
+func.func @bad_matmul2(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op at least one array must have rank 2}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?x?xi32>
+ return
+}
+
+// -----
+func.func @bad_matmul3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op if one array is logical, so should the other be}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+ return
+}
+
+// -----
+func.func @bad_matmul4(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<200x?xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op the last dimension of LHS should match the first dimension of RHS}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<200x?xi32>) -> !hlfir.expr<?x?xi32>
+ return
+}
+
+// -----
+func.func @bad_matmul5(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op the result type should be a logical only if the argument types are logical}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?x!fir.logical<4>>
+ return
+}
+
+// -----
+func.func @bad_matmul6(%arg0: !hlfir.expr<1x2xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op incorrect result shape}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<1x2xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32>
+ return
+}
+
+// -----
+func.func @bad_matmul7(%arg0: !hlfir.expr<1x2xi32>, %arg1: !hlfir.expr<2xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op incorrect result shape}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<1x2xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32>
+ return
+}
+
+// -----
+func.func @bad_matmul8(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+ // expected-error at +1 {{'hlfir.matmul' op incorrect result shape}}
+ %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<1x3xi32>
+ return
+}
diff --git a/flang/test/HLFIR/matmul.fir b/flang/test/HLFIR/matmul.fir
new file mode 100644
index 0000000000000..eaf85c3969737
--- /dev/null
+++ b/flang/test/HLFIR/matmul.fir
@@ -0,0 +1,99 @@
+// Test hlfir.matmul operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// arguments are expressions of known shape
+func.func @matmul0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+ return
+}
+// CHECK-LABEL: func.func @matmul0
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x2xi32>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// arguments are expressions of assumed shape
+func.func @matmul1(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+ return
+}
+// CHECK-LABEL: func.func @matmul1
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<?x?xi32>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// arguments are expressions where only some dimensions are known #1
+func.func @matmul2(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr<?x2xi32>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+ return
+}
+// CHECK-LABEL: func.func @matmul2
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x?xi32>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<?x2xi32>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// arguments are expressions where only some dimensions are known #2
+func.func @matmul3(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<2x?xi32>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+ return
+}
+// CHECK-LABEL: func.func @matmul3
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<?x2xi32>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x2xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// arguments are logicals
+func.func @matmul4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?x!fir.logical<4>>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+ return
+}
+// CHECK-LABEL: func.func @matmul4
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<?x?x!fir.logical<4>>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<?x?x!fir.logical<4>>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// lhs is rank 1
+func.func @matmul5(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?xi32>
+ return
+}
+// CHECK-LABEL: func.func @matmul5
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<?xi32>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<?x?xi32>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// rhs is rank 1
+func.func @matmul6(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?xi32>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+ return
+}
+// CHECK-LABEL: func.func @matmul6
+// CHECK: %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK: %[[ARG1:.*]]: !hlfir.expr<?xi32>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// arguments are boxed arrays
+func.func @matmul7(%arg0: !fir.box<!fir.array<2x2xf32>>, %arg1: !fir.box<!fir.array<2x2xf32>>) {
+ %res = hlfir.matmul %arg0 %arg1 : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+ return
+}
+// CHECK-LABEL: func.func @matmul7
+// CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<2x2xf32>>,
+// CHECK: %[[ARG1:.*]]: !fir.box<!fir.array<2x2xf32>>) {
+// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
More information about the flang-commits
mailing list