Added pivoting to qwp solver

This commit is contained in:
Clifford Wolf 2015-09-24 22:16:37 +02:00
parent 69071bbc5f
commit ec92c89659

View file

@ -255,33 +255,62 @@ struct QwpWorker
// (least squares fit for "A*x = y")
//
// Using gaussian elimination to get M := [Id x]
// (no pivoting, so let's hope for the best..)
// eliminate to upper triangular matrix
vector<int> pivot_cache;
vector<int> queue;
for (int i = 0; i < N; i++)
queue.push_back(i);
// gaussian elimination
for (int i = 0; i < N; i++)
{
// find best row
int best_row = queue.front();
int best_row_queue_idx = 0;
double best_row_absval = 0;
for (int k = 0; k < GetSize(queue); k++) {
int row = queue[k];
double absval = fabs(M[i + row*N1]);
if (absval > best_row_absval) {
best_row = row;
best_row_queue_idx = k;
best_row_absval = absval;
}
}
int row = best_row;
pivot_cache.push_back(row);
queue[best_row_queue_idx] = queue.back();
queue.pop_back();
// normalize row
for (int j = i+1; j < N+1; j++)
M[(N+1)*i + j] /= M[(N+1)*i + i];
M[(N+1)*i + i] = 1.0;
for (int k = i+1; k < N1; k++)
M[k + row*N1] /= M[i + row*N1];
M[i + row*N1] = 1.0;
// elimination
for (int j = i+1; j < N; j++) {
double d = M[(N+1)*j + i];
for (int k = 0; k < N+1; k++)
if (k > i)
M[(N+1)*j + k] -= d*M[(N+1)*i + k];
else
M[(N+1)*j + k] = 0.0;
for (int other_row : queue) {
double d = M[i + other_row*N1];
for (int k = i+1; k < N1; k++)
M[k + other_row*N1] -= d*M[k + row*N1];
M[i + other_row*N1] = 0.0;
}
}
log_assert(queue.empty());
log_assert(GetSize(pivot_cache) == N);
// back substitution
for (int i = N-1; i >= 0; i--)
for (int j = i+1; j < N; j++)
{
M[(N+1)*i + N] -= M[(N+1)*i + j] * M[(N+1)*j + N];
M[(N+1)*i + j] = 0.0;
int row = pivot_cache[i];
int other_row = pivot_cache[j];
M[N + row*N1] -= M[j + row*N1] * M[N + other_row*N1];
M[j + row*N1] = 0.0;
}
#ifdef LOG_MATRICES