Lagrangian Duality¶
Preamble: Run the cells below to import the necessary Python packages
This notebook created by William Gilpin. Consult the course website for all content and GitHub repository for raw files and runnable online code.
# Import local plotting functions and in-notebook display functions
import matplotlib.pyplot as plt
%matplotlib inline
# Import linear algebra module
import numpy as np
We start by stating our optimization problem in the form of a minimization problem:
$$ \begin{aligned} \text{minimize} \quad & f(x) \\ \text{subject to} \quad & x \geq 0, \end{aligned} $$
where the univariate $f: \mathbb{R} \to \mathbb{R}$ is a convex function. Any value of $x$ that satisfies the constraint $x \geq 0$ is called a feasible solution. The set of all feasible solutions is called the feasible set.
As an example, below we choose the function
$$ f(x) = x^2 - 2x $$
x = np.linspace(-4, 4, 100)
f = lambda x: x**2 - 2 * x
plt.plot(x, f(x))
plt.xlabel('x')
plt.ylabel('f(x)')
Text(0, 0.5, 'f(x)')
We start by writing this problem in the form of a Lagrangian:
$$ \min_x L(x, \lambda) = f(x) - \lambda\, x. $$
Where $\lambda$ is the Lagrange multiplier. If $\lambda \geq 0$, then the solution to the optimization problem is given by taking the derivative of the Lagrangian with respect to $x$ and setting it equal to zero:
$$ \frac{d}{dx} L(x, \lambda) = f'(x) - \lambda = 0. $$
lagrange_primal = lambda x, lam: f(x) - lam * x
xx = np.linspace(-4, 4, 100)
lamvals = np.linspace(0, 4, 100)
# create uniform meshgrid
x_mesh, lam_mesh = np.meshgrid(xx, lamvals)
langrange_primal_values = lagrange_primal(x_mesh, lam_mesh)
plt.figure()
plt.scatter(x_mesh.ravel(), lam_mesh.ravel(),
c=langrange_primal_values.ravel(),
cmap='viridis'
)
plt.ylim(np.min(lamvals), np.max(lamvals))
plt.xlim(np.min(xx), np.max(xx))
plt.xlabel('x')
plt.ylabel('$\lambda$')
plt.colorbar(label='f(x) + lambda*x')
<matplotlib.colorbar.Colorbar at 0x15e265350>
The resulting relationship $\lambda(x) = f'(x)$ is the the KKT condition for the optimization problem. Notice that there is an implicit constraint on $\lambda$ that $\lambda \geq 0$. If $\lambda = 0$, then the solution to the original optimization problem is $x = 0$.
For our example, we have $f'(x) = 2x - 2$, so the KKT condition becomes $$ \lambda = 2x - 2. $$ with the constraint $\lambda \geq 0$.
# The KKT condition
lam_kkt = lambda x: 2 * x - 2
plt.figure()
plt.scatter(x_mesh.ravel(), lam_mesh.ravel(),
c=langrange_primal_values.ravel(),
cmap='viridis'
)
plt.plot(lamvals, lam_kkt(lamvals), '--r', dashes=(8, 3))
plt.ylim(np.min(lamvals), np.max(lamvals))
plt.xlim(np.min(xx), np.max(xx))
plt.xlabel('x')
plt.ylabel('$\lambda$')
Text(0, 0.5, '$\\lambda$')
Now, we seek to write the dual function. We revisit the Lagrangian and use the KKT to write it as a function of $\lambda$ only
$$ L(\lambda) = \min_x L(x, \lambda) = \min_x f(x) - \lambda x. $$
For our example, we have $f(x) = x^2 - 2x$, so the Lagrangian becomes $$ L(x, \lambda) = x^2 - 2x - \lambda x. $$
We insert our $KKT$ condition, $x = \frac{\lambda + 2}{2}$, into the Lagrangian to get the dual function:
$$ q(\lambda) = \left(\frac{\lambda + 2}{2}\right)^2 - (2 + \lambda)\left(\frac{\lambda + 2}{2}\right). $$ which reduces to, $$ q(\lambda) = -\frac{1}{4} (\lambda + 2)^2. $$
The dual optimization problem is then
$$ \begin{aligned} \text{maximize} \quad & q(\lambda) \\ \text{subject to} \quad & \lambda \geq 0. \end{aligned} $$
g = lambda lam: -(1 / 4) * (lam + 2)**2
plt.figure()
plt.plot(lamvals, g(lamvals))
plt.xlabel('$\lambda$')
plt.ylabel('Dual function g')
Text(0, 0.5, 'Dual function g')
We can see that we now have two formulations of our optimization problem.
The primal problem is to minimize $f(x)$ subject to $x \geq 0$. $$ \begin{aligned} \text{minimize} \quad & f(x) \\ \text{subject to} \quad & x \geq 0. \end{aligned} $$
The dual problem is to maximize $q(\lambda)$ subject to $\lambda \geq 0$. $$ \begin{aligned} \text{maximize} \quad & q(\lambda) \\ \text{subject to} \quad & \lambda \geq 0. \end{aligned} $$
Once we solve for $\lambda^*$, we can find the optimal $x^*$ by using the KKT condition, and vice versa.
import scipy.optimize
lam_opt = scipy.optimize.minimize(lambda lam: -g(lam), 0, bounds=[(0, 6)]).x[0]
x_from_kkt = (2 + lam_opt) / 2
print('Optimal lambda:', lam_opt)
print('Optimal x:', x_from_kkt)
x_opt = scipy.optimize.minimize(f, 0, bounds=[(0, 6)]).x[0]
lam_from_kkt = 2 * x_opt - 2
print('Optimal lambda:', lam_from_kkt)
print('Optimal x:', x_opt)
Optimal lambda: 0.0 Optimal x: 1.0 Optimal lambda: -1.829984008772101e-08 Optimal x: 0.99999999085008
The two approaches are equivalent, and the optimal values of the primal and dual problems are the same. This is known as strong duality.
When is duality broken?¶
Duality is broken when the primal problem is not convex. In this case, the KKT condition may not be sufficient to find the optimal solution. In general, the KKT condition is necessary but not sufficient for optimality. Let's try a different function,
$$ f(x) = g x^3 + x^2 - 2x. $$ The parameter $g$ is a constant that we can vary. For $g = 0$, the function is convex, but for all other values of $g$, the function is not convex. The Lagrangian is
$$ L(x, \lambda) = g x^3 + x^2 - 2x - \lambda x. $$
The KKT condition $\lambda = f'(x)$ becomes $$ \lambda = 3 g x^2 + 2x - 2. $$
Solving this expression for $x$ gives
$$ x = \frac{-1\pm\sqrt{3 g \lambda +6 g+1}}{3 g} $$
We insert this expression into the Lagrangian to get the dual function $q(\lambda)$.
$$ q(\lambda) = \frac{\left(\sqrt{3 g (\lambda +2)+1} \pm 1\right) \left(\pm 6 g (\lambda +2)+\sqrt{3 g (\lambda +2)+1}\pm1\right)}{27 g^2} $$
To resolve the $\pm$ ambiguity, we take the limit as $g \to 0$. The correct root is the one that becomes the dual function that we derived earlier for the convex case.
Let's plot the primal and dual functions to see what happens.
import numpy as np
import matplotlib.pyplot as plt
def f(x, g=0):
return g * x**3 + x**2 - 2*x
def L(x, lam):
return f(x) - lam * x
def q(lam, g=0.0001):
return ((-1 + np.sqrt(1 + 3*g*(2 +lam)))*(-1 - 6*g*(2 + lam) + np.sqrt(1 + 3*g*(2 + lam))))/(27.*g**2)
x = np.linspace(-1, 3, 100)
lam = np.linspace(-3, 3, 100)
plt.figure()
plt.plot(x, f(x), label='f(x)')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.figure()
plt.plot(lam, q(lam), label='q(lambda)')
plt.xlabel('x or lambda')
plt.ylabel('q(lambda)')
plt.legend()
<matplotlib.legend.Legend at 0x15cdd6a50>
Our two optimization problems remain the same:
The primal problem is to minimize $f(x)$ subject to $x \geq 0$. $$ \begin{aligned} \text{minimize} \quad & f(x) \\ \text{subject to} \quad & x \geq 0. \end{aligned} $$
The dual problem is to maximize $q(\lambda)$ subject to $\lambda \geq 0$. $$ \begin{aligned} \text{maximize} \quad & q(\lambda) \\ \text{subject to} \quad & \lambda \geq 0. \end{aligned} $$
We'll start with $g = 0$ to confirm that we get the same results as before.
import scipy.optimize
gval = 1e-8
lam_opt = scipy.optimize.minimize(lambda lam: -q(lam, g=gval), 0, bounds=[(0, 5)]).x[0]
x_from_kkt = (-1 + np.sqrt(1 + 6 * gval + 3 * gval * lam_opt))/(3. * gval)
print('Optimal lambda:', lam_opt)
print('Optimal x:', x_from_kkt)
# lagrange_primal = lambda x, lam: f(x) - lam * x
x_opt = scipy.optimize.minimize(lambda lam: f(lam, g=gval), 0, bounds=[(0, 5)]).x[0]
lam_from_kkt = 3 * gval * x_opt**2 + 2 * x_opt - 2
print('Optimal lambda:', lam_from_kkt)
print('Optimal x:', x_opt)
Optimal lambda: 0.0 Optimal x: 0.999999986521042 Optimal lambda: 5.9544609243289415e-09 Optimal x: 0.9999999879772308
The duality gap¶
The difference between the optimal values of the primal and dual problems is called the duality gap. When the duality gap is zero, the primal and dual problems are said to have strong duality. When the duality gap is positive, the primal and dual problems are said to have weak duality.
gvals = np.logspace(-14, 10, 50)
all_lam_opt_primal, all_x_opt_primal = [], []
all_lam_opt_dual, all_x_opt_dual = [], []
for gval in gvals:
lam_opt = scipy.optimize.minimize(lambda lam: -q(lam, g=gval), 0, bounds=[(0, 5)]).x[0]
x_from_kkt = (-1 + np.sqrt(1 + 6 * gval + 3 * gval * lam_opt))/(3. * gval)
all_lam_opt_primal.append(lam_opt)
all_x_opt_primal.append(x_from_kkt)
x_opt = scipy.optimize.minimize(lambda lam: f(lam, g=gval), 0, bounds=[(0, 5)]).x[0]
lam_from_kkt = 3 * gval * x_opt**2 + 2 * x_opt - 2
all_lam_opt_dual.append(lam_from_kkt)
all_x_opt_dual.append(x_opt)
plt.plot(gvals, all_lam_opt_primal, label='x primal')
plt.plot(gvals, all_lam_opt_dual, label='x dual')
# plt.
[<matplotlib.lines.Line2D at 0x15f5c9cd0>]