Palindrome Pairs

Given a list of unique words, find all pairs of distinct indices (i, j) in the given list, so that the concatenation of the two words, i.e. words[i] + words[j] is a palindrome.

Example 1:

Input: ["abcd","dcba","lls","s","sssll"]
Output: [[0,1],[1,0],[3,2],[2,4]] 
Explanation: The palindromes are ["dcbaabcd","abcddcba","slls","llssssll"]

Example 2:

Input: ["bat","tab","cat"]
Output: [[0,1],[1,0]] 
Explanation: The palindromes are ["battab","tabbat"]

Solution:

http://www.allenlipeng47.com/blog/index.php/2016/03/15/palindrome-pairs/

class Solution {
    static class Trie {
        int pos;
        Trie[] children;
        List<Integer> palins;
        public Trie() {
            pos = -1;
            children = new Trie[26];
            palins = new ArrayList();
        }
    }
    
    
    private void add(Trie root, String word, int pos) {
        for (int i = word.length() - 1; i >= 0; i --) {
            char c = word.charAt(i);
            if (isPalindrome(word, 0, i)) {
                root.palins.add(pos);
            }
            if (root.children[c - 'a'] == null) {
                root.children[c - 'a'] = new Trie();
            }
            root = root.children[c - 'a'];
        }
        root.pos = pos;
        root.palins.add(pos);
    }
    
    private void search(Trie root, String[] words, int i, List<List<Integer>> ans) {
        int len = words[i].length();
        for (int j = 0; j < len && root != null; j ++) {
            if (root.pos >= 0 && i != root.pos && isPalindrome(words[i], j, len - 1)) {
                ans.add(Arrays.asList(new Integer[]{i, root.pos}));
            }
            char c = words[i].charAt(j);
            root = root.children[c - 'a'];
        }
        if (root != null && root.palins.size() > 0) {
            for (int j : root.palins) {
                if (j != i) {
                    ans.add(Arrays.asList(new Integer[]{i, j}));
                }
            }
        }
    }
    
    public boolean isPalindrome(String str, int i, int j) {
        while (i < j) {
            if (str.charAt(i++) != str.charAt(j--)) {
                return false;
            }
        }
        return true;
    }

    public List<List<Integer>> palindromePairs(String[] words) {
        List<List<Integer>> ans = new ArrayList<>();
        Trie trie = new Trie();
        for (int i = 0; i < words.length; i++) {
            add(trie, words[i], i);
        }
        for (int i = 0; i < words.length; i++) {
            search(trie, words, i, ans);
        }
        return ans;
    }
}

class Solution {
    public List<List<Integer>> palindromePairs(String[] words) {
        Map<String, Integer> map = new HashMap();
        for (int i = 0; i < words.length; i ++) {
            map.put(words[i], i);
        }
        List<List<Integer>> result = new ArrayList();
        for (int i = 0; i < words.length; i ++) {
            String w = words[i];
            for (int j = 0; j <= w.length(); j ++) {
                // split w into 2 strings
                String str1 = w.substring(0, j);
                String str2 = w.substring(j);
                
                // use str1 as middle part, str2 as right part, find reversed(str2) as left part
                if (isP(str1)) {
                    String revstr2 = new StringBuilder(str2).reverse().toString();
                    if (map.containsKey(revstr2) && map.get(revstr2) != i) {
                        result.add(Arrays.asList(map.get(revstr2), i));
                    }
                }
                // use str2 as middle part, str1 as left part, find reversed(str1) as right part
                if (isP(str2) && str2.length() != 0) {
                    String revstr1 = new StringBuilder(str1).reverse().toString();
                    if (map.containsKey(revstr1) && map.get(revstr1) != i) {
                        result.add(Arrays.asList(i, map.get(revstr1)));
                    }
                }
            }
        }
        return result;
    }
    
    private boolean isP(String s) {
        int left = 0, right = s.length() - 1;
        while (left < right) {
            if (s.charAt(left++) != s.charAt(right --)) {
                return false;
            }
        }
        return true;
    }
}