7 Different Ways to Flatten a List of Lists in Python
Learn how to "unlist" (unnest) a irregular list; nested lists of tuples, ints or strings; with or without recursion; or using itertools python
Ever wondered how can you flatten, or unnest, a 2D list of lists in Python?
In another words, how to turn 2-D lists into 1D:
[[1, 2], [3, 4], [5, 6, 7], [8]]
->[1, 2, 3, 4, 5, 6, 7, 8]
?[[1, 2], [4, 5], [[[7]]], [[[[8]]]]]
->[1, 2, 3, 4, 5, 6, 7, 8]
?[1, 2, 3, [4, 5, 6], [[7, 8]]]
->[1, 2, 3, 4, 5, 6, 7, 8]
?
What about lists with mixed types such as list of strings, or list of tuples?
[[1, 2], "three", ["four", "five"]]
->[1, 2, "three", "four", "five"]
[[1, 2], (3, 4), (5, 6, 7), [8]]
->[1, 2, 3, 4, 5, 6, 7, 8]
In this post, we’ll see how we can unnest an arbitrarily nested list of lists in 7 different ways. Each method has pros and cons, and varies in performance. By going over each one, you’ll learn how to identify the most appropriate solution for your problem by creating your own flatten()
function in Python.
For all examples, we'll use Python 3, and for the tests pytest
.
By the end of this guide, you'll have learned:
- how to flatten / unnest a list of mixed types, including list of strings, list of tuples or ints
- the best way to flatten lists of lists with list comprehensions
- how to unfold a list and remove duplicates
- how to convert a nested list of lists using the built-in function
sum
from the standard library - how to use numpy to flatten nested lists
- how to use
itertools
chain to create a flat list - the best way to flatten a nested list using recursion or without recursion
Table of Contents
- Flattening a list of lists with list comprehensions
- How to flatten list of strings, tuples or mixed types
- [How to flatten a nested list and remove duplicates](#how to flatten a list and remove duplicates)
- Flattening a nested list of lists with the
sum
function - Flattening using
itertools.chain
- Flatten a regular list of lists with numpy
- Conclusion
Flattening a list of lists with list comprehensions
Let’s imagine that we have a simple list of lists like this [[1, 3], [2, 5], [1]]
and we want to flatten it.
In other words, we want to convert the original list into a flat list like this [1, 3, 2, 5, 1]
. The first way of doing that is through list/generator comprehensions. We iterate through each sublist, then iterate over each one of them producing a single element each time.
The following function accepts any multidimensional lists as an argument and returns a generator. The reason for that is to avoid building a whole list in memory. We can then use the generators to create a single list.
To make sure everything works as expected, we can assert the behavior with the test_flatten
unit test.
from typing import List, Any, Iterable
def flatten_gen_comp(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using generators comprehensions."""
return (item
for sublist in lst
for item in sublist)
def test_flatten():
lst = [[1, 3], [2, 5], [1]]
assert list(flatten_gen_comp(lst)) == [1, 3, 2, 5, 1]
This function returns a generator, to get a list back we need to convert the generator to list.
When we run the test we can see it passing...
============================= test session starts ==============================
flatten.py::test_flatten PASSED [100%]
============================== 1 passed in 0.01s ===============================
Process finished with exit code 0
If you prefer you can make the code shorter by using a lambda function.
>>> flatten_lambda = lambda lst: (item for sublist in lst for item in sublist)
>>> lst = [[1, 3], [2, 5], [1]]
>>> list(flatten_lambda(lst))
[1, 3, 2, 5, 1]
How to flatten list of strings, tuples or mixed types
The technique we've seen assumes the items are not iterables. Otherwise, it flattens them as well, which is the case for strings. Let's see what happens if we plug a list of lists and strings.
from typing import List, Any, Iterable
def flatten_gen_comp(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using generators comprehensions."""
return (item
for sublist in lst
for item in sublist)
>>> lst = [['hello', 'world'], ['my', 'dear'], 'friend']
>>> list(flatten(lst))
['hello', 'world', 'my', 'dear', 'f', 'r', 'i', 'e', 'n', 'd']
Oops, that's not what we want! The reason is, since one of the items is iterable, the function will unfold them as well.
One way of preventing that is by checking if the item is a list of not. If not, we don't iterate over it.
>>> from typing import List, Any, Iterable
>>> def flatten(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using generators comprehensions.
Returns a flattened version of list lst.
"""
for sublist in lst:
if isinstance(sublist, list):
for item in sublist:
yield item
else:
yield sublist
>>> lst = [['hello', 'world'], ['my', 'dear'], 'friend']
>>> list(flatten(l))
['hello', 'world', 'my', 'dear', 'f', 'r', 'i', 'e', 'n', 'd']
Since we check if sublist is a list of not, this works with list of tuples as well, any list of iterables, for that matter.
>>> lst = [[1, 2], 3, (4, 5)]
>>> list(flatten(lst))
[1, 2, 3, (4, 5)]
Lastly, this flatten function works for multidimensional list of mixed types.
>>> lst = [[1, 2], 3, (4, 5), ["string"], "hello"]
>>> list(flatten(lst))
[1, 2, 3, (4, 5), 'string', 'hello']
How to flatten a list and remove duplicates
To flatten a list of lists and return a list without duplicates, the best way is to convert the final output to a set
.
The only downside is that if the list is big, there'll be a performance penalty since we need to create the set
using the generator, then convert set
to list
.
>>> from typing import List, Any, Iterable
>>> def flatten(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using generators comprehensions.
Returns a flattened version of list lst.
"""
for sublist in lst:
if isinstance(sublist, list):
for item in sublist:
yield item
else:
yield sublist
>>> lst = [[1, 2], 3, (4, 5), ["string"], "hello", 3, 4, "hello"]
>>> list(set(flatten(lst)))
[1, 2, 3, 'hello', 4, (4, 5), 'string']
Flattening a list of lists with the sum
function
The second strategy is a bit unconventional and, truth to be told, very "magical".
Did you know that we can use the built-in function sum
to create flattened lists?
All we need to do is to pass the list as an argument along an empty list. The following code snippet illustrates that.
If you’re curious about this approach, I discuss it in more detail in another post.
from typing import List, Any, Iterable
def flatten_sum(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using sum."""
return sum(lst, [])
def test_flatten():
lst = [[1, 3], [2, 5], [1]]
assert list(flatten_sum(lst)) == [1, 3, 2, 5, 1]
And the tests pass too...
flatten.py::test_flatten PASSED
Even though it seems clever, it's not a good idea to use it in production IMHO. As we'll see later, this is the worst way to flatten a list in terms of performance.
Flattening using itertools.chain
The third alternative is to use the chain
function from the itertools
module. In a nutshell, chain
creates a single iterator from a sequence of other iterables. This function is a perfect match for our use case.
Equivalent implementation of chain
taken from the official docs looks like this.
import itertools
def chain(*iterables):
# chain('ABC', 'DEF') --> A B C D E F
for it in iterables:
for element in it:
yield element
We can then flatten the multi-level lists like so:
def flatten_chain(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using chain."""
return itertools.chain(*lst)
def test_flatten():
lst = [[1, 3], [2, 5], [1]]
assert list(flatten_chain(lst)) == [1, 3, 2, 5, 1]
And the test pass...
[OMITTED]
....
flatten.py::test_flatten PASSED
...
[OMITTED]
Flatten a regular list of lists with numpy
Another option to create flat lists from nested ones is to use numpy. This library is mostly used to represent and perform operations on multidimensional arrays such as 2D and 3D arrays.
What most people don't know is that some of its function also work with multidimensional lists or other list of iterables. For example, we can use the numpy.concatenate
function to flatten a regular list of lists.
>>> import numpy as np
>>> lst = [[1, 3], [2, 5], [1], [7, 8]]
>>> list(np.concatenate(lst))
[1, 3, 2, 5, 1, 7, 8]
Flattening irregular lists
So far we've been flattening regular lists, but what happens if we have a list like this [[1, 3], [2, 5], 1]
or this [1, [2, 3], [[4]], [], [[[[[[[[[5]]]]]]]]]]
?
Unfortunately, that ends up not so well if we try to apply any of those preceding approaches. In this section, we’ll see two unique solutions for that, one recursive and the other iterative.
The recursive approach
Solving the flatting problem recursively means iterating over each list element and deciding if the item is already flattened or not. If so, we return it, otherwise we can call flatten_recursive
on it. But better than words is code, so let’s see some code.
from typing import List, Any, Iterable
def flatten_recursive(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using recursion."""
for item in lst:
if isinstance(item, list):
yield from flatten_recursive(item)
else:
yield item
def test_flatten_recursive():
lst = [[1, 3], [2, 5], 1]
assert list(flatten_recursive(lst)) == [1, 3, 2, 5, 1]
Bare in mind that the if isinstance(item, list)
means it only works with lists. On the other hand, it will flatten lists of mixed types with no trouble.
>>> list(flatten_recursive(lst))
[1, 2, 3, (4, 5), 'string', 'hello']
The iterative approach
The iterative approach is no doubt the most complex of all of them.
In this approach, we’ll use a deque
to flatten the irregular list. Quoting the official docs:
"Deques are a generalization of stacks and queues (the name is pronounced “deck” and is short for “double-ended queue”). Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction."
In other words, we can use deque
to simulate the stacking operation of the recursive solution.
To do that, we’ll start just like we did in the recursion case, we’ll iterate through each element and if the element is not a list, we’ll append to the left of the deque
. The appendleft
method append the element to the leftmost position of the deque
, for example:
>>> from collections import deque
>>> l = deque()
>>> l.appendleft(2)
>>> l
deque([2])
>>> l.appendleft(7)
>>> l
deque([7, 2])
If the element is a list, though, then we need to reverse it first to pass it to “extendleft” method, like this my_deque.extendleft(reversed(item))
. Again, similar to a list
, extendleft
adds each item to the leftmost position of the deque
in series. As a result, deque
will add the elements in a reverse order. That’s exactly why we need to reverse the sub-list before extending left. To make things clearer, let’s see an example.
>>> l = deque()
>>> l.extendleft([1, 2, 3])
>>> l
deque([3, 2, 1])
The final step is to iterate over the deque
removing the leftmost element and yielding it if it’s not a list. If the element is a list, then we need to extend it left. The full implement goes like this:
from typing import List, Any, Iterable
def flatten_deque(lst: List[Any]) -> Iterable[Any]:
"""Flatten a list using a deque."""
q = deque()
for item in lst:
if isinstance(item, list):
q.extendleft(reversed(item))
else:
q.appendleft(item)
while q:
elem = q.popleft()
if isinstance(elem, list):
q.extendleft(reversed(elem))
else:
yield elem
def test_flatten_super_irregular():
lst = [1, [2, 3], [4], [], [[[[[[[[[5]]]]]]]]]]
assert list(flatten_deque(lst)) == [1, 2, 3, 4, 5]
When we run the test, it happily pass...
============================= test session starts ==============================
...
flatten.py::test_flatten_super_irregular PASSED [100%]
============================== 1 passed in 0.01s ===============================
This approach can also flatten lists of mixed types with no trouble.
>>> list(flatten_deque(lst))
[1, 2, 3, (4, 5), 'string', 'hello']
Which method is faster? A performance comparison
As we’ve seen in the previous section, if our multi-level list is irregular we have little choice. But assuming that this is not a frequent use case, how these approaches compare in terms of performance?
In this last part, we’ll run a benchmark and compare which solutions perform best.
To do that, we can use the timeit
module, we can invoke it in IPython
by doing %timeit [code_to_measure]
. The following list is the timings for each one of them. As we can see, flatten_chain
is the fastest implementation of all. It flatted our list lst
in 267 µs
avg. The slowest implementation is flatten_sum
, taking around 42 ms
to flatten the same list.
PS: A special thanks to @hynekcer who pointed out a bug in this benchmark. Since most of the functions return a generator, we need to consume all elements in order to get a better assessment. We can either iterate over the generator or create a list out of it.
Flatten generator comprehension
In [1]: lst = [[1, 2, 3], [4, 5, 6], [7], [8, 9]] * 1_000
In [2]: %timeit list(flatten_gen_comp(lst))
615 µs ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Flatten sum
In [3]: In [19]: %timeit list(flatten_sum(lst))
42 ms ± 660 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Flatten chain
In [4]: %timeit list(flatten_chain(lst))
267 µs ± 517 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Flatten numpy
In [5]: %timeit list(flatten_numpy(lst))
4.65 ms ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Flatten recursive
In [6]: %timeit list(flatten_recursive(lst))
3.02 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Flatten deque
In [7]: %timeit list(flatten_deque(lst))
2.97 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Conclusion
We can flatten a multi-level list in several different ways with Python. Each approach has its pros and cons. In this post we looked at 5 different ways to create 1D lists from nested 2D lists, including:
- regular nested lists
- irregular lists
- list of mixed types
- list of strings
- list of tuples or ints
- recursive and iterative approach
- using the iterools module
- removing duplicates
Other posts you may like:
- Everything You Need to Know About Python's Namedtuples
- The Best Way to Compare Two Dictionaries in Python
- Python's f-strings: 73 Examples to Help You Master It
- Design Patterns That Make Sense in Python: Simple Factory
- How to Pass Multiple Arguments to a map Function in Python
- 3 Ways to Unit Test REST APIs in Python
See you next time!