Table of Contents

一、BFS 概念

摘自:BFS (图论)

BFS 全称是 Breadth First Search,中文名是宽度优先搜索,也叫广度优先搜索,是图上最基础、最重要的搜索算法之一。

所谓宽度优先。就是每次都尝试访问同一层的节点。 如果同一层都访问完了,再访问下一层。

这样做的结果是,BFS 算法找到的路径是从起点开始的 最短 合法路径。换言之,这条路径所包含的边数最小。

在 BFS 结束时,每个节点都是通过从起点到该点的最短路径访问的。

算法过程可以看做是图上火苗传播的过程:最开始只有起点着火了,在每一时刻,有火的节点都向它相邻的所有节点传播火苗。

对于一般的BFS 问题,BFS 的核心思想应该不难理解的,就是把一些问题抽象成图,从一个点开始,向四周开始扩散。一般来说,我们写 BFS 算法都是用「队列」这种数据结构,每次将一个节点周围的所有节点加入队列。

二、BFS 算法框架

要说框架的话,我们先举例一下 BFS 出现的常见场景好吧,问题的本质就是让你在一幅「图」中找到从起点 start 到终点 target 的最近距离

框架如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// 计算从起点 start 到终点 target 的最近距离
int BFS(Node start, Node target) {
    Queue<Node> q; // 核心数据结构
    Set<Node> visited; // 避免走回头路
    
    q.offer(start); // 将起点加入队列
    visited.add(start);
    int step = 0; // 记录扩散的步数

    while (q not empty) {
        int sz = q.size();
        /* 将当前队列中的所有节点向四周扩散 */
        for (int i = 0; i < sz; i++) {
            Node cur = q.poll();
            /* 划重点:这里判断是否到达终点 */
            if (cur is target)
                return step;
            /* 将 cur 的相邻节点加入队列 */
            for (Node x : cur.adj())
                if (x not in visited) {
                    q.offer(x);
                    visited.add(x);
                }
        }
        /* 划重点:更新步数在这里 */
        step++;
    }
}

三、例题

752. 打开转盘锁

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class Solution {
    public int openLock(String[] deadends, String target) {
        Queue<String> q = new ArrayDeque<String>();
        HashSet<String> set = new HashSet<>();
        HashSet<String> passwordSet = new HashSet<>();
        for (String s : deadends) {
            set.add(s);
        }
        q.add("0000");
        passwordSet.add("0000");
        int res = 0;
        while(!q.isEmpty()) {
            int sz = q.size();
            for (int i = 0; i < sz; i++) {
                String str = q.poll();
                if (set.contains(str))
                    continue;
                if (str.equals(target))
                    return res;
                for (int j = 0; j < 8; j++) {
                    String temp;
                    if (j % 2 == 0) {
                        temp = minusOne(str, j/2);
                    } else {
                        temp = plusOne(str, j/2);
                    }
                    if (!passwordSet.contains(temp)) {
                        passwordSet.add(temp);
                        q.add(temp);
                    }
                }
            }
            res++;
        }
        return -1;
    }

    private String minusOne(String str, int i) {
        char[] ch = str.toCharArray();
        if (ch[i] == '9')
            ch[i] = '0';
        else
            ch[i] += 1;
        return new String(ch);
    }

    private String plusOne(String str, int i) {
        char[] ch = str.toCharArray();
        if (ch[i] == '0')
            ch[i] = '9';
        else
            ch[i] -= 1;
        return new String(ch);
    }
}

双向BFS

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
int openLock(String[] deadends, String target) {
    Set<String> deads = new HashSet<>();
    for (String s : deadends) deads.add(s);
    // 用集合不用队列,可以快速判断元素是否存在
    Set<String> q1 = new HashSet<>();
    Set<String> q2 = new HashSet<>();
    Set<String> visited = new HashSet<>();
    
    int step = 0;
    q1.add("0000");
    q2.add(target);
    
    while (!q1.isEmpty() && !q2.isEmpty()) {
        // 哈希集合在遍历的过程中不能修改,用 temp 存储扩散结果
        Set<String> temp = new HashSet<>();

        /* 将 q1 中的所有节点向周围扩散 */
        for (String cur : q1) {
            /* 判断是否到达终点 */
            if (deads.contains(cur))
                continue;
            if (q2.contains(cur))
                return step;
            visited.add(cur);

            /* 将一个节点的未遍历相邻节点加入集合 */
            for (int j = 0; j < 4; j++) {
                String up = plusOne(cur, j);
                if (!visited.contains(up))
                    temp.add(up);
                String down = minusOne(cur, j);
                if (!visited.contains(down))
                    temp.add(down);
            }
        }
        /* 在这里增加步数 */
        step++;
        // temp 相当于 q1
        // 这里交换 q1 q2,下一轮 while 就是扩散 q2
        q1 = q2;
        q2 = temp;
    }
    return -1;
}

1091. 二进制矩阵中的最短路径

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Solution {
    private HashSet<Integer> set = new HashSet<>();
    private ArrayDeque<Integer> qOfx = new ArrayDeque<>();
    private ArrayDeque<Integer> qOfy = new ArrayDeque<>();
    public int shortestPathBinaryMatrix(int[][] grid) {
        if (grid == null || grid.length == 0) {
            return -1;
        }
        int m = grid.length;
        int n = grid[0].length;
        int res = 0;
        qOfx.add(0);
        qOfy.add(0);
        set.add(0);
        while (!qOfx.isEmpty()) {
            int sz = qOfx.size();
            res++;
            for (int i = 0; i < sz; i++) {
                int x = qOfx.poll();
                int y = qOfy.poll();
                if (grid[x][y] == 1)
                    continue;
                if (x == m-1 && y == n-1)
                    return res;
                search(x-1, y-1, m, n, grid);
                search(x-1, y, m, n, grid);
                search(x-1, y+1, m, n, grid);
                search(x, y-1, m, n, grid);
                search(x, y+1, m, n, grid);
                search(x+1, y-1, m, n, grid);
                search(x+1, y, m, n, grid);
                search(x+1, y+1, m, n, grid);
            }
        }
        return -1;
    }

    private void search(int x, int y, int m, int n, int[][] grid) {
        if (x >= 0 && y >= 0 && x < m && y < n && !set.contains(x*n+y)) {
            qOfx.add(x);
            qOfy.add(y);
            set.add(x*n+y);
        }
    }
}

上面的写法很low

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public int shortestPathBinaryMatrix(int[][] grids) {
        if (grids == null || grids.length == 0 || grids[0].length == 0) {
            return -1;
        }
        int[][] direction = {{1, -1}, {1, 0}, {1, 1}, {0, -1}, {0, 1}, {-1, -1}, {-1, 0}, {-1, 1}};
        int m = grids.length, n = grids[0].length;
        Queue<Pair<Integer, Integer>> queue = new LinkedList<>();
        queue.add(new Pair<>(0, 0));
        int pathLength = 0;
        while (!queue.isEmpty()) {
            int size = queue.size();
            pathLength++;
            while (size-- > 0) {
                Pair<Integer, Integer> cur = queue.poll();
                int cr = cur.getKey(), cc = cur.getValue();
                if (grids[cr][cc] == 1) {
                    continue;
                }
                if (cr == m - 1 && cc == n - 1) {
                    return pathLength;
                }
                grids[cr][cc] = 1; // 标记
                for (int[] d : direction) {
                    int nr = cr + d[0], nc = cc + d[1];
                    if (nr < 0 || nr >= m || nc < 0 || nc >= n) {
                        continue;
                    }
                    queue.add(new Pair<>(nr, nc));
                }
            }
        }
        return -1;
    }