In [1]:
import numpy as np
import pandas as pd
import scipy.optimize
import time
from itertools import combinations
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
In [2]:
# utility functions
def rand(a, sd):
    return a + np.random.normal(0, sd, a.shape)

def sse(params, data):
    sse = np.sum(np.power(model(params, data) - data['y'], 2))
    return sse

Broken Line Model Definition

In [3]:
# model definition

def pre(params, data):
    intercept = params[0]
    slope = params[1]
    return intercept + (slope * data['x1'])


def post(params, data):
    slope = params[2]
    switchPoint = switch(params, data)
    return (slope * (data['x1'] - switchPoint)) + (params[0] + (params[1] * switchPoint))


def switch(params, data):
    intercept = params[3]
    slope = params[4]
    return intercept + (slope * data['z1'])


def model(params, data):
    # params = [preIntercept, preSlope, postSlope, switchIntercept, switchSlope]

    # enforce parameter constraints to boost optimization efficiency
    bad = 9999999999999
    # pre-slope must be negative
    if params[1] >= 0:
        return bad
    # post-slope must be negative
    if params[2] >= 0:
        return bad
    # switch has wrong shape
    if params[1] <= params[2]:
        return bad

    switchPoint = switch(params, data)
    ypre = pre(params, data)
    ypost = post(params, data)
    isPre = data['x1'] < switchPoint
    y = (isPre * ypre) + ((1 - isPre) * ypost)
    return y

This is a broken line model. It yields two straight lines connected at a switch point. Before the switch point, the model is equivalent to pre(), after the switch point the model is equivalent to post(). The switch point itself is permitted to vary and is linearly related to z1. When the coefficient relating z1 to the switch point is positive, z1 can be considered a protective factor, with increasing z1 yielding later switch points. When the coefficient relating z1 to the switch point is negative, z1 can be considered a risk factor, with increasing z1 yielding earlier switch points.

Let's visualize what this model looks like.

In [4]:
# preIntercept, preSlope, postSlope, switchIntercept, switchSlope
params = [100, -.1, -1, 45, .5]

zs = np.arange(60,90,5)
plt.figure(figsize=(10,8))
plt.gca().set_facecolor('xkcd:gray')
plt.gca().set_prop_cycle('color',plt.cm.hot(np.linspace(0,1,len(zs))))

for z in zs:
    x1 = pd.Series(range(60,90,1), name='x1')
    z1 = pd.Series(z * np.ones(len(x1)), name='z1')
    df = pd.concat([x1, z1], axis=1)
    plt.plot(df['x1'],model(params, df), label='z='+str(z), linewidth=4)

plt.xlabel('x1', fontsize=18)
plt.ylabel('y', fontsize=18)
plt.gca().legend(prop={'size': 16}).get_frame().set_facecolor('xkcd:gray')
plt.show()

Here we can see that there is a general decline before the switch point that increases dramatically after the switch point. The switch point is positively related to z1, so those with lower values of z1 experience this post-switch decline much earlier than those with higher values of z1.

Simulated Data

Let's generate some synthetic data and see if we can recover the parameter values.

In [79]:
# generate some synthetic data

params = [100, -.1, -1, 45, .5]

nDataPoints = 1000
mn = 60
mx = 90
df = pd.DataFrame(data={'x1': np.random.randint(mn, high=mx, size=nDataPoints), \
                        'z1': np.random.randint(mn, high=mx, size=nDataPoints)})
noise = 0.9
#tempY = df.apply(model, args=(params,), axis=1)
df['y'] = rand(model(params, df).values, noise)
df['x1'] = rand(df.x1.values, noise)
df['z1'] = rand(df.z1.values, noise)
In [81]:
# plot the simulated data
plt.figure(figsize=(10,8))
plt.title('Broken Line Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, c=df['z1'], alpha=0.85, cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
plt.show()
In [82]:
# optimize
paramsInit = [0, -1, -2, 0, 0]
ranges = (slice(-100, 100, 30), slice(-10, 10, 3), slice(-10, 10, 3), slice(-100, 100, 10), slice(-10, 10, 1))
bounds = [(-200,200), (-100,100), (-100,100), (-200,200), (-100,100)]
bounds = [(0,200), (-1,1), (-1,1), (0,200), (-1,1)]
bounds = [(-200,200), (-10,10), (-10,10), (-200,200), (-10,10)]

start = time.time()
# hand the sse function to an optimization routine
#result = scipy.optimize.minimize(sse, paramsInit, method='BFGS', args=df)
#result = scipy.optimize.minimize(sse, paramsInit, method='Nelder-Mead', args=df)
#result = scipy.optimize.brute(sse, ranges, args=(df,))
#result = scipy.optimize.basinhopping(sse, paramsInit, minimizer_kwargs={'args':df}, T=10, disp=False)
#result = scipy.optimize.basinhopping(sse, paramsInit, minimizer_kwargs={'args':df})
#result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), mutation=(.5,1.9), popsize=50)
#result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), popsize=50)
result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), popsize=50)
end = time.time()
print(f'Optimization took: {end - start:.2f}s')
Optimization took: 75.37s
In [83]:
print(result.message)

print('True parameter values:\t'+', '.join('{:.2f}'.format(f) for f in params))
print('Estimates:\t\t'+', '.join('{:.2f}'.format(f) for f in result.x))
print('Differences:\t\t'+', '.join('{:.2f}'.format(f) for f in (params - result.x)))
print('% Differences:\t\t'+', '.join('{:.1f}%'.format(f) for f in 100*(params - result.x)/params))
print(f'SSE:\t\t\t{result.fun}')
Optimization terminated successfully.
True parameter values:	100.00, -0.10, -1.00, 45.00, 0.50
Estimates:		100.03, -0.10, -0.93, 45.14, 0.50
Differences:		-0.03, -0.00, -0.07, -0.14, 0.00
% Differences:		-0.0%, 0.1%, 7.1%, -0.3%, 0.8%
SSE:			1076.5584354106873

Things look good so far. When calling the differential_evolution optimization routine, note that I cranked up the population size to 50. This is probably overkill (default is 15), but I wanted the estimates to be solid for this initial demonstration. Subsequent fits will probably use the default population size (because the population size strongly influences how long the optimization takes), but keep in mind this important parameter.

Now let's visualize the model's behavior with the estimated parameter values and compare it to the simulated data.

In [85]:
# plot simulated data
plt.figure(figsize=(10,8))
plt.title('Broken Line Fit to Broken Line Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, c=df['z1'], alpha=0.85, cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')

# plot model fit
xs = np.arange(df['x1'].min(), df['x1'].max())
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].min()})), color='k', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].quantile(q=.5)})), color='r', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].max()})), color='w', linewidth=4)

plt.show()
In [88]:
# plot residuals
plt.figure(figsize=(8,6))
plt.title('Residuals', fontsize=20)
plt.scatter(df['x1'], model(result.x, df)-df['y'], color='k', alpha=.5)
plt.xlabel('x', fontsize=16)
plt.axhline(y=0, color='k')
plt.ylabel('residual', fontsize=16)
plt.show()

Looks good. Can't even see where the switch point is among the residuals. That's a good sign.

Let's see whether the broken line model fits the data better than a line and, if so, how much better.

In [90]:
def plainOldLinear(params, data):
    intercept = params[0]
    slopeX = params[1]
    slopeZ = params[2]
    return intercept + (slopeX * data['x1']) + (slopeZ * data['z1'])

def ssePlainLinear(params, data):
    sse = np.sum(np.power(plainOldLinear(params, data) - data['y'], 2))
    return sse

bounds = [(-200,200), (-10,0), (0,10)]

start = time.time()
result = scipy.optimize.differential_evolution(ssePlainLinear, bounds, args=(df,), popsize=50)
end = time.time()
print(f'Optimization took: {end - start:.2f}s')
Optimization took: 5.51s
In [91]:
print(result.message)

print('Estimates:\t\t'+', '.join('{:.2f}'.format(f) for f in result.x))
print(f'SSE:\t\t\t{result.fun:.2f}')
Optimization terminated successfully.
Estimates:		102.72, -0.24, 0.09
SSE:			3614.32
In [92]:
# plot simulated data
plt.figure(figsize=(10,8))
plt.title('Linear Fit to Broken Line Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, c=df['z1'], alpha=0.85, cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')

# plot model fit
xs = np.arange(df['x1'].min(), df['x1'].max())
plt.plot(xs, plainOldLinear(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].min()})), color='k', linewidth=4)
plt.plot(xs, plainOldLinear(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].quantile(q=.5)})), color='r', linewidth=4)
plt.plot(xs, plainOldLinear(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].max()})), color='w', linewidth=4)

plt.show()

The best fitting lines are...lines. The upper is the line estimated for those with the largest values of z1 (the bright points) and the bottom is the line estimated for those with the smallest values of z1 (the dark points). The middle line is the that associated with the median value of z1. This seems reasonable.

In [93]:
# Compare fits of broken line and pure linear models

sseBroken = 1048.48
sseLinear = 3779.56
sseTotal = np.sum(np.power(df['y'], 2))
print(f'Total SSE: {sseTotal:,.2f}')
Total SSE: 8,414,232.26

Simple linear model is much better than an intercept-only model, but substantially worse than the broken line model. Just as hoped!

Fit model(s) to synthetic linear data

In [94]:
# generate some synthetic data

params = [100, -.1]

nDataPoints = 1000
mn = 60
mx = 90
df = pd.DataFrame(data={'x1': np.random.randint(mn, high=mx, size=nDataPoints), \
                        'z1': np.random.randint(mn, high=mx, size=nDataPoints)})
noise = 0.75
df['y'] = rand(params[0]+(df['x1']*params[1]), noise)
df['x1'] = rand(df.x1.values, noise)
df['z1'] = rand(df.z1.values, noise)
In [96]:
# plot the simulated data
plt.figure(figsize=(10,8))
plt.title('Linear Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, alpha=0.85, c=df['z1'], cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
plt.show()
In [97]:
# optimize
bounds = [(-200,200), (-10,10), (-10,10), (-200,200), (-10,10)]

start = time.time()
# hand the sse function to an optimization routine
result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), popsize=25)
end = time.time()
print(f'Optimization took: {end - start:.2f}s')
Optimization took: 8.77s
In [98]:
print(result.message)

print('Estimates:\t\t'+', '.join('{:.2f}'.format(f) for f in result.x))
print(f'SSE:\t\t\t{result.fun:.2f}')
Optimization terminated successfully.
Estimates:		99.84, -0.10, -2.33, 71.79, 4.54
SSE:			568.67
In [99]:
# plot simulated data
plt.figure(figsize=(10,8))
plt.title('Broken Line Fit to Linear Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, alpha=0.85, c=df['z1'], cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')

# plot model fit
xs = np.arange(df['x1'].min(), df['x1'].max())
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].min()})), color='k', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].quantile(q=.5)})), color='k', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].max()})), color='k', linewidth=4)

plt.show()