@@ -21,6 +21,8 @@ Reference: PhysRevLett 97, 170201 (2006)
2121
2222#include " minimizer_fire_box_change.cuh"
2323#include " utilities/gpu_macro.cuh"
24+ #include < algorithm>
25+ #include < cmath>
2426#include < cstring>
2527
2628namespace
@@ -151,8 +153,8 @@ void get_force_temp(
151153template <int N>
152154void solveLinearEquation (const double * A, const double * B, double * X)
153155{
154-
155156 double a[N][N], b[N][N];
157+
156158 for (int j = 0 ; j < N; ++j) {
157159 for (int i = 0 ; i < N; ++i) {
158160 a[i][j] = A[j * N + i];
@@ -161,27 +163,43 @@ void solveLinearEquation(const double* A, const double* B, double* X)
161163 }
162164
163165 for (int col = 0 ; col < N; ++col) {
164- for (int i = 0 ; i < N; ++i) {
165- if (i == col) {
166- double diag = a[i][col];
167- if (fabs (diag) < 1e-9 ) {
168- printf (" Matrix is singular or nearly singular!\n " );
169- return ;
170- }
171- for (int j = 0 ; j < N; ++j) {
172- a[i][j] /= diag;
173- b[i][j] /= diag;
174- }
175- } else {
176- double factor = a[i][col];
166+ int pivot_row = col;
167+ for (int i = col + 1 ; i < N; ++i) {
168+ if (fabs (a[i][col]) > fabs (a[pivot_row][col])) {
169+ pivot_row = i;
170+ }
171+ }
172+
173+ if (fabs (a[pivot_row][col]) < 1e-9 ) {
174+ printf (" Matrix is singular or nearly singular!\n " );
175+ return ;
176+ }
177+
178+ if (pivot_row != col) {
179+ for (int j = 0 ; j < N; ++j) {
180+ std::swap (a[col][j], a[pivot_row][j]);
181+ std::swap (b[col][j], b[pivot_row][j]);
182+ }
183+ }
184+
185+ double diag = a[col][col];
186+ for (int j = 0 ; j < N; ++j) {
187+ a[col][j] /= diag;
188+ b[col][j] /= diag;
189+ }
190+
191+ for (int row = 0 ; row < N; ++row) {
192+ if (row != col) {
193+ double factor = a[row][col];
177194 for (int j = 0 ; j < N; ++j) {
178- a[i ][j] -= factor * a[col][j];
179- b[i ][j] -= factor * b[col][j];
195+ a[row ][j] -= factor * a[col][j];
196+ b[row ][j] -= factor * b[col][j];
180197 }
181198 }
182199 }
183200 }
184201
202+ // 将计算得到的结果存储回 X 中(行主序)
185203 for (int i = 0 ; i < N; ++i) {
186204 for (int j = 0 ; j < N; ++j) {
187205 X[i * N + j] = b[i][j];
0 commit comments