Concatenated Words

Given a list of words (without duplicates), please write a program that returns all concatenated words in the given list of words.
A concatenated word is defined as a string that is comprised entirely of at least two shorter words in the given array.

Example:

Input: ["cat","cats","catsdogcats","dog","dogcatsdog","hippopotamuses","rat","ratcatdogcat"]

Output: ["catsdogcats","dogcatsdog","ratcatdogcat"]

Explanation: "catsdogcats" can be concatenated by "cats", "dog" and "cats"; 
 "dogcatsdog" can be concatenated by "dog", "cats" and "dog"; 
"ratcatdogcat" can be concatenated by "rat", "cat", "dog" and "cat".


Note:

  1. The number of elements of the given array will not exceed 10,000
  2. The length sum of elements in the given array will not exceed 600,000.
  3. All the input string will only include lower case letters.
  4. The returned elements order does not matter.

Solution:

class Solution {
    static class TrieNode {
        char c;
        List<TrieNode> children = new ArrayList();
        boolean isWord = false;
        
        public TrieNode(char c) {
            this.c = c;
        }
        
        public String toString() {
            return c + "";
        }
    }
    
    static class Trie {
        TrieNode root = new TrieNode('\0');
        
        public void insert(String word) {
            TrieNode curr = root;
            for (char c : word.toCharArray()) {
                boolean found = false;
                for (TrieNode next : curr.children) {
                    if (next.c == c) {
                        curr = next;
                        found = true;
                        break;     
                    }
                }
                if (found) continue;
                TrieNode next = new TrieNode(c);
                curr.children.add(next);
                curr = next;
            }
            curr.isWord = true;
        }
        
        public boolean isConcatenated(String word, int wordCount) {
            if (word.equals("") && wordCount > 1) return true;
            TrieNode curr = root;
            char[] arr = word.toCharArray();
            for (int i = 0; i < arr.length; i ++) {
                char c = arr[i];
                boolean found = false;
                for (TrieNode next : curr.children) {
                    if (next.c == c) {
                        curr = next;
                        found = true;
                        break;
                    }
                }
                if (!found) return false;
                if (curr.isWord && isConcatenated(word.substring(i + 1), wordCount + 1)) {
                    return true;
                }
            }
            return false;
        }
    }
    
    public List<String> findAllConcatenatedWordsInADict(String[] words) {
        Arrays.sort(words, (a, b) -> { return Integer.compare(a.length(), b.length()); });
        List<String> result = new ArrayList();
        Trie trie = new Trie();
        for (String word : words) {
            trie.insert(word);
            if (trie.isConcatenated(word, 0)) {
                result.add(word);
            }
        }
        return result;
    }
}