文章目录
-
-
- 背景
- 扩展
- AC自动机
-
背景
最近参与了某业务系统的开发, 需要根据城市的名字简称,找到其官方的完整名称。比如云南的
大理
,其实其完整的名称是
大理白族自治州
。可以参考官方的行政区划,点这里。
通常来说,城市的简称,都是其完整名称的前缀。所以任务就转化成了:根据前缀,在一堆字符串中,找出满足条件的字符串。
Trie树可以派上用场,只需要对全国所有城市的完整名称,建一颗Trie树即可。这种前缀补全的功能,也有其他的一些经典应用,比如在命令行下,输入一个命令的前缀,或文件名的前缀,敲下
Tab
,能够进行自动补全,也是利用了Trie树这种数据结构。
Trie树的原理参考我之前的这篇文章
代码如下
import java.util.*;
/**
* @Author yogurtzzz
* @Date 2022/3/31 15:57
*
* 利用 Trie 树实现字符串的存储和快速查找
*
* 该类的功能是, 给定一个字符串数组 list , 给定一个字符串 s
* 返回在 list 中所有前缀为 s 的字符串
*
* 简单地说, 该类的功能是做字符串前缀匹配, 或者自动补全
*
* 比如, 输入 "内蒙", 能够查找到完整的名称为 "内蒙古自治区"
**/
public class TrieTree {
private Node root = new Node('0');
/**
* 添加一个字符串到 [集合] 中
* **/
public void addString(String s) {
Node cur = root;
for (int i = 0; i < s.length(); i++) {
char c = s.charAt(i);
cur = cur.getOrCreateSon(c);
}
cur.mark = true; // 在末尾打上标记
}
/**
* 返回 [集合] 中所有以 @param prefix 为前缀的字符串
* **/
public List<String> findStringFuzzy(String prefix) {
List<String> ans = new ArrayList<>();
char[] chars = prefix.toCharArray();
find(root, chars, 0, ans);
return ans;
}
/**
* 查找
* **/
private void find(Node cur, char[] prefix, int index, List<String> ans) {
int n = prefix.length;
if (index == n && cur != null) dfsAns(cur, prefix, ans); // 前缀匹配完成, 开始搜寻答案
else {
if (cur == null || cur.son == null || !cur.son.containsKey(prefix[index])) return;
find(cur.son.get(prefix[index]), prefix, index + 1, ans);
}
}
/**
* 返回这个节点下所有可能的字符串
* **/
private void dfsAns(Node node, char[] prefix, List<String> ans) {
StringBuilder sb = new StringBuilder(new String(prefix));
if (node.mark) ans.add(sb.toString());
if (node.son != null) {
for (Node s : node.son.values()) dfsFinal(s, sb, ans);
}
}
private void dfsFinal(Node cur, StringBuilder sb, List<String> ans) {
sb.append(cur.c);
if (cur.mark) ans.add(sb.toString());
if (cur.son != null) {
for (Node n : cur.son.values()) dfsFinal(n, sb, ans);
}
sb.deleteCharAt(sb.length() - 1); // 深搜恢复现场
}
static class Node {
char c;
boolean mark; // 标记是否为终点
Map<Character, Node> son; // 子节点
Node(char c) {
this.c = c;
mark = false;
}
void addSon(char c) {
if (son == null) son = new HashMap<>();
son.put(c, new Node(c));
}
Node getOrCreateSon(char c) {
boolean isSon = son != null && son.containsKey(c);
if (!isSon) addSon(c);
return son.get(c);
}
}
/**
* 测试代码
* **/
public static void main(String[] args) {
List<String> list = Arrays.asList("海南市", "海北市", "辽宁哈哈市", "辽宁市", "辽宁哈嘻嘻市", "绵阳市", "绵花市", "绵阳咩咩市", "内蒙古自治区");
TrieTree trie = new TrieTree();
list.forEach(trie::addString);
Scanner scanner = new Scanner(System.in);
String line = null;
while (!(line = scanner.nextLine()).equals("quit")) {
List<String> result = trie.findStringFuzzy(line);
result.forEach(s -> System.out.printf("%s,", s));
System.out.println();
}
}
}
扩展
上面的功能完成后,我回顾了一下,实现的就是一个前缀模糊查找。(查找有固定前缀的字符串)
在 mysql 中就相当于
like 's%'
,但是像后缀
like '%s'
和 中缀
like '%s%'
又要怎么实现呢?
对于后缀模糊查找,容易想到的一个方法是,对所有字符串,按照字符逆序建一颗 Trie 树即可(可以理解为先把字符串反转一下,再建 Trie)
但是对于
%s%
好像就无能为力了。
查了很多的资料,发现一个比较接近的解决方案是:AC自动机。但是这个方案还是不太适用于
%s%
这种场景。
我们知道,在字符串的模式匹配中,一般有两种情形:
- 单模式匹配:给定一个主串
和一个模式串s
,查找p
中s
出现的位置,方法有BF(暴力),KMPp
- 多模式匹配:给定一个主串
和多个模式串,查找s
中出现的模式串都有哪些,经典的应用场景是敏感词过滤。s
但是我们发现,这两种情形都是在一个主串中,根据一个或多个模式串,来查找主串中是否有匹配的部分,且都是精确匹配。
这和我们上面说的情形不太一样。上面的情形是:给定一堆字符串(主串),再给一个模式串(带通配符),在这一堆主串中,查找出所有满足模糊匹配条件的串。
其实是两个不太一样的问题。
由于对于
%s%
这样中缀模糊查找,仍然没找到解决方案,于是便暂时搁置,先研究一下利用AC自动机实现敏感词过滤,也挺有意思的。(有朋友说去研究一下ElasticSearch中的倒排索引的原理,手写一下,就能解决
%s%
这个问题,这个留在之后再去做了)
AC自动机
对于在一个主串中查找多个模式串,一个简单粗暴的做法是,将每个模式串拿出来,单独和主串用KMP做匹配。由于KMP的时间复杂度是 O ( m + n ) O(m+n) O(m+n) 。
关于KMP算法的原理,参考我之前的这篇文章
关于KMP的时间复杂度如何计算,参考思否的这篇文章。
关于KMP在回溯指针
j
时,如何保证不漏掉正确答案?-> 用反证法证明即可,参考这篇文章
大概是,在KMP的匹配过程中,当匹配到
i
位置时,若下一个位置不匹配,则
j
最多回退
i-1
次。极端的例子是:T=“aaaabaaaab”,P=“aaaaa”。则在对主串进行一次遍历时(假设主串长度为
m
),则最多会遍历
2m
次(实际,上界到不了
2m
),而我们构造
next
数组时,是对模式串进行了一次同样的遍历匹配操作(假设模式串长度为
n
),那么构造
next
数组时最多要
2n
次,加到一起,就是
2m + 2n
,所以复杂度是 O ( m + n ) O(m+n) O(m+n)。
那么如果对所有模式串,依次与主串做一次KMP,则开销是非常大的。假设主串长度为
m
,模式串平均长度为
n
,模式串个数为
N
。容易算得,这种方式的时间复杂度是 O ( N × ( m + n ) ) O(N×(m+n)) O(N×(m+n))。
这样肯定是不行的,那我们来看看用朴素的 Trie 树可以怎么做?我们先对多个模式串,建一颗 Trie 树,然后从主串
s
的第一个位置开始,在 Trie 树中进行查找,查找完毕后(匹配到或者没匹配到),则把主串中的指针
i
往后移一位,从第二个位置开始查找,再从第三个位置进行查找。如此以来就能找到
s
中包含的全部模式串。假设模式串的平均长度为
n
,主串长度为
m
,最坏情况下,在主串的每一个位置都要进行
n
次匹配,则时间复杂度是 O ( m × n ) O(m × n) O(m×n)
这种效率也无法达到我们的需求。
此时轮到AC自动机登场了。
多模式匹配中,朴素的 Trie 和 AC自动机的关系;就和单模式匹配中,BF法 和 KMP 的关系一样。
AC自动机仅仅是在 Trie 树上运用了 KMP 的思想,增加了一个类似
next
数组的东西,叫做
fail
指针。
整个匹配的过程,主串只会被扫描一次,而不会不断回退,当在主串的某个位置匹配失败后,根据 Trie 树上该节点的父节点的
fail
指针,找到 Trie 树上下一个要匹配的节点,继续进行匹配即可。
可以这样简单的理解
fail
指针,假设Trie树上的一个节点
p
,其到根节点
root
,构成的字符串为
abc
,节点
p
的
fail
指针,指向节点
q
,而节点
q
到
root
节点,构成的字符串是
bc
,这两个节点代表的字符串,存在一个公共的前后缀(这和KMP中next数组的含义一样)。
在AC自动机匹配时,假设主串为
abcd
,则在匹配第四个位置
d
时,Trie树上是到节点
pc
,此时匹配失败,此时查看
pc
节点的父节点
p
的
fail
指针为
q
,则继续查看
q
的子节点中有没有字符为
d
的,发现有,则成功匹配到模式串
bcd
。
(图片来源极客时间)
由于每个节点的
fail
指针,都一定依赖于其上层的节点,那么我们构造
fail
指针的时候,需要从根节点往下,一层一层构造,所以需要使用层序遍历BFS。根节点
root
的
fail
指针为
null
,第一层节点的
fail
指针为
root
。因为第一个位置不存在公共前后缀,需要从头开始匹配。(因为公共前后缀的长度一定要小于当前长度,才能构成公共前后缀,至少从第二个位置开始,才可能存在公共前后缀)。
比如
a
,是没有公共前后缀的,
aa
的公共前后缀长度为1。
我们用BFS,每次处理一个节点,并为这个节点的所有子节点,填充
fail
指针。
假设当前节点为
p
,它有一个子节点
s
,那么填充
s
的
fail
指针的过程如下:
- 取
,看这个节点的子节点中,有没有和p.fail
节点字符相同的,若有,假设为s
,则填充s'
s.fail = s’
- 若
的子节点中,没有和p.fail
字符相同的节点,则更新s
,继续上面的判断,直到p = p.fail
。p = null
- 若
,说明在p = null
节点的子节点中,都没有发现和root
节点字符相同的节点(才会在s
时将p = p.fail
更新为p
),那么在null
节点就不存在公共前后缀,需要从头匹配,所以此时s
s.fail = root
代码如下:
import java.util.*;
/**
* @Author yogurtzzz
* @Date 2022/4/26 10:30
*
* AC 自动机
**/
public class AcMachine {
private AcNode root;
/**
* 用一个敏感词集合构建一个 AC 自动机
* **/
public AcMachine(Set<String> words) {
root = new AcNode('0');
buildTrieTree(words); // 先把 Trie 树建起来
fillFailPointer(); // 再填充每个节点的 fail 指针
}
/**
* 进行敏感词过滤
* @param s 原字符串
* @return 脱敏后的字符串
* **/
public String filterWithSensitiveWords(String s) {
char[] cs = s.toCharArray();
List<Integer> begins = new ArrayList<>();
List<Integer> lens = new ArrayList<>();
// 开始查找
AcNode cur = root;
for (int j = 0; j < cs.length; j++) {
char c = cs[j];
if (cur.hasChild(c)) {
cur = cur.getOrCreateChild(c);
if (cur.len != -1) {
// 该节点为结束节点
lens.add(cur.len);
begins.add(j - cur.len + 1);
}
} else {
while (cur.fail != null) {
cur = cur.fail;
if (cur.hasChild(c)) {
cur = cur.getOrCreateChild(c);
if (cur.len != -1) {
lens.add(cur.len);
begins.add(j - cur.len + 1);
}
break;
}
}
}
}
// 查找出所有敏感词出现的起始位置和长度后, 对原字符串进行敏感词屏蔽
StringBuilder sb = new StringBuilder();
int i = 0;
for (int j = 0; j < begins.size(); j++) {
int begin = begins.get(j);
int len = lens.get(j);
for (; i < begin; i++) sb.append(cs[i]);
for (; i < begin + len; i++) sb.append('*');
}
while (i < cs.length) {
sb.append(cs[i]);
i++;
}
return sb.toString();
}
private void buildTrieTree(Set<String> words) {
words.forEach(this::addWord);
}
private void addWord(String word) {
AcNode p = root;
char[] cs = word.toCharArray();
for (char c : cs) {
p = p.getOrCreateChild(c);
}
p.len = cs.length;
}
private void fillFailPointer() {
// BFS 层序遍历填充 fail 指针
Queue<AcNode> q = new LinkedList<>();
q.offer(root);
while (!q.isEmpty()) {
AcNode x = q.poll();
if (x.children == null) continue;
// 处理当前节点的子节点的fail指针
x.children.values().forEach(node -> {
q.offer(node); // 加入队列
if (x == root) node.fail = root;
else {
AcNode last = x.fail;
while (last != null) {
if (last.hasChild(node.c)) {
node.fail = last.children.get(node.c);
break;
} else {
last = last.fail;
}
}
if (last == null) node.fail = root;
}
});
}
}
private static class AcNode {
private char c;
private int len = -1; // 如果是结尾节点, 记录长度
private Map<Character, AcNode> children;
private AcNode fail;
AcNode(char c) {
this.c = c;
}
AcNode getOrCreateChild(char c) {
if (children == null) children = new HashMap<>();
if (children.containsKey(c)) return children.get(c);
AcNode newNode = new AcNode(c);
children.put(c, newNode);
return newNode;
}
boolean hasChild(char c) {
return children != null && children.containsKey(c);
}
}
public static void main(String[] args) {
testEn();
System.out.println();
testZh();
}
private static void testEn() {
Set<String> wordSet = new HashSet<>(Arrays.asList("she", "bleed", "dog", "doggy", "hurt"));
AcMachine acMachine = new AcMachine(wordSet);
String s = "A girl is bitten by a doggy, and she is badly hurt, bleeding along the road";
System.out.printf("sensitive words : ");
for (String x : wordSet) System.out.printf("%s ", x);
System.out.println();
System.out.println(s);
String filteredString = acMachine.filterWithSensitiveWords(s);
System.out.println(filteredString);
}
private static void testZh() {
Set<String> wordSet = new HashSet<>(Arrays.asList("血", "政府", "暴力"));
AcMachine acMachine = new AcMachine(wordSet);
System.out.printf("sensitive words : ");
for (String x : wordSet) System.out.printf("%s ", x);
System.out.println();
String s = "美国政府呼吁民众不要暴力, 因为暴力会造成流血";
System.out.println(s);
System.out.println(acMachine.filterWithSensitiveWords(s));
}
}
关于AC自动机的原理,参考极客时间的这篇文章,以及这篇文章