import numpy as np
import pandas as pd
import scipy.optimize
import time
from itertools import combinations
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
# utility functions
def rand(a, sd):
return a + np.random.normal(0, sd, a.shape)
def sse(params, data):
sse = np.sum(np.power(model(params, data) - data['y'], 2))
return sse
# model definition
def pre(params, data):
intercept = params[0]
slope = params[1]
return intercept + (slope * data['x1'])
def post(params, data):
slope = params[2]
switchPoint = switch(params, data)
return (slope * (data['x1'] - switchPoint)) + (params[0] + (params[1] * switchPoint))
def switch(params, data):
intercept = params[3]
slope = params[4]
return intercept + (slope * data['z1'])
def model(params, data):
# params = [preIntercept, preSlope, postSlope, switchIntercept, switchSlope]
# enforce parameter constraints to boost optimization efficiency
bad = 9999999999999
# pre-slope must be negative
if params[1] >= 0:
return bad
# post-slope must be negative
if params[2] >= 0:
return bad
# switch has wrong shape
if params[1] <= params[2]:
return bad
switchPoint = switch(params, data)
ypre = pre(params, data)
ypost = post(params, data)
isPre = data['x1'] < switchPoint
y = (isPre * ypre) + ((1 - isPre) * ypost)
return y
This is a broken line model. It yields two straight lines connected
at a switch point. Before the switch point, the model is equivalent to pre()
, after the switch point the model is equivalent to post()
. The switch point itself is permitted to vary and is linearly related to z1. When the coefficient relating z1 to the switch point is positive, z1 can be considered a protective factor, with increasing z1 yielding later switch points. When the coefficient relating z1 to the switch point is negative, z1 can be considered a risk factor, with increasing z1 yielding earlier switch points.
Let's visualize what this model looks like.
# preIntercept, preSlope, postSlope, switchIntercept, switchSlope
params = [100, -.1, -1, 45, .5]
zs = np.arange(60,90,5)
plt.figure(figsize=(10,8))
plt.gca().set_facecolor('xkcd:gray')
plt.gca().set_prop_cycle('color',plt.cm.hot(np.linspace(0,1,len(zs))))
for z in zs:
x1 = pd.Series(range(60,90,1), name='x1')
z1 = pd.Series(z * np.ones(len(x1)), name='z1')
df = pd.concat([x1, z1], axis=1)
plt.plot(df['x1'],model(params, df), label='z='+str(z), linewidth=4)
plt.xlabel('x1', fontsize=18)
plt.ylabel('y', fontsize=18)
plt.gca().legend(prop={'size': 16}).get_frame().set_facecolor('xkcd:gray')
plt.show()
Here we can see that there is a general decline before the switch point that increases dramatically after the switch point. The switch point is positively related to z1, so those with lower values of z1 experience this post-switch decline much earlier than those with higher values of z1.
Let's generate some synthetic data and see if we can recover the parameter values.
# generate some synthetic data
params = [100, -.1, -1, 45, .5]
nDataPoints = 1000
mn = 60
mx = 90
df = pd.DataFrame(data={'x1': np.random.randint(mn, high=mx, size=nDataPoints), \
'z1': np.random.randint(mn, high=mx, size=nDataPoints)})
noise = 0.9
#tempY = df.apply(model, args=(params,), axis=1)
df['y'] = rand(model(params, df).values, noise)
df['x1'] = rand(df.x1.values, noise)
df['z1'] = rand(df.z1.values, noise)
# plot the simulated data
plt.figure(figsize=(10,8))
plt.title('Broken Line Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, c=df['z1'], alpha=0.85, cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
plt.show()
# optimize
paramsInit = [0, -1, -2, 0, 0]
ranges = (slice(-100, 100, 30), slice(-10, 10, 3), slice(-10, 10, 3), slice(-100, 100, 10), slice(-10, 10, 1))
bounds = [(-200,200), (-100,100), (-100,100), (-200,200), (-100,100)]
bounds = [(0,200), (-1,1), (-1,1), (0,200), (-1,1)]
bounds = [(-200,200), (-10,10), (-10,10), (-200,200), (-10,10)]
start = time.time()
# hand the sse function to an optimization routine
#result = scipy.optimize.minimize(sse, paramsInit, method='BFGS', args=df)
#result = scipy.optimize.minimize(sse, paramsInit, method='Nelder-Mead', args=df)
#result = scipy.optimize.brute(sse, ranges, args=(df,))
#result = scipy.optimize.basinhopping(sse, paramsInit, minimizer_kwargs={'args':df}, T=10, disp=False)
#result = scipy.optimize.basinhopping(sse, paramsInit, minimizer_kwargs={'args':df})
#result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), mutation=(.5,1.9), popsize=50)
#result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), popsize=50)
result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), popsize=50)
end = time.time()
print(f'Optimization took: {end - start:.2f}s')
print(result.message)
print('True parameter values:\t'+', '.join('{:.2f}'.format(f) for f in params))
print('Estimates:\t\t'+', '.join('{:.2f}'.format(f) for f in result.x))
print('Differences:\t\t'+', '.join('{:.2f}'.format(f) for f in (params - result.x)))
print('% Differences:\t\t'+', '.join('{:.1f}%'.format(f) for f in 100*(params - result.x)/params))
print(f'SSE:\t\t\t{result.fun}')
Things look good so far. When calling the differential_evolution
optimization routine, note that I cranked up the population size to 50.
This is probably overkill (default is 15), but I wanted the estimates
to be solid for this initial demonstration. Subsequent fits will
probably use the default population size (because the population size
strongly influences how long the optimization takes), but keep in mind
this important parameter.
Now let's visualize the model's behavior with the estimated parameter values and compare it to the simulated data.
# plot simulated data
plt.figure(figsize=(10,8))
plt.title('Broken Line Fit to Broken Line Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, c=df['z1'], alpha=0.85, cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
# plot model fit
xs = np.arange(df['x1'].min(), df['x1'].max())
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].min()})), color='k', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].quantile(q=.5)})), color='r', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].max()})), color='w', linewidth=4)
plt.show()
# plot residuals
plt.figure(figsize=(8,6))
plt.title('Residuals', fontsize=20)
plt.scatter(df['x1'], model(result.x, df)-df['y'], color='k', alpha=.5)
plt.xlabel('x', fontsize=16)
plt.axhline(y=0, color='k')
plt.ylabel('residual', fontsize=16)
plt.show()
Looks good. Can't even see where the switch point is among the residuals. That's a good sign.
Let's see whether the broken line model fits the data better than a line and, if so, how much better.
def plainOldLinear(params, data):
intercept = params[0]
slopeX = params[1]
slopeZ = params[2]
return intercept + (slopeX * data['x1']) + (slopeZ * data['z1'])
def ssePlainLinear(params, data):
sse = np.sum(np.power(plainOldLinear(params, data) - data['y'], 2))
return sse
bounds = [(-200,200), (-10,0), (0,10)]
start = time.time()
result = scipy.optimize.differential_evolution(ssePlainLinear, bounds, args=(df,), popsize=50)
end = time.time()
print(f'Optimization took: {end - start:.2f}s')
print(result.message)
print('Estimates:\t\t'+', '.join('{:.2f}'.format(f) for f in result.x))
print(f'SSE:\t\t\t{result.fun:.2f}')
# plot simulated data
plt.figure(figsize=(10,8))
plt.title('Linear Fit to Broken Line Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, c=df['z1'], alpha=0.85, cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
# plot model fit
xs = np.arange(df['x1'].min(), df['x1'].max())
plt.plot(xs, plainOldLinear(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].min()})), color='k', linewidth=4)
plt.plot(xs, plainOldLinear(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].quantile(q=.5)})), color='r', linewidth=4)
plt.plot(xs, plainOldLinear(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].max()})), color='w', linewidth=4)
plt.show()
The best fitting lines are...lines. The upper is the line estimated for those with the largest values of z1 (the bright points) and the bottom is the line estimated for those with the smallest values of z1 (the dark points). The middle line is the that associated with the median value of z1. This seems reasonable.
# Compare fits of broken line and pure linear models
sseBroken = 1048.48
sseLinear = 3779.56
sseTotal = np.sum(np.power(df['y'], 2))
print(f'Total SSE: {sseTotal:,.2f}')
Simple linear model is much better than an intercept-only model, but substantially worse than the broken line model. Just as hoped!
# generate some synthetic data
params = [100, -.1]
nDataPoints = 1000
mn = 60
mx = 90
df = pd.DataFrame(data={'x1': np.random.randint(mn, high=mx, size=nDataPoints), \
'z1': np.random.randint(mn, high=mx, size=nDataPoints)})
noise = 0.75
df['y'] = rand(params[0]+(df['x1']*params[1]), noise)
df['x1'] = rand(df.x1.values, noise)
df['z1'] = rand(df.z1.values, noise)
# plot the simulated data
plt.figure(figsize=(10,8))
plt.title('Linear Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, alpha=0.85, c=df['z1'], cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
plt.show()
# optimize
bounds = [(-200,200), (-10,10), (-10,10), (-200,200), (-10,10)]
start = time.time()
# hand the sse function to an optimization routine
result = scipy.optimize.differential_evolution(sse, bounds, args=(df,), popsize=25)
end = time.time()
print(f'Optimization took: {end - start:.2f}s')
print(result.message)
print('Estimates:\t\t'+', '.join('{:.2f}'.format(f) for f in result.x))
print(f'SSE:\t\t\t{result.fun:.2f}')
# plot simulated data
plt.figure(figsize=(10,8))
plt.title('Broken Line Fit to Linear Data', fontsize=20)
plt.scatter(df['x1'], df['y'], s=20, alpha=0.85, c=df['z1'], cmap='hot')
plt.xlabel('x1', fontsize=16)
plt.ylabel('y', fontsize=16)
plt.gca().set_facecolor('xkcd:gray')
# plot model fit
xs = np.arange(df['x1'].min(), df['x1'].max())
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].min()})), color='k', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].quantile(q=.5)})), color='k', linewidth=4)
plt.plot(xs, model(result.x, pd.DataFrame(data={'x1': xs, 'z1': np.ones(len(xs))*df['z1'].max()})), color='k', linewidth=4)
plt.show()