Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place.

Example :
Given

         1
        / \
       2   5
      / \   \
     3   4   6

The flattened tree should look like:

   1
    \
     2
      \
       3
        \
         4
          \
           5
            \
             6

Note that the left child of all nodes should be NULL.

Method:

Move the right subtree to the left subtree's rightmost node, then move the left subtree to right

Solution:

Time: O(n)
Space: O(1)

/**
 * Definition for binary tree
 * class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
 /*
         1
        / \
       2   5
      / \   \
     3   4   6
     
*/
public class Solution {
    public TreeNode flatten(TreeNode a) {
        TreeNode curr = a;
        while (curr != null) {
            if (curr.left != null) {
                TreeNode right = curr.left;
                while (right.right != null) {
                    right = right.right;
                }
                right.right = curr.right;
                curr.right = curr.left;
                curr.left = null;
            }
            curr = curr.right;
        }
        return a;
    }
}