Added pivoting to qwp solver
This commit is contained in:
parent
69071bbc5f
commit
ec92c89659
1 changed files with 43 additions and 14 deletions
|
@ -255,33 +255,62 @@ struct QwpWorker
|
|||
// (least squares fit for "A*x = y")
|
||||
//
|
||||
// Using gaussian elimination to get M := [Id x]
|
||||
// (no pivoting, so let's hope for the best..)
|
||||
|
||||
// eliminate to upper triangular matrix
|
||||
vector<int> pivot_cache;
|
||||
vector<int> queue;
|
||||
|
||||
for (int i = 0; i < N; i++)
|
||||
queue.push_back(i);
|
||||
|
||||
// gaussian elimination
|
||||
for (int i = 0; i < N; i++)
|
||||
{
|
||||
// find best row
|
||||
int best_row = queue.front();
|
||||
int best_row_queue_idx = 0;
|
||||
double best_row_absval = 0;
|
||||
|
||||
for (int k = 0; k < GetSize(queue); k++) {
|
||||
int row = queue[k];
|
||||
double absval = fabs(M[i + row*N1]);
|
||||
if (absval > best_row_absval) {
|
||||
best_row = row;
|
||||
best_row_queue_idx = k;
|
||||
best_row_absval = absval;
|
||||
}
|
||||
}
|
||||
|
||||
int row = best_row;
|
||||
pivot_cache.push_back(row);
|
||||
|
||||
queue[best_row_queue_idx] = queue.back();
|
||||
queue.pop_back();
|
||||
|
||||
// normalize row
|
||||
for (int j = i+1; j < N+1; j++)
|
||||
M[(N+1)*i + j] /= M[(N+1)*i + i];
|
||||
M[(N+1)*i + i] = 1.0;
|
||||
for (int k = i+1; k < N1; k++)
|
||||
M[k + row*N1] /= M[i + row*N1];
|
||||
M[i + row*N1] = 1.0;
|
||||
|
||||
// elimination
|
||||
for (int j = i+1; j < N; j++) {
|
||||
double d = M[(N+1)*j + i];
|
||||
for (int k = 0; k < N+1; k++)
|
||||
if (k > i)
|
||||
M[(N+1)*j + k] -= d*M[(N+1)*i + k];
|
||||
else
|
||||
M[(N+1)*j + k] = 0.0;
|
||||
for (int other_row : queue) {
|
||||
double d = M[i + other_row*N1];
|
||||
for (int k = i+1; k < N1; k++)
|
||||
M[k + other_row*N1] -= d*M[k + row*N1];
|
||||
M[i + other_row*N1] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
log_assert(queue.empty());
|
||||
log_assert(GetSize(pivot_cache) == N);
|
||||
|
||||
// back substitution
|
||||
for (int i = N-1; i >= 0; i--)
|
||||
for (int j = i+1; j < N; j++)
|
||||
{
|
||||
M[(N+1)*i + N] -= M[(N+1)*i + j] * M[(N+1)*j + N];
|
||||
M[(N+1)*i + j] = 0.0;
|
||||
int row = pivot_cache[i];
|
||||
int other_row = pivot_cache[j];
|
||||
M[N + row*N1] -= M[j + row*N1] * M[N + other_row*N1];
|
||||
M[j + row*N1] = 0.0;
|
||||
}
|
||||
|
||||
#ifdef LOG_MATRICES
|
||||
|
|
Loading…
Add table
Reference in a new issue