Hyper Parameter Optimization with Optuna

In this short notebook we’ll have a brief introduction to Optuna, a python package for performing efficient hyper parameter optimization.

Optuna treats the problem to be optimized as a black box, so in many cases it’s very similar to machine learning.

It’s completely agnostic to what framework we’re using to implement our learning algorithms, and doesn’t even depend on learning algorithms at all as we will see.

Running this file on Google Colab

In this workshop we will be using Google Colab. To open this notebook in colab use this link

You need to save a local copy of this notebook to your own Google Drive if you want the changes to be saved.

Installing Optuna

If you are running this notebook on Google Colab, you will need to install Optuna first. Uncomment the line below and run the cell. There are variuous ways of installing Optuna, since this is a Colab notebook, we’ll do it using pip.

#!pip install optuna > /dev/null
import optuna
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 import optuna

ModuleNotFoundError: No module named 'optuna'

What’s your favorite color?

As a simple example of how optuna can be used, we’ll start with showing how it can be used to optimize essentially any thing. In this case we’ll try to find a color which is optimal to our taste.

The study object

The optimizations in Optuna are centered around a Study object. This object is essentially the foundation of our whole hyper parameter experiment and is the main interface you interact with.

Most parameters for creating studies are given to the function create_study. At the very least, we have to tell optuna whether to minimize or maximize.

study = optuna.create_study(direction='maximize')
[I 2022-12-05 09:19:56,346] A new study created in memory with name: no-name-dc82f9dc-e6e5-41d5-9e0e-b11de38e4559

Finding the optimal color

This toy example uses IPython widgets to give us a simple interface we can interact with.

The goal is to search over three hyper parameters: red, green and blue and find and maximize the score you as the human assigns the color. Your task is to rate each color on a continuous scale from 0-1.

The trial object

A trial in Optuna is the retrieval of a single observation/data point. We can think of Trial objects as the contexts of an experiment, it’s what we use to suggest hyper parameter values and as a reference when we register the result of the experiment.

Optuna has two main interfaces for instantiating trials:

  • study.ask()/study.tell(): Useful for manually creating and managing trials, especially for interactive things.

  • study.optimize(): Useful to automate the search procedure, especially for parallell batch tasks.

In this tutorial we will use the ask/tell interface to get to the basics.

We create a Trial object by calling study.ask()

# Create a new trial object
trial = study.ask()

Suggest hyper parameter values

Optuna is a very flexible framework which doesn’t require us to declare the search space beforehand. Instead new dimensions in the parameter space are dynamically added when we call one of the suggest_[]() methods of a Trial object. The hyper parameter is recognized by the name parameter to the suggest function. By reusing the same name, the underlying sampler builds a model over how the value associated with each names relates to the optimization goal.

Supported variable types

This is the main method which you will interact with when performing hyper parameter optimization. Here is a list of the supported (non-deprecated) variable types:

  • suggest_categorical(name, choices): Suggest a value from a list of discrete objects. Very flexible interface.

  • suggest_float(name, low, high, *[, step, log]): Suggest a floating point value. The range is inclusive, so both endpoints can be sampled.

  • suggest_int(name, low, high[, step, log]): Suggest an integer value. The range is inclusive.

In this example, we will model the red, greenand blue components as floating points values in the range $[0,1]$, so we will use suggest_float(name, 0, 1).

(detour) Sampling from the log space

Both suggest_float and suggest_int have a log flag. If this is set to True, the sampling is performed in a logarithm-transformed space instead. Roughly, this means that the probability mass is skewed so that in e.g. a range of $[10^{-4}, 0.1]$ you have as much probability mass between $[10^{-4}, 10^{-3}]$ as between $[10^2, 0.1]$. This is very useful for hyper parameters such as learning rate, where we essentially want to sample an order of magnitude.

# We sample each component from the range [0,1].
r = trial.suggest_float('red', 0, 1)
g = trial.suggest_float('green', 0, 1)
b = trial.suggest_float('blue', 0, 1)

Interactive optimization

We’re now ready to run the optimization. We have a small IPython Widgets interface which we set up in the code below. The important code can be found in the rate_color function, which is a callback we associate with the button in the interface.

# This cell mainly contains convenience functions for setting up the widgets

import ipywidgets as widgets

def rgb_to_hex(r, g, b):
  return "#{0:02x}{1:02x}{2:02x}".format(int(r*255), int(g*255), int(b*255))

def hex_to_rgb(hex):
  r = hex[1:3]
  g = hex[3:5]
  b = hex[5:7]
  return (int(r, base=16)/255, int(g, base=16)/255, int(b, base=16)/255)

# Header label
label = widgets.Label(value="Assign a rating, then click the color to update")

color_viewer = update_button = widgets.Button(
    description='',
    disabled=True,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip=rgb_to_hex(r,g,b),
)
color_viewer.style.button_color = rgb_to_hex(r,g,b)

# The float slider used to give a rating
rating_widget = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.001,
    description='Rating:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
)

# The update button widget
update_button = widgets.Button(
    description='Rate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='How do you like my color?',
)


# The call back for the optimization. Called every time the user presses 
# the "Update" button
def rate_color(button):
  # We have to refer to the widget and trial variables using global.
  global r,g,b, trial, rating_widget
  # We take the value the user has supplied to the rating widget
  rating = rating_widget.value
  
  # This is where we tell Optuna what the value of
  # the goal function (the rating) is for this
  # trial. The trial object knows what values 
  # it has suggested.
  study.tell(trial, rating)

  # We create a new trial object from the study which will be used next time
  # rate_color() is called
  trial = study.ask()

  # We ask the new trial object to suggest new values.
  r = trial.suggest_float('red', 0, 1)
  g = trial.suggest_float('green', 0, 1)
  b = trial.suggest_float('blue', 0, 1)

  # Show the newly picked color by setting the color picker to the value
  color_viewer.style.button_color = rgb_to_hex(r,g,b)
  color_viewer.tooltip = rgb_to_hex(r,g,b)


# bind the call-back to the rate_color() function
update_button.on_click(rate_color)


widgets.VBox([label, color_viewer, rating_widget, update_button])

Analyzing the study

Once we’re done with the above we can analyze the results of the study. In particular we might wish to find the best performing color. The study maintains two attributes of interest: best_params and best_value.

best_params is a dictionary which maps each parameter name used in the trials to the value suggested by the best trial.

best_color = study.best_params
best_value = study.best_value
r,g,b = best_color['red'], best_color['green'], best_color['blue']
# The update button widget
button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r,g,b)
)
button.style.button_color = rgb_to_hex(r,g,b)
widgets.VBox([button, widgets.Label(value=f"Value: {best_value}")])

We can also get all of the trial data as a Pandas DataFrame with the study.trials_dataframe() method. This assumes pandas is installed. Why do you think there is an odd looking last line in the DataFrame?

study.trials_dataframe()
number value datetime_start datetime_complete duration params_blue params_green params_red state
0 0 0.500 2022-12-05 09:20:00.100308 2022-12-05 09:20:08.715630 0 days 00:00:08.615322 0.102017 0.924562 0.946926 COMPLETE
1 1 0.350 2022-12-05 09:20:08.715996 2022-12-05 09:20:10.873133 0 days 00:00:02.157137 0.640186 0.802962 0.585927 COMPLETE
2 2 0.968 2022-12-05 09:20:10.873389 2022-12-05 09:20:15.427273 0 days 00:00:04.553884 0.856165 0.189297 0.033233 COMPLETE
3 3 NaN 2022-12-05 09:20:15.427512 NaT NaT 0.402359 0.036989 0.697580 RUNNING

Manually inspecting trials

We can also manually go through the Trial objects to summarize results using the study.get_trials() method.

summarized_params = []
for trial in study.get_trials():
  # Note that the trial here is a "FrozenTrial", which is not the same as the 
  # trials we got from study.ask(). FrozenTrial doesn't do any sampling and 
  # contains the optimization value result
  # We use the trials state attribute to check whether it's finnished.
  if trial.state.is_finished():
    r, g, b = trial.params['red'], trial.params['green'], trial.params['blue']
    trial_result = {'Optimization value': trial.value, 'red': r, 'green': g, 'blue': b}
    summarized_params.append(trial_result)
summarized_params
[{'Optimization value': 0.5,
  'red': 0.9469259095595486,
  'green': 0.9245623666310328,
  'blue': 0.10201741089324834},
 {'Optimization value': 0.35,
  'red': 0.5859267350402545,
  'green': 0.8029621088484467,
  'blue': 0.6401860069072317},
 {'Optimization value': 0.968,
  'red': 0.03323309482501857,
  'green': 0.18929742355225276,
  'blue': 0.8561651114296673}]

Visual Inspection

Optuna has a multitude of built in methods to visually inspect the results of the optimization. Here we’ll take a look at the most important ones. You can find the documentation of all visualization methods in optuna.visualization

Slice plot

We’re often interested in trying to understand how a single hyper parameter interacts with the objective value. The slice plot allows us to inspect each hyper parameter value in isolation plotted against the objective value. This hides a lot of potential interactions between parameters, but is a good way to narrow down search spaces per parameter.

from optuna.visualization import plot_slice
plot_slice(study)

Contour plot

When the number of hyper parameters is relatively low, we can look at the response surface the study has learned over our problem. The contour plot shows us the fitted model values over pairs of hyper parameters. Note that the color of the contour plot has nothing to do with the colors we picked in this example.

A contour plot allows us to understand how the underlying study model works, and can help us spot weird interactions of pairs of parameters which can help us interpreting the slice plot above.

# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
plot_contour(study)
# If we only wan't to see the relationship between some specific variables we can supply a list of them
plot_contour(study, ["red", "green"])

Parallel coordinate plot

While the slice plot and contour plot allows us to gauge how up to two variables interact, any more is very difficult to visually in a regular plot. A common visualization tool when we wish to compare interactions of more than two variables is a parallel coordinate plot.

Instead of visualizing the parameters as the axises in a cartesian space, it places all parameters and the objective value parallel to each other. Each trial is connected over the parameter axises with a line. The lines is colored by the objective value.

This plot can be useful to visually gauge overall trends of how many parameters affect the objective value.

from optuna.visualization import plot_parallel_coordinate
plot_parallel_coordinate(study)
Other Parallel coordinate plotting tools

While the built in tool in Optuna is useful, the HiPlot package from Facebook Research is much more powerful. Here’s a short example.

!pip install -U hiplot > /dev/null
import hiplot

# HiPlot supports many interfaces for reading data, but here we use an iterable 
# of dictionary records. Each record is a mapping from coordinate to value.

summarized_params = []
for trial in study.get_trials():
  # Note that the trial here is a "FrozenTrial", which is not the same as the 
  # trials we got from study.ask(). FrozenTrial doesn't do any sampling and 
  # contains the optimization value result
  if trial.state.is_finished():
    r, g, b = trial.params['red'], trial.params['green'], trial.params['blue']
    trial_result = {'Optimization value': trial.value, 'red': r, 'green': g, 'blue': b}
    summarized_params.append(trial_result)

hiplot.Experiment.from_iterable(summarized_params).display()
HiPlot
Loading HiPlot...
<hiplot.ipython.IPythonExperimentDisplayed at 0x7f1dbe9677c0>

Exercise

In the code below, extend the example above to optimize how well two colors match. The places to edit are marked with YOUR CODE HERE

import optuna
import ipywidgets as widgets

def rgb_to_hex(r, g, b):
  return "#{0:02x}{1:02x}{2:02x}".format(int(r*255), int(g*255), int(b*255))

def hex_to_rgb(hex):
  r = hex[1:3]
  g = hex[3:5]
  b = hex[5:7]
  return (int(r, base=16)/255, int(g, base=16)/255, int(b, base=16)/255)


study = optuna.create_study(direction='maximize')
trial = study.ask()

r1 = YOUR CODE HERE
g1 = YOUR CODE HERE
b1 = YOUR CODE HERE

r2 = YOUR CODE HERE
g2 = YOUR CODE HERE
b2 = YOUR CODE HERE

# Header label
label = widgets.Label(value="Assign a rating to these pairs of colors, then click the color to update")

color_viewer1 = update_button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r1,g1,b1),
)
color_viewer1.style.button_color = rgb_to_hex(r1,g1,b1)

color_viewer2 = update_button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r2,g2,b2),
)
color_viewer2.style.button_color = rgb_to_hex(r2,g2,b2)


# The float slider used to give a rating
rating_widget = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.001,
    description='Rating:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
)

# The update button widget
update_button = widgets.Button(
    description='Rate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='How do you like the colors?',
)


# The call back for the optimization. Called every time the user presses 
# the "Update" button
def rate_color(button):
  # We have to refer to the widget and trial variables using global.
  global r1,g1,b1, r1, b2, g2, trial, rating_widget
  # We take the value the user has supplied to the rating widget
  rating = rating_widget.value
  
  # This is where we tell Optuna what the value of
  # the goal function (the rating) is for this
  # trial. The trial object knows what values 
  # it has suggested.
  study.tell(trial, rating)

  # We create a new trial object from the study which will be used next time
  # rate_color() is called
  trial = study.ask()

  # We ask the new trial object to suggest new values.

  r1 = YOUR CODE HERE
  g1 = YOUR CODE HERE
  b1 = YOUR CODE HERE

  r2 = YOUR CODE HERE
  g2 = YOUR CODE HERE
  b2 = YOUR CODE HERE


  # Show the newly picked color by setting the color picker to the value
  color_viewer1.style.button_color = rgb_to_hex(r1,g1,b1)
  color_viewer1.tooltip = rgb_to_hex(r1,g1,b1)
  color_viewer2.style.button_color = rgb_to_hex(r2,g2,b2)
  color_viewer2.tooltip = rgb_to_hex(r2,g2,b2)


# bind the call-back to the rate_color() function
update_button.on_click(rate_color)


widgets.VBox([label, widgets.HBox([color_viewer1, color_viewer2]) , rating_widget, update_button])
  File "<ipython-input-33-2e26d7330ebc>", line 17
    r1 = YOUR CODE HERE
              ^
SyntaxError: invalid syntax

Solution

import optuna
import ipywidgets as widgets

def rgb_to_hex(r, g, b):
  return "#{0:02x}{1:02x}{2:02x}".format(int(r*255), int(g*255), int(b*255))

def hex_to_rgb(hex):
  r = hex[1:3]
  g = hex[3:5]
  b = hex[5:7]
  return (int(r, base=16)/255, int(g, base=16)/255, int(b, base=16)/255)



study = optuna.create_study(direction='maximize')
trial = study.ask()

r1 = trial.suggest_float('red_1', 0,1)
g1 = trial.suggest_float('green_1', 0,1)
b1 = trial.suggest_float('blue_1', 0,1)

r2 = trial.suggest_float('red_2', 0,1)
g2 = trial.suggest_float('green_2', 0,1)
b2 = trial.suggest_float('blue_2', 0,1)

# Header label
label = widgets.Label(value="Assign a rating to these pairs of colors, then click the color to update")

color_viewer1 = update_button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r1,g1,b1),
)
color_viewer1.style.button_color = rgb_to_hex(r1,g1,b1)

color_viewer2 = update_button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r2,g2,b2),
)
color_viewer2.style.button_color = rgb_to_hex(r2,g2,b2)


# The float slider used to give a rating
rating_widget = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.001,
    description='Rating:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
)

# The update button widget
update_button = widgets.Button(
    description='Rate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='How do you like the colors?',
)


# The call back for the optimization. Called every time the user presses 
# the "Update" button
def rate_color(button):
  # We have to refer to the widget and trial variables using global.
  global r1,g1,b1, r1, b2, g2, trial, rating_widget
  # We take the value the user has supplied to the rating widget
  rating = rating_widget.value
  
  # This is where we tell Optuna what the value of
  # the goal function (the rating) is for this
  # trial. The trial object knows what values 
  # it has suggested.
  study.tell(trial, rating)

  # We create a new trial object from the study which will be used next time
  # rate_color() is called
  trial = study.ask()

  # We ask the new trial object to suggest new values.

  r1 = trial.suggest_float('red_1', 0,1)
  g1 = trial.suggest_float('green_1', 0,1)
  b1 = trial.suggest_float('blue_1', 0,1)

  r2 = trial.suggest_float('red_2', 0,1)
  g2 = trial.suggest_float('green_2', 0,1)
  b2 = trial.suggest_float('blue_2', 0,1)


  # Show the newly picked color by setting the color picker to the value
  color_viewer1.style.button_color = rgb_to_hex(r1,g1,b1)
  color_viewer1.tooltip = rgb_to_hex(r1,g1,b1)
  color_viewer2.style.button_color = rgb_to_hex(r2,g2,b2)
  color_viewer2.tooltip = rgb_to_hex(r2,g2,b2)


# bind the call-back to the rate_color() function
update_button.on_click(rate_color)


widgets.VBox([label, widgets.HBox([color_viewer1, color_viewer2]) , rating_widget, update_button])
[I 2022-12-05 10:37:11,639] A new study created in memory with name: no-name-ffb52d9a-6799-4b78-abd3-c27d221a9dba

Multi-objective optimization

In the above exercise example, we’re only optimizing a single value - which we’ve tried to bake different aspects into. It could be how pleasing we find each of the colors, but also how pleasing we find their combination.

Optimizing a single value is straight forward, we can simply pick one value and say that it’s the optimum. If we have multiple objectives, this isn’t as simple. Imagine you would like to find hyper parameters maximizes the classification accuracy of a model while also minimizing the number of parameters or the ammount of compute it requires. It is very likely that these two objectives are in conflict - decreasing the allowed compute time leads to lower accuracy. In the end, multi-objective optimization usually requires a human to decide on what trade-off we are willing to take, but Optuna can be used to tell us how the trade-off curves look.

What pairs of colors do you like best?

In this toy example, we extend the above exercise to multi-objective optimization. We wan’t to optimize two colors, but also optimize their respective match. We tell Optuna that we wish to perform multi-objective optimization by giving the keyword directions to optuna.create_study() with a sequence of direction values, one for each objective.

In this case we will maximize three objectives: rating of color 1, rating of color 2 and rating of the color matches so the call to the create_study will be:

study = optuna.create_study(directions=['maximize', 'maximize', 'maximize'])
import optuna
import ipywidgets as widgets

def rgb_to_hex(r, g, b):
  return "#{0:02x}{1:02x}{2:02x}".format(int(r*255), int(g*255), int(b*255))

def hex_to_rgb(hex):
  r = hex[1:3]
  g = hex[3:5]
  b = hex[5:7]
  return (int(r, base=16)/255, int(g, base=16)/255, int(b, base=16)/255)

# We will have three objectives, pleasantness of colors 1 and 2, as well as pleasentness of their matching
study = optuna.create_study(directions=['maximize', 'maximize', 'maximize'])
trial = study.ask()

r1 = trial.suggest_float('red_1', 0,1)
g1 = trial.suggest_float('green_1', 0,1)
b1 = trial.suggest_float('blue_1', 0,1)

r2 = trial.suggest_float('red_2', 0,1)
g2 = trial.suggest_float('green_2', 0,1)
b2 = trial.suggest_float('blue_2', 0,1)

color_viewer1 = update_button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r1,g1,b1),
)
color_viewer1.style.button_color = rgb_to_hex(r1,g1,b1)

color_viewer2 = update_button = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r2,g2,b2),
)
color_viewer2.style.button_color = rgb_to_hex(r2,g2,b2)


# The float slider used to give a rating
rating_color1 = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.001,
    description='Color 1:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
)
# The float slider used to give a rating
rating_color2 = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.001,
    description='Color 2:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
)
# The float slider used to give a rating
rating_color_match = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.001,
    description='Match:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
)

# The update button widget
update_button = widgets.Button(
    description='Rate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='How do you like the colors?',
)


# The call back for the optimization. Called every time the user presses 
# the "Update" button
def rate_colors(button):
  # We have to refer to the widget and trial variables using global.
  global r1,g1,b1, r1, b2, g2, trial, rating_color1, rating_color2, rating_color_match, study
  # We take the value the user has supplied to the rating widget
  v1 = rating_color1.value
  v2 = rating_color2.value
  m = rating_color_match.value
  
  # We now give the tell method a list of values, it's up to us to make sure 
  # they're in the same order as what we telled the directions create_study to use
  study.tell(trial, values=(m, v1, v2))

  # We create a new trial object from the study which will be used next time
  # rate_color() is called
  trial = study.ask()

  # We ask the new trial object to suggest new values.

  r1 = trial.suggest_float('red_1', 0,1)
  g1 = trial.suggest_float('green_1', 0,1)
  b1 = trial.suggest_float('blue_1', 0,1)

  r2 = trial.suggest_float('red_2', 0,1)
  g2 = trial.suggest_float('green_2', 0,1)
  b2 = trial.suggest_float('blue_2', 0,1)


  # Show the newly picked color by setting the color picker to the value
  color_viewer1.style.button_color = rgb_to_hex(r1,g1,b1)
  color_viewer1.tooltip = rgb_to_hex(r1,g1,b1)
  color_viewer2.style.button_color = rgb_to_hex(r2,g2,b2)
  color_viewer2.tooltip = rgb_to_hex(r2,g2,b2)


# bind the call-back to the rate_color() function
update_button.on_click(rate_colors)


widgets.VBox([label, widgets.HBox([rating_color1, color_viewer1]), widgets.HBox([rating_color2, color_viewer2]) , rating_color_match, update_button])
[I 2022-12-05 11:20:10,902] A new study created in memory with name: no-name-49d69c8b-532b-46d9-810f-6bf9f8bd1c39
study.trials_dataframe()
number values_0 values_1 values_2 datetime_start datetime_complete duration params_blue_1 params_blue_2 params_green_1 params_green_2 params_red_1 params_red_2 system_attrs_nsga2:generation state
0 0 0.500 0.500 0.500 2022-12-05 11:20:10.906272 2022-12-05 11:20:12.420332 0 days 00:00:01.514060 0.297427 0.978796 0.785602 0.432933 0.615970 0.135565 0 COMPLETE
1 1 0.277 0.500 0.500 2022-12-05 11:20:12.420666 2022-12-05 11:20:14.263354 0 days 00:00:01.842688 0.189571 0.755184 0.055841 0.859008 0.186524 0.923310 0 COMPLETE
2 2 0.277 0.500 0.721 2022-12-05 11:20:14.263581 2022-12-05 11:20:16.045308 0 days 00:00:01.781727 0.120192 0.816243 0.621310 0.285793 0.494802 0.658099 0 COMPLETE
3 3 0.655 0.253 0.721 2022-12-05 11:20:16.045524 2022-12-05 11:20:19.300460 0 days 00:00:03.254936 0.961877 0.611846 0.755017 0.982573 0.250001 0.745471 0 COMPLETE
4 4 0.269 0.253 0.721 2022-12-05 11:20:19.300702 2022-12-05 11:20:22.382664 0 days 00:00:03.081962 0.086702 0.077736 0.093625 0.225985 0.806207 0.485622 0 COMPLETE
5 5 NaN NaN NaN 2022-12-05 11:20:22.383007 NaT NaT 0.181049 0.949569 0.109965 0.070579 0.840617 0.502335 0 RUNNING

Evaluating multi-objective optimization

When we optimize multiple objectives, there’s typically a trade-off between the objectives. There’s often not a single point which optimizes all of them at the same time.

To help with evaluating multi-objective optimzation, Optuna features tools to visualize the pareto front of our optimization. This front is the set of trials which when we connect with lines/planes has no other points on the side of improvement and is a useful heuristic to use when comparing multiple objectives.

from optuna.visualization import plot_pareto_front
plot_pareto_front(study, target_names=["Match", "Color 1", "Color 2"])

We can also get the points programmatically along the pareto front from the attribute study.best_trials

study.best_trials
[FrozenTrial(number=0, values=[0.5, 0.5, 0.5], datetime_start=datetime.datetime(2022, 12, 5, 11, 20, 10, 906272), datetime_complete=datetime.datetime(2022, 12, 5, 11, 20, 12, 420332), params={'red_1': 0.6159704427300418, 'green_1': 0.7856022323342713, 'blue_1': 0.29742704732233693, 'red_2': 0.13556506763097065, 'green_2': 0.432932677987358, 'blue_2': 0.9787962371220519}, distributions={'red_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'green_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'blue_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'red_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'green_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'blue_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None)}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, trial_id=0, state=TrialState.COMPLETE, value=None),
 FrozenTrial(number=2, values=[0.277, 0.5, 0.721], datetime_start=datetime.datetime(2022, 12, 5, 11, 20, 14, 263581), datetime_complete=datetime.datetime(2022, 12, 5, 11, 20, 16, 45308), params={'red_1': 0.4948019257808193, 'green_1': 0.6213095226780458, 'blue_1': 0.12019236663102772, 'red_2': 0.6580991580283672, 'green_2': 0.28579265834935685, 'blue_2': 0.816243307707653}, distributions={'red_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'green_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'blue_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'red_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'green_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'blue_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None)}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, trial_id=2, state=TrialState.COMPLETE, value=None),
 FrozenTrial(number=3, values=[0.655, 0.253, 0.721], datetime_start=datetime.datetime(2022, 12, 5, 11, 20, 16, 45524), datetime_complete=datetime.datetime(2022, 12, 5, 11, 20, 19, 300460), params={'red_1': 0.25000087827320916, 'green_1': 0.7550169796801904, 'blue_1': 0.9618770117456149, 'red_2': 0.7454706204839532, 'green_2': 0.9825729572843858, 'blue_2': 0.6118455693701307}, distributions={'red_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'green_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'blue_1': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'red_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'green_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None), 'blue_2': FloatDistribution(high=1.0, log=False, low=0.0, step=None)}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None)]

We can use this list of FrozenTrial objects to filter out som of interest, for example what trial had the best score of the matching color objective (the first value in the values list, hence the key = lambda trial: trial.values[0] when looking for the best trial below)

best_color_match_trial = max(study.best_trials, key=lambda trial: trial.values[0])
print(f"Trial with best color match: ")
print(f"\tnumber: {best_color_match_trial.number}")
print(f"\tparams: {best_color_match_trial.params}")
print(f"\tvalues: {best_color_match_trial.values}")

r1, g1, b1 = best_color_match_trial.params['red_1'], best_color_match_trial.params['green_1'], best_color_match_trial.params['blue_1']
r2, g2, b2 = best_color_match_trial.params['red_2'], best_color_match_trial.params['green_2'], best_color_match_trial.params['blue_2']
match_value = best_color_match_trial.values[0]

# The update button widget
button1 = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r1,g1,b1)
)
button.style.button_color = rgb_to_hex(r1,g1,b1)
# The update button widget
button2 = widgets.Button(
    description='',
    disabled=True,
    tooltip=rgb_to_hex(r2,g2,b2)
)
button2.style.button_color = rgb_to_hex(r2,g2,b2)
widgets.VBox([widgets.HBox([button1, button2]), widgets.Label(value=f"Match value: {match_value}")])
Trial with best color match: 
	number: 3
	params: {'red_1': 0.25000087827320916, 'green_1': 0.7550169796801904, 'blue_1': 0.9618770117456149, 'red_2': 0.7454706204839532, 'green_2': 0.9825729572843858, 'blue_2': 0.6118455693701307}
	values: [0.655, 0.253, 0.721]

Determining exactly what trial is the best when performing multi-objective optimization is not automatic, so the user has to look at pareto frontier and decide on what trade-off is reasonable.

Learning outcomes

By now, you should have learned these concepts:

  • Creating an Optuna study for single objective optimization

  • Creating a trial object and sample it for a set of hyper parameters

  • Running multiple trials to perform hyper parameter optimization

  • Inspecting the results of the trials and choose a set of parameters

  • Understand the difference between single objective and multi-objective hyper parameter optimization, and how to perform them both in Optuna