Aho Corasick

#include<bits/stdc++.h>
#define ll long long

// Usage: 
//      AhoCorasick aho;
//      aho.addString(s); (multiple-times)
//      aho.compile();
//      implement compute() method and use it.
struct AhoCorasick
{
    static const int alphabet_size = 26;
    
    struct Nodie {
        int ch;
        int pa;
        int next[alphabet_size];
        int leaf;
        int link;
        int go[alphabet_size];
        // ==== more variables for specific problem
        vector<int> end_here;
        // ====
        
        Nodie (int ch_ = '$', int pa_ = 0) {
            ch = ch_;
            pa = pa_;
            fill(begin(next), end(next), -1);
            leaf = 0;
            link = -1;
            fill(begin(go), end(go), -1);
            // ==== init additional variables
            // ====
        }
        
    };
    
    vector<Nodie> trie;
    
    AhoCorasick() {
        trie.emplace_back();
    }
    
    void clear() {
        trie.clear();
        trie.emplace_back();
    }
    
    void addString(string s, int s_id) {
        
        int u = 0; // 0 is root of trie
        
        for (int i = 0; i < s.size(); ++i) {
            int ch = s[i]-'a';
            if (trie[u].next[ch] == -1) {
                trie[u].next[ch] = trie.size();
                trie.emplace_back(ch, u);
            }
            
            u = trie[u].next[ch];
        }
        
        trie[u].leaf++;
        // ==== actions for specific problem
        trie[u].end_here.emplace_back(s_id);
        // ====
    }
    
    int getLink(int u) { 
        
        if (trie[u].link == -1) {
            if (u == 0 || trie[u].pa == 0)
                trie[u].link = 0;
            else {
                trie[u].link = getGo(getLink(trie[u].pa), trie[u].ch);
                //==== actions for specific problem
                getLink(trie[u].link);
                trie[u].end_here.insert(trie[u].end_here.end(), trie[trie[u].link].end_here.begin(), trie[trie[u].link].end_here.end());
                //====
            }
        }
        
        return trie[u].link;
    }
    
    int getGo(int u, int ch) { // given node u and char ch. 
                               // return node v so that trie[v].ch = ch and string-of-v[:-1] is longest suffix of u.
        
        if (trie[u].go[ch] ==-1) {
            
            if (trie[u].next[ch] != -1)
                trie[u].go[ch] = trie[u].next[ch];
            else
                trie[u].go[ch] = (u == 0) ? 0 : getGo(getLink(u), ch);          
        }
        
        return trie[u].go[ch];
        
    }
    
    void compile() {
        
        for (int i = 1; i < trie.size(); ++i) {
            getLink(i);
            for (int j = 0; j < alphabet_size; ++j)
                getGo(i, j);
        }
        
    }
        
    //==== more functions for specific problem
    void compute(string s, int *cnt) {
        
        int u = 0;
        
        for (int i = 0; i < s.size(); ++i) {
            int ch = s[i] - 'a';
            u = getGo(u, ch);
            
            for (auto x : trie[u].end_here)
                cnt[x]++;
        }
    }
    //====
    
};

Leave a Reply