Trivial ordinary least squares in C++ & Armadillo

Trivial ordinary least squares in C++ & Armadillo

Using plain C++ for complex numerical calculations can be a lot of work. Luckily, some libraries can do much of the heavy lifting. I'm going to show you how to use Armadillo to solve simple ordinary least squares (OLS) fitting.

I'm assuming that you have the Armadillo library installed on your system and that it is available in the library search path for g++.

Ordinary least squares (OLS)

We consider the following simple linear expression

\[ y = ax + n \]

where \(a\) is a fixed unknown coefficient, \(x\) is the input, \(n\) is additive zero-mean noise and \(y\) is the observation. Given a set of sample points \(\mathbf{x}= [x_1, x_2, \ldots, N]\), and the corresponding observations \(\mathbf{y} = [y_1, y_2, \ldots, N]\), we want to minimize the error

\[ \|\mathbf{y} - (a\mathbf{x} + \mathbf{n}) \| \]

To do this, we simply find the zero gradient point in terms of \(a\) ignoring the noise

\[ \|\mathbf{y} - (a\mathbf{x}) \| = \mathbf{y}^T\mathbf{y} - 2a \mathbf{x}^T\mathbf{y} + a^2\mathbf{x}^T\mathbf{x} \]

has gradient

\[ {d \over da} \|\mathbf{y} - a\mathbf{x} \| =  -2 \mathbf{x}^T\mathbf{y} + 2a\mathbf{x}^T\mathbf{x} \]

Setting this to zero gives us

\[-2 \mathbf{x}^T\mathbf{y} + 2a\mathbf{x}^T\mathbf{x} = 0 \Leftrightarrow a = (\mathbf{x}^T\mathbf{x})^{-1}\mathbf{x}^T\mathbf{y}\]

This becomes our least squares estimate of \(a\)

Solution in C++

We will write an algorithm that performs the OLS estimates for different numbers of samples and outputs the results in a JSON format for later processing.

Along with some C++ standard libraries, we need to include the Armadillo headers

#include <armadillo>
#include <cmath>
#include <iostream>

To perform OLS for a given set of samples and observations, we write a separate function, which takes the samples and observations as Armadillo matrices arma::mat and returns the estimate.


/**
 Perform ordinary least squares estimation for given sample vectors

 \param x Sample points
 \param y Samples

 \return estimate of the samples
 */
inline arma::mat OLS(arma::mat x, arma::mat y) {
    return arma::inv(x.t() * x)*x.t() * y;
}

In the main function, we simply generate the samples and compute the error average over 100 iterations (iterations). We try sample numbers (sampleNum) between 1 and 50.

The estimated parameter is a and the fixed noise variance is noiseVar.

/**
 * Entry point
 */
int main(int argc, char **argv) {
    // Input line a*x 
    const double a = 2; // Coefficient we are trying to estimate

    // Noise variance
    const double noiseVar = 2;

    // Number of samples to take
    const int samplesNum = 50;
    const int iterations = 100;

    std::cout << "{" << std::endl;
    std::cout << "\"error\": [" << std::endl;

    for (int s = 1; s <= samplesNum; s++) {
        arma::mat err = arma::ones(1,1);

        for (int i = 0; i < iterations; i++) {
            arma::mat noise = sqrt(noiseVar) * arma::randn(s, 1);

            // Sample the reference line with additive noise
            arma::mat x = arma::randn(s, 1);
            arma::mat y = a*x + noise; 

            arma::mat est = OLS(x, y);

            err += abs(est-a);
        }

        err /= iterations;

        std::cout << "[" << static_cast<double>(s) << "," << static_cast<double>(err[0]) << "]";
        if (s < samplesNum)
            std::cout << ',';
    }

    std::cout << "] }" << std::endl;

    return 0;
}

Save the source as sls.cc and compile using g++

g++ -Wall -Werror -Wall -O3 -std=c++11 sls.cc -larmadillo -o sls

To store the error data in data.json, run the program simply as

sls > data.json

The JSON data can then be plotted in Python

import json
import matplotlib.pyplot as plt

# Load JSON data
data = json.load(open('data.json'))

# Extract the data from JSON
samples = [x[0] for x in data['error']]
errors = [x[1] for x in data['error']]

# Plot the figure 
fig, ax = plt.subplots()
ax.plot(samples, errors)
ax.set(xlabel='Samples', ylabel='Error', title='Ordinary least squares error')
ax.grid()
plt.show()