Solution: Flatten Binary Tree to Linked List

Let's solve the Flatten Binary Tree to Linked List problem using the Tree Depth-First Search pattern.

Statement

Given the root of a binary tree, the task is to flatten the tree into a linked list using the same TreeNode class. The left child pointer of each node in the linked list should always be NULL, and the right child pointer should point to the next node in the linked list. The nodes in the linked list should be in the same order as that of the preorder traversal of the given binary tree.

Constraints:

  • 100-100 \leq Node.data 100\leq 100.
  • The tree contains nodes in the range [1,500][1, 500].

Solution

So far, you’ve probably brainstormed some approaches and have an idea of how to solve this problem. Let’s explore some of these approaches and figure out which one to follow based on considerations such as time complexity and any implementation constraints.

Naive approach

The naive approach to flatten a binary tree into a linked list is to perform a preorder traversal of the tree and store the visited nodes in a Queue. After the traversal, start dequeuing the nodes and set the pointers of each node such that: the right pointer of the dequeued node is set to the previously dequeued node, and the left pointer is set to NULL.

However, this naive approach requires extra memory because it uses a Queue. The space complexity would be O(n)O(n). However, can the problem be solved without additional data structures?

Optimized approach using depth-first search

The solution to flattening a binary tree into a linked list involves a depth-first traversal to rearrange the tree’s nodes in place. The process begins by recursively flattening the left and right subtrees. The left subtree is attached to the current node’s right, and the original right subtree is appended to the end of this newly attached left subtree. Following this pattern for each node, the tree is transformed into a singly linked list with all nodes arranged in the same order as a pre-order tree traversal. This ensures the final structure retains a linear chain of nodes without additional space.

Specifically, we start at the root node, and for each node, find the right-most node in its left subtree. We set the right pointer of the right-most node to the current node’s right pointer. After that, we set the current node’s right pointer to the current node’s left pointer. Finally, we set the current node’s left pointer to NULL. We will repeat this process for all nodes in the binary tree.

Note: In the following section, we will gradually build the solution. Alternatively, you can skip straight to just the code.

Step-by-step solution construction

Starting from the tree’s root, we traverse the tree in a depth-first search manner. At each node, we check if it has a left child. If it does, we follow a path down to the rightmost node of the left subtree. This can be achieved by repeatedly moving to the right child of each node in the left subtree until we reach a node that does not have a right child.

Solution.java
BinaryTree.java
TreeNode.java
public class Solution {
public static TreeNode<Integer> flattenTree(TreeNode<Integer> root) {
if (root == null) {
return null;
}
//Assign current to root
TreeNode<Integer> current = root;
TreeNode<Integer> last = null;
// Traversing the whole tree
System.out.println("\n\tTraversing the tree:");
while (current != null) {
// printing the tree
Print.displayTree(root, current);
if (current.left != null) {
System.out.println("\n\tThe current node has a left child.");
last = current.left;
// printing the tree
Print.displayTree(root, last);
// If the last node has a right child
while (last.right != null) {
System.out.println("\n\tThe current node has a right child.");
// printing the tree
Print.displayTree(root, last.right);
last = last.right;
}
System.out.println("\n\tThe current node does not have a right child.");
System.out.println("\tWe'll move back");
// printing the tree
Print.displayTree(root, current);
System.out.println("\n\tWe'll set the left pointer of the current node to None.");
current.left = null;
// printing the tree
Print.displayTree(root, current);
}
System.out.println("\n\tMoving to the right child.");
current = current.right;
}
// Returning root of flattened tree
return root;
}
public static void main(String[] args) {
// Create a list of list of TreeNode objects to represent binary trees
List<List<TreeNode<Integer>>> listOfTrees = Arrays.asList(
Arrays.asList(new TreeNode<Integer>(3), new TreeNode<Integer>(2), new TreeNode<Integer>(17), new TreeNode<Integer>(1), new TreeNode<Integer>(4), new TreeNode<Integer>(19), new TreeNode<Integer>(5)),
Arrays.asList(new TreeNode<Integer>(7), new TreeNode<Integer>(6), new TreeNode<Integer>(5), new TreeNode<Integer>(4), new TreeNode<Integer>(3), new TreeNode<Integer>(2), null, new TreeNode<Integer>(1)),
Arrays.asList(new TreeNode<Integer>(5), new TreeNode<Integer>(4), new TreeNode<Integer>(6), new TreeNode<Integer>(3), new TreeNode<Integer>(2), new TreeNode<Integer>(7), new TreeNode<Integer>(8), new TreeNode<Integer>(1), new TreeNode<Integer>(9)),
Arrays.asList(new TreeNode<Integer>(5), new TreeNode<Integer>(2), new TreeNode<Integer>(1), new TreeNode<Integer>(6), new TreeNode<Integer>(10), new TreeNode<Integer>(11), new TreeNode<Integer>(44)),
Arrays.asList(new TreeNode<Integer>(1), new TreeNode<Integer>(2), new TreeNode<Integer>(5), new TreeNode<Integer>(3), new TreeNode<Integer>(4), new TreeNode<Integer>(6)),
Arrays.asList(new TreeNode<Integer>(-1), new TreeNode<Integer>(-2), null, new TreeNode<Integer>(-5), new TreeNode<Integer>(1), new TreeNode<Integer>(2), null, new TreeNode<Integer>(-6))
);
// Create the binary trees using the BinaryTree class
List<BinaryTree<Integer>> inputTrees = new ArrayList<BinaryTree<Integer>>();
for (List<TreeNode<Integer>> ListOfNodes : listOfTrees) {
BinaryTree<Integer> tree = new BinaryTree<Integer>(ListOfNodes);
inputTrees.add(tree);
}
// Print the input trees
int x = 1;
for (BinaryTree<Integer> tree : inputTrees) {
System.out.println(x + ".\tBinary tree:");
Print.displayTree(tree.root, null);
flattenTree(tree.root);
x++;
System.out.println(new String(new char[100]).replace('\0', '-'));
}
}
}
Flatten Binary Tree to Linked List

Once we reach the rightmost node, we point the right pointer of this node to the right child of the current node. After making this connection, we point the current node’s right pointer to the current node’s left child. Finally, we set the current node’s left pointer to NULL. We repeat this process until all nodes of the tree have been traversed.

Solution.java
BinaryTree.java
TreeNode.java
public class Solution {
public static TreeNode<Integer> flattenTree(TreeNode<Integer> root) {
if (root == null) {
return null;
}
//Assign current to root
TreeNode<Integer> current = root;
// Traversing the whole tree
System.out.println("\n\tTraversing the tree:");
while (current != null) {
// printing the tree
Print.displayTree(root, current);
if (current.left != null) {
System.out.println("\n\tThe current node has a left child.");
TreeNode<Integer> last = current.left;
// printing the tree
Print.displayTree(root, last);
// If the last node has a right child
while (last.right != null) {
System.out.println("\n\tThe current node has a right child.");
// printing the tree
Print.displayTree(root, last.right);
last = last.right;
}
System.out.println("\n\tThe current node does not have a right child.");
System.out.println("\tWe'll merge it with the right subtree.");
last.right = current.right;
current.right = current.left;
current.left = null;
System.out.print("\n\tOur tree now looks like this:\n");
// printing the tree
Print.displayTree(root, current);
}
if (current.right != null)
System.out.println("\n\tMoving to the right child.\n");
current = current.right;
}
// Returning root of flattened tree
return root;
}
public static void main(String[] args) {
// Create a list of list of TreeNode objects to represent binary trees
List<List<TreeNode<Integer>>> listOfTrees = Arrays.asList(
Arrays.asList(new TreeNode<Integer>(3), new TreeNode<Integer>(2), new TreeNode<Integer>(17), new TreeNode<Integer>(1), new TreeNode<Integer>(4), new TreeNode<Integer>(19), new TreeNode<Integer>(5)),
Arrays.asList(new TreeNode<Integer>(7), new TreeNode<Integer>(6), new TreeNode<Integer>(5), new TreeNode<Integer>(4), new TreeNode<Integer>(3), new TreeNode<Integer>(2), null, new TreeNode<Integer>(1)),
Arrays.asList(new TreeNode<Integer>(5), new TreeNode<Integer>(4), new TreeNode<Integer>(6), new TreeNode<Integer>(3), new TreeNode<Integer>(2), new TreeNode<Integer>(7), new TreeNode<Integer>(8), new TreeNode<Integer>(1), new TreeNode<Integer>(9)),
Arrays.asList(new TreeNode<Integer>(5), new TreeNode<Integer>(2), new TreeNode<Integer>(1), new TreeNode<Integer>(6), new TreeNode<Integer>(10), new TreeNode<Integer>(11), new TreeNode<Integer>(44)),
Arrays.asList(new TreeNode<Integer>(1), new TreeNode<Integer>(2), new TreeNode<Integer>(5), new TreeNode<Integer>(3), new TreeNode<Integer>(4), new TreeNode<Integer>(6)),
Arrays.asList(new TreeNode<Integer>(-1), new TreeNode<Integer>(-2), null, new TreeNode<Integer>(-5), new TreeNode<Integer>(1), new TreeNode<Integer>(2), null, new TreeNode<Integer>(-6))
);
// Create the binary trees using the BinaryTree class
List<BinaryTree<Integer>> inputTrees = new ArrayList<BinaryTree<Integer>>();
for (List<TreeNode<Integer>> ListOfNodes : listOfTrees) {
BinaryTree<Integer> tree = new BinaryTree<Integer>(ListOfNodes);
inputTrees.add(tree);
}
// Print the input trees
int x = 1;
for (BinaryTree<Integer> tree : inputTrees) {
System.out.println(x + ".\tBinary tree:");
Print.displayTree(tree.root, null);
flattenTree(tree.root);
x++;
System.out.println(new String(new char[100]).replace('\0', '-'));
}
}
}
Flatten Binary Tree to Linked List

Just the code

Here’s the complete solution to this problem:

Solution.java
BinaryTree.java
TreeNode.java
public class Solution {
public static TreeNode<Integer> flattenTree(TreeNode<Integer> root) {
if (root == null) {
return null;
}
TreeNode<Integer> current = root;
while (current != null) {
if (current.left != null) {
TreeNode<Integer> last = current.left;
while (last.right != null) {
last = last.right;
}
last.right = current.right;
current.right = current.left;
current.left = null;
}
current = current.right;
}
return root;
}
public static void main(String[] args) {
// Create a list of list of TreeNode objects to represent binary trees
List<List<TreeNode<Integer>>> listOfTrees = Arrays.asList(
Arrays.asList(new TreeNode<Integer>(3), new TreeNode<Integer>(2), new TreeNode<Integer>(17), new TreeNode<Integer>(1), new TreeNode<Integer>(4), new TreeNode<Integer>(19), new TreeNode<Integer>(5)),
Arrays.asList(new TreeNode<Integer>(7), new TreeNode<Integer>(6), new TreeNode<Integer>(5), new TreeNode<Integer>(4), new TreeNode<Integer>(3), new TreeNode<Integer>(2), null, new TreeNode<Integer>(1)),
Arrays.asList(new TreeNode<Integer>(5), new TreeNode<Integer>(4), new TreeNode<Integer>(6), new TreeNode<Integer>(3), new TreeNode<Integer>(2), new TreeNode<Integer>(7), new TreeNode<Integer>(8), new TreeNode<Integer>(1), new TreeNode<Integer>(9)),
Arrays.asList(new TreeNode<Integer>(5), new TreeNode<Integer>(2), new TreeNode<Integer>(1), new TreeNode<Integer>(6), new TreeNode<Integer>(10), new TreeNode<Integer>(11), new TreeNode<Integer>(44)),
Arrays.asList(new TreeNode<Integer>(1), new TreeNode<Integer>(2), new TreeNode<Integer>(5), new TreeNode<Integer>(3), new TreeNode<Integer>(4), new TreeNode<Integer>(6)),
Arrays.asList(new TreeNode<Integer>(-1), new TreeNode<Integer>(-2), null, new TreeNode<Integer>(-5), new TreeNode<Integer>(1), new TreeNode<Integer>(2), null, new TreeNode<Integer>(-6))
);
// Create the binary trees using the BinaryTree class
List<BinaryTree<Integer>> inputTrees = new ArrayList<BinaryTree<Integer>>();
for (List<TreeNode<Integer>> ListOfNodes : listOfTrees) {
BinaryTree<Integer> tree = new BinaryTree<Integer>(ListOfNodes);
inputTrees.add(tree);
}
// Print the input trees
int x = 1;
for (BinaryTree<Integer> tree : inputTrees) {
System.out.println(x + ".\tBinary tree:");
Print.displayTree(tree.root, null);
System.out.println("\n\tFlattened tree:");
Print.displayTree(flattenTree(tree.root), null);
x++;
System.out.println(new String(new char[100]).replace('\0', '-'));
}
}
}
Flatten Binary Tree to Linked List

Solution summary

  • Traverse the binary tree, and for each node, check if it has a left child.
  • If the left child exists, find the rightmost node in the left subtree.
  • Point the right pointer of the rightmost node to the right child of the current node.
  • Set the current node’s right pointer to the current node’s left pointer.
  • Set the current node’s left child to NULL.
  • Repeat the steps above until the entire binary tree has been traversed.

Time complexity

The time complexity is O(n)O(n), where nn is the number of nodes in the tree.

Space complexity

The space complexity will be O(1)O(1) for this problem.

Level up your interview prep. Join Educative to access 70+ hands-on prep courses.