Number of Restricted Paths From First to Last Node

There is an undirected weighted connected graph. You are given a positive integer n which denotes that the graph has n nodes labeled from 1 to n, and an array edges where each edges[i] = [ui, vi, weighti] denotes that there is an edge between nodes ui and vi with weight equal to weighti.

A path from node start to node end is a sequence of nodes [z0, z1, z2, ..., zk] such that z0 = start and zk = end and there is an edge between zi and zi+1 where 0 <= i <= k-1.

The distance of a path is the sum of the weights on the edges of the path. Let distanceToLastNode(x) denote the shortest distance of a path between node n and node x. A restricted path is a path that also satisfies that distanceToLastNode(zi) > distanceToLastNode(zi+1) where 0 <= i <= k-1.

Return the number of restricted paths from node 1 to node n. Since that number may be too large, return it modulo 109 + 7.

 

Example 1:

Input: n = 5, edges = [[1,2,3],[1,3,3],[2,3,1],[1,4,2],[5,2,2],[3,5,1],[5,4,10]]
Output: 3
Explanation: Each circle contains the node number in black and its distanceToLastNode value in blue. The three restricted paths are:
1) 1 --> 2 --> 5
2) 1 --> 2 --> 3 --> 5
3) 1 --> 3 --> 5

Example 2:

Input: n = 7, edges = [[1,3,1],[4,1,2],[7,3,4],[2,5,3],[5,6,1],[6,7,2],[7,5,3],[2,6,4]]
Output: 1
Explanation: Each circle contains the node number in black and its distanceToLastNode value in blue. The only restricted path is 1 --> 3 --> 7.

 

Constraints:


Solution:

class Solution {
    int mod = (int) (1e9) + 7;
    Map<Integer, Long> memo = new HashMap();
    
    public int countRestrictedPaths(int n, int[][] edges) {
        Map<Integer, List<int[]>> graph = new HashMap();
        for (int[] edge : edges) {
            graph.putIfAbsent(edge[0], new ArrayList());
            graph.putIfAbsent(edge[1], new ArrayList());
            graph.get(edge[0]).add(new int[]{edge[1], edge[2]});
            graph.get(edge[1]).add(new int[]{edge[0], edge[2]});
        }
        PriorityQueue<int[]> queue = new PriorityQueue<int[]>((a, b) -> Integer.compare(a[1], b[1]));
        queue.offer(new int[]{n, 0});
        int[] dist = new int[n + 1];
        Arrays.fill(dist, Integer.MAX_VALUE);
        dist[n] = 0;
        while (!queue.isEmpty()) {
            int[] curr = queue.poll();
            // System.out.println(curr);
            int node = curr[0], distance = curr[1];
            for (int[] nei : graph.getOrDefault(node, new ArrayList<int[]>())) {
               int next = nei[0], w = nei[1];
                if (dist[next] > distance + w) {
                    dist[next] = distance + w;
                    queue.offer(new int[]{next, distance + w});
                }
            }
        }
        long res = dfs(1, graph, n, dist);
        return (int )(res % mod);
    }
    
    private long dfs(int curr, Map<Integer, List<int[]>> graph, int n, int[] dist) {
        if (curr == n) return 1;
        if (memo.get(curr) != null) return memo.get(curr);
        int currD = dist[curr];
        long res = 0;
        for (int[] nei : graph.getOrDefault(curr, new ArrayList<>())) {
            int next = nei[0];
            if (dist[next] < currD) {
                res = (res + dfs(next, graph, n, dist)) % mod;
            }
        }
        memo.put(curr, res);
        return res;
    }
}