[clang] [llvm] [HLSL][Matrix] Support row and column indexing modes for MatrixSubscriptExpr (PR #171564)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 19 11:12:06 PST 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/171564
>From d7f655f5dfc86edcc68bc35916d13ee5fbd6abe8 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 10 Dec 2025 00:41:04 -0500
Subject: [PATCH 1/2] [HLSL][Matrix] Support row and column indexing modes for
MatrixSubscriptExpr
fixes #167617
In DXC HLSL supports different indexing modes via codegen for its
equivalent of the MatrixSubscriptExpr when the /Zpr and /Zpc flags are
used see: https://godbolt.org/z/bz5Y5WG36.
This change modifies EmitMatrixSubscriptExpr to consider the
MatrixRowMajor/MatrixColMajor Layout flags before generating an index.
Similarly it introduces `createRowMajorIndex` and
`createColumnMajorIndex` in `MatrixBuilder.h` for use in
`VisitMatrixSubscriptExpr`.
---
clang/lib/CodeGen/CGExpr.cpp | 18 ++++--
clang/lib/CodeGen/CGExprScalar.cpp | 13 +++-
clang/test/CodeGen/matrix-type-indexing.c | 60 +++++++++++++++++++
.../test/CodeGenCXX/matrix-type-indexing.cpp | 36 +++++++++++
.../BasicFeatures/matrix-type-indexing.hlsl | 52 ++++++++++++++++
llvm/include/llvm/IR/MatrixBuilder.h | 17 +++++-
6 files changed, 187 insertions(+), 9 deletions(-)
create mode 100644 clang/test/CodeGen/matrix-type-indexing.c
create mode 100644 clang/test/CodeGenCXX/matrix-type-indexing.cpp
create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 27ee96cb6dc82..3cf835c1ba516 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4977,11 +4977,19 @@ LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
llvm::Value *ColIdx = EmitMatrixIndexExpr(E->getColumnIdx());
- llvm::Value *NumRows = Builder.getIntN(
- RowIdx->getType()->getScalarSizeInBits(),
- E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumRows());
- llvm::Value *FinalIdx =
- Builder.CreateAdd(Builder.CreateMul(ColIdx, NumRows), RowIdx);
+ llvm::Value *FinalIdx;
+ if (getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor) {
+ llvm::Value *NumCols = Builder.getIntN(
+ RowIdx->getType()->getScalarSizeInBits(),
+ E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumColumns());
+ FinalIdx = Builder.CreateAdd(Builder.CreateMul(RowIdx, NumCols), ColIdx);
+ } else {
+ llvm::Value *NumRows = Builder.getIntN(
+ RowIdx->getType()->getScalarSizeInBits(),
+ E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumRows());
+ FinalIdx = Builder.CreateAdd(Builder.CreateMul(ColIdx, NumRows), RowIdx);
+ }
return LValue::MakeMatrixElt(
MaybeConvertMatrixAddress(Base.getAddress(), *this), FinalIdx,
E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 1de1378d25249..72aad9707a67e 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2155,9 +2155,18 @@ Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
Value *ColumnIdx = CGF.EmitMatrixIndexExpr(E->getColumnIdx());
const auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
- unsigned NumRows = MatrixTy->getNumRows();
llvm::MatrixBuilder MB(Builder);
- Value *Idx = MB.CreateIndex(RowIdx, ColumnIdx, NumRows);
+
+ Value *Idx;
+ if (CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor) {
+ unsigned NumCols = MatrixTy->getNumColumns();
+ Idx = MB.createRowMajorIndex(RowIdx, ColumnIdx, NumCols);
+ } else {
+ unsigned NumRows = MatrixTy->getNumRows();
+ Idx = MB.createColumnMajorIndex(RowIdx, ColumnIdx, NumRows);
+ }
+
if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
MB.CreateIndexAssumption(Idx, MatrixTy->getNumElementsFlattened());
diff --git a/clang/test/CodeGen/matrix-type-indexing.c b/clang/test/CodeGen/matrix-type-indexing.c
new file mode 100644
index 0000000000000..d76d14c3f67ef
--- /dev/null
+++ b/clang/test/CodeGen/matrix-type-indexing.c
@@ -0,0 +1,60 @@
+// RUN: %clang_cc1 -fenable-matrix -fmatrix-memory-layout=row-major -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,ROW-CHECK
+// RUN: %clang_cc1 -fenable-matrix -fmatrix-memory-layout=column-major -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
+
+typedef float fx2x3_t __attribute__((matrix_type(2, 3)));
+ float Out[6];
+
+ fx2x3_t gM;
+
+void binaryOpMatrixSubscriptExpr(int index, fx2x3_t M) {
+ // CHECK-LABEL: binaryOpMatrixSubscriptExpr
+ // CHECK: %row = alloca i32, align 4
+ // CHECK: %col = alloca i32, align 4
+ // CHECK: [[row_load:%.*]] = load i32, ptr %row, align 4
+ // CHECK-NEXT: [[row_load_zext:%.*]] = zext i32 [[row_load]] to i64
+ // CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col, align 4
+ // CHECK-NEXT: [[col_load_zext:%.*]] = zext i32 [[col_load]] to i64
+ // COL-CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_zext]], 2
+ // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_zext]]
+ // ROW-CHECK-NEXT: [[row_offset:%.*]] = mul i64 [[row_load_zext]], 3
+ // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i64 [[row_offset]], [[col_load_zext]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
+ // COL-CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
+ // ROW-CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[row_major_index]]
+ const unsigned int COLS = 3;
+ unsigned int row = index / COLS;
+ unsigned int col = index % COLS;
+ Out[index] = M[row][col];
+}
+
+float returnMatrixSubscriptExpr(int row, int col, fx2x3_t M) {
+ // CHECK-LABEL: returnMatrixSubscriptExpr
+ // CHECK: [[row_load:%.*]] = load i32, ptr [[row_ptr:%.*]], align 4
+ // CHECK-NEXT: [[row_load_sext:%.*]] = sext i32 [[row_load]] to i64
+ // CHECK-NEXT: [[col_load:%.*]] = load i32, ptr [[col_ptr:%.*]], align 4
+ // CHECK-NEXT: [[col_load_sext:%.*]] = sext i32 [[col_load]] to i64
+ // COL-CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_sext]], 2
+ // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_sext]]
+ // ROW-CHECK-NEXT: [[row_offset:%.*]] = mul i64 [[row_load_sext]], 3
+ // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i64 [[row_offset]], [[col_load_sext]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
+ // COL-CHECK-NEXT: [[matrix_after_extract:%.*]] = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
+ // ROW-CHECK-NEXT: [[matrix_after_extract:%.*]] = extractelement <6 x float> [[matrix_as_vec]], i64 [[row_major_index]]
+ // CHECK-NEXT: ret float [[matrix_after_extract]]
+ return M[row][col];
+}
+
+void storeAtMatrixSubscriptExpr(int row, int col, float value) {
+ // CHECK-LABEL: storeAtMatrixSubscriptExpr
+ // CHECK: [[value_load:%.*]] = load float, ptr [[value_ptr:%.*]], align 4
+ // ROW-CHECK: [[row_offset:%.*]] = mul i64 [[row_load:%.*]], 3
+ // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i64 [[row_offset]], [[col_load:%.*]]
+ // COL-CHECK: [[col_offset:%.*]] = mul i64 [[col_load:%.*]], 2
+ // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load:%.*]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr @gM, align 4
+ // ROW-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x float> [[matrix_as_vec]], float [[value_load]], i64 [[row_major_index]]
+ // COL-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x float> [[matrix_as_vec]], float [[value_load]], i64 [[col_major_index]]
+ // CHECK-NEXT: store <6 x float> [[matrix_after_insert]], ptr @gM, align 4
+ gM[row][col] = value;
+}
diff --git a/clang/test/CodeGenCXX/matrix-type-indexing.cpp b/clang/test/CodeGenCXX/matrix-type-indexing.cpp
new file mode 100644
index 0000000000000..a390bcaf89119
--- /dev/null
+++ b/clang/test/CodeGenCXX/matrix-type-indexing.cpp
@@ -0,0 +1,36 @@
+// RUN: %clang_cc1 -fenable-matrix -fmatrix-memory-layout=column-major -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck %s
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck %s
+
+ typedef float fx2x3_t __attribute__((matrix_type(2, 3)));
+ float Out[6];
+
+void binaryOpMatrixSubscriptExpr(int index, fx2x3_t M) {
+ // CHECK-LABEL: binaryOpMatrixSubscriptExpr
+ // CHECK: %row = alloca i32, align 4
+ // CHECK: %col = alloca i32, align 4
+ // CHECK: [[row_load:%.*]] = load i32, ptr %row, align 4
+ // CHECK-NEXT: [[row_load_zext:%.*]] = zext i32 [[row_load]] to i64
+ // CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col, align 4
+ // CHECK-NEXT: [[col_load_zext:%.*]] = zext i32 [[col_load]] to i64
+ // CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_zext]], 2
+ // CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_zext]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
+ // CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
+ const unsigned int COLS = 3;
+ unsigned int row = index / COLS;
+ unsigned int col = index % COLS;
+ Out[index] = M[row][col];
+}
+
+float returnMatrixSubscriptExpr(int row, int col, fx2x3_t M) {
+ // CHECK-LABEL: returnMatrixSubscriptExpr
+ // CHECK: [[row_load:%.*]] = load i32, ptr %row.addr, align 4
+ // CHECK-NEXT: [[row_load_sext:%.*]] = sext i32 [[row_load]] to i64
+ // CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col.addr, align 4
+ // CHECK-NEXT: [[col_load_sext:%.*]] = sext i32 [[col_load]] to i64
+ // CHECK-NEXT: [[col_offset:%.*]] = mul i64 [[col_load_sext]], 2
+ // CHECK-NEXT: [[col_major_index:%.*]] = add i64 [[col_offset]], [[row_load_sext]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x float>, ptr %M.addr, align 4
+ // CHECK-NEXT: %matrixext = extractelement <6 x float> [[matrix_as_vec]], i64 [[col_major_index]]
+ return M[row][col];
+}
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl
new file mode 100644
index 0000000000000..7a63bbb45ecf7
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl
@@ -0,0 +1,52 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -fmatrix-memory-layout=row-major -o - | FileCheck %s --check-prefixes=CHECK,ROW-CHECK
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -fmatrix-memory-layout=column-major -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,COL-CHECK
+
+RWBuffer<int> Out : register(u1);
+half2x3 gM;
+
+
+void binaryOpMatrixSubscriptExpr(int index, half2x3 M) {
+ // CHECK-LABEL: binaryOpMatrixSubscriptExpr
+ // CHECK: %row = alloca i32, align 4
+ // CHECK: %col = alloca i32, align 4
+ // CHECK: [[row_load:%.*]] = load i32, ptr %row, align 4
+ // CHECK-NEXT: [[col_load:%.*]] = load i32, ptr %col, align 4
+ // ROW-CHECK-NEXT: [[row_offset:%.*]] = mul i32 [[row_load]], 3
+ // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load]]
+ // COL-CHECK-NEXT: [[col_offset:%.*]] = mul i32 [[col_load]], 2
+ // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr %M.addr, align 2
+ // ROW-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[row_major_index]]
+ // COL-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[col_major_index]]
+ const uint COLS = 3;
+ uint row = index / COLS;
+ uint col = index % COLS;
+ Out[index] = M[row][col];
+}
+
+half returnMatrixSubscriptExpr(int row, int col, half2x3 M) {
+ // CHECK-LABEL: returnMatrixSubscriptExpr
+ // ROW-CHECK: [[row_offset:%.*]] = mul i32 [[row_load:%.*]], 3
+ // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load:%.*]]
+ // COL-CHECK: [[col_offset:%.*]] = mul i32 [[col_load:%.*]], 2
+ // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load:%.*]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr %M.addr, align 2
+ // ROW-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[row_major_index]]
+ // COL-CHECK-NEXT: %matrixext = extractelement <6 x half> [[matrix_as_vec]], i32 [[col_major_index]]
+ return M[row][col];
+}
+
+void storeAtMatrixSubscriptExpr(int row, int col, half value) {
+ // CHECK-LABEL: storeAtMatrixSubscriptExpr
+ // CHECK: [[value_load:%.*]] = load half, ptr [[value_ptr:%.*]], align 2
+ // ROW-CHECK: [[row_offset:%.*]] = mul i32 [[row_load:%.*]], 3
+ // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load:%.*]]
+ // COL-CHECK: [[col_offset:%.*]] = mul i32 [[col_load:%.*]], 2
+ // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load:%.*]]
+ // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr addrspace(2) @gM, align 2
+ // ROW-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x half> [[matrix_as_vec]], half [[value_load]], i32 [[row_major_index]]
+ // COL-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x half> [[matrix_as_vec]], half [[value_load]], i32 [[col_major_index]]
+ // CHECK-NEXT: store <6 x half> [[matrix_after_insert]], ptr addrspace(2) @gM, align 2
+ gM[row][col] = value;
+}
diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h
index 3a04ca87f2b55..0cea0f0ddafa3 100644
--- a/llvm/include/llvm/IR/MatrixBuilder.h
+++ b/llvm/include/llvm/IR/MatrixBuilder.h
@@ -241,8 +241,8 @@ class MatrixBuilder {
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
/// a matrix with \p NumRows embedded in a vector.
- Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
- Twine const &Name = "") {
+ Value *createColumnMajorIndex(Value *RowIdx, Value *ColumnIdx,
+ unsigned NumRows, Twine const &Name = "") {
unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
ColumnIdx->getType()->getScalarSizeInBits());
Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
@@ -251,6 +251,19 @@ class MatrixBuilder {
Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
}
+
+ /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
+ /// a matrix with \p NumCols embedded in a vector.
+ Value *createRowMajorIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumCols,
+ Twine const &Name = "") {
+ unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
+ ColumnIdx->getType()->getScalarSizeInBits());
+ Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
+ RowIdx = B.CreateZExt(RowIdx, IntTy);
+ ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
+ Value *NumColsV = B.getIntN(MaxWidth, NumCols);
+ return B.CreateAdd(B.CreateMul(RowIdx, NumColsV), ColumnIdx);
+ }
};
} // end namespace llvm
>From 66d1eb9e69b03bcfa108c8c7187a2df9bd36f7eb Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 19 Dec 2025 13:58:31 -0500
Subject: [PATCH 2/2] address the new CreateIndex cases
---
clang/lib/CodeGen/CGExpr.cpp | 31 ++++++++++++++--------------
clang/lib/CodeGen/CGExprScalar.cpp | 18 ++++++++--------
llvm/include/llvm/IR/MatrixBuilder.h | 30 +++++++++++++++------------
3 files changed, 42 insertions(+), 37 deletions(-)
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 3cf835c1ba516..518cb50bb7722 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2482,7 +2482,10 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
for (unsigned Col = 0; Col < NumCols; ++Col) {
llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
- llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+ bool IsMatrixRowMajor = getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+ llvm::Value *EltIndex =
+ MB.createIndex(Row, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
Result = Builder.CreateInsertElement(Result, Elt, Lane);
@@ -2733,7 +2736,10 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
for (unsigned Col = 0; Col < NumCols; ++Col) {
llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
- llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+ bool IsMatrixRowMajor = getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+ llvm::Value *EltIndex =
+ MB.createIndex(Row, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
@@ -4976,20 +4982,15 @@ LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
// Extend or truncate the index type to 32 or 64-bits if needed.
llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
llvm::Value *ColIdx = EmitMatrixIndexExpr(E->getColumnIdx());
+ llvm::MatrixBuilder MB(Builder);
+ const auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+ unsigned NumCols = MatrixTy->getNumColumns();
+ unsigned NumRows = MatrixTy->getNumRows();
+ bool IsMatrixRowMajor = getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+ llvm::Value *FinalIdx =
+ MB.createIndex(RowIdx, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
- llvm::Value *FinalIdx;
- if (getLangOpts().getDefaultMatrixMemoryLayout() ==
- LangOptions::MatrixMemoryLayout::MatrixRowMajor) {
- llvm::Value *NumCols = Builder.getIntN(
- RowIdx->getType()->getScalarSizeInBits(),
- E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumColumns());
- FinalIdx = Builder.CreateAdd(Builder.CreateMul(RowIdx, NumCols), ColIdx);
- } else {
- llvm::Value *NumRows = Builder.getIntN(
- RowIdx->getType()->getScalarSizeInBits(),
- E->getBase()->getType()->castAs<ConstantMatrixType>()->getNumRows());
- FinalIdx = Builder.CreateAdd(Builder.CreateMul(ColIdx, NumRows), RowIdx);
- }
return LValue::MakeMatrixElt(
MaybeConvertMatrixAddress(Base.getAddress(), *this), FinalIdx,
E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 72aad9707a67e..c4c97dc3773ae 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2136,7 +2136,10 @@ Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
for (unsigned Col = 0; Col != NumColumns; ++Col) {
Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col);
- Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx");
+ bool IsMatrixRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+ Value *EltIdx = MB.createIndex(RowIdx, ColVal, NumRows, NumColumns,
+ IsMatrixRowMajor, "matrix_row_idx");
Value *Elt =
Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem");
Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
@@ -2158,14 +2161,11 @@ Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
llvm::MatrixBuilder MB(Builder);
Value *Idx;
- if (CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
- LangOptions::MatrixMemoryLayout::MatrixRowMajor) {
- unsigned NumCols = MatrixTy->getNumColumns();
- Idx = MB.createRowMajorIndex(RowIdx, ColumnIdx, NumCols);
- } else {
- unsigned NumRows = MatrixTy->getNumRows();
- Idx = MB.createColumnMajorIndex(RowIdx, ColumnIdx, NumRows);
- }
+ unsigned NumCols = MatrixTy->getNumColumns();
+ unsigned NumRows = MatrixTy->getNumRows();
+ bool IsMatrixRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+ Idx = MB.createIndex(RowIdx, ColumnIdx, NumRows, NumCols, IsMatrixRowMajor);
if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
MB.CreateIndexAssumption(Idx, MatrixTy->getNumElementsFlattened());
diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h
index 0cea0f0ddafa3..6b6926c2d00a7 100644
--- a/llvm/include/llvm/IR/MatrixBuilder.h
+++ b/llvm/include/llvm/IR/MatrixBuilder.h
@@ -238,30 +238,34 @@ class MatrixBuilder {
else
B.CreateAssumption(Cmp);
}
-
- /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
- /// a matrix with \p NumRows embedded in a vector.
- Value *createColumnMajorIndex(Value *RowIdx, Value *ColumnIdx,
- unsigned NumRows, Twine const &Name = "") {
+ Value *createIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
+ unsigned NumCols, bool IsMatrixRowMajor = false,
+ Twine const &Name = "") {
unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
ColumnIdx->getType()->getScalarSizeInBits());
Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
RowIdx = B.CreateZExt(RowIdx, IntTy);
ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
+ if (IsMatrixRowMajor) {
+ Value *NumColsV = B.getIntN(MaxWidth, NumCols);
+ return createRowMajorIndex(RowIdx, ColumnIdx, NumColsV, Name);
+ }
Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
+ return createColumnMajorIndex(RowIdx, ColumnIdx, NumRowsV, Name);
+ }
+
+private:
+ /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
+ /// a matrix with \p NumRows embedded in a vector.
+ Value *createColumnMajorIndex(Value *RowIdx, Value *ColumnIdx,
+ Value *NumRowsV, Twine const &Name) {
return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
}
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
/// a matrix with \p NumCols embedded in a vector.
- Value *createRowMajorIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumCols,
- Twine const &Name = "") {
- unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
- ColumnIdx->getType()->getScalarSizeInBits());
- Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
- RowIdx = B.CreateZExt(RowIdx, IntTy);
- ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
- Value *NumColsV = B.getIntN(MaxWidth, NumCols);
+ Value *createRowMajorIndex(Value *RowIdx, Value *ColumnIdx, Value *NumColsV,
+ Twine const &Name) {
return B.CreateAdd(B.CreateMul(RowIdx, NumColsV), ColumnIdx);
}
};
More information about the llvm-commits
mailing list