Two-Sat (2-SAT)

#include<bits/stdc++.h>
#define ll long long

// NOTE: 
//      vertex-index must be 1-BASED
// Usage:
//      SAT2 sat;
//      sat.init(n);
//      sat.addOR/addImply etc
//      bool feasible = sat.computeSolution();
//      result is sat.sol[1..n]. 0 is false, 1 is true.
struct SAT2 
{
    
    int n;
    vector<int> *imply;
    int *comp;
    int *sol;
        
    void init(int n_) {
        n = n_;
        imply = new vector<int>[n*2+1];
        comp = new int[n*2+1];
        sol = new int[n*2+1];
        fill(sol, sol + n*2 + 1, -1);
    }
    
    int mask(int i) {
        if (i >= 0) return i;
        return (-i)+n;
    }
    
    void addImply(int u, int v) {
        imply[mask(u)].push_back(mask(v));
    }
    
    void addOR(int u , int v) {
        imply[mask(-u)].push_back(mask(v));
        imply[mask(-v)].push_back(mask(u));     
    }
    
    void addNAND(int u, int v) {
        imply[mask(u)].push_back(mask(-v));
        imply[mask(v)].push_back(mask(-u));     
    }
    
    void addTRUE(int u) {
        imply[mask(-u)].push_back(mask(u));
    }
    
    void addNOT(int u) {
        addTRUE(-u);
    }
    
    void addXOR(int u, int v) {
        addOR(u, v);
        addNAND(u, v);
    }
    
    void addAND(int u, int v) {
        addTRUE(u);
        addTRUE(v);
    }
    
    void addNOR(int u, int v) {
        addTRUE(-u);
        addTRUE(-v);
    }
    
    void addXNOR(int u, int v) {
        addXOR(u, -v);
    }
    
    bool computeSolution() {
        
        // First, check if this problem has feasible solution
        bool feasible = true;
        
        Tarjan tarjan;
        tarjan.init(n*2+1);
        for (int u = 1; u <= n*2; ++u)
            for (auto v : imply[u])
                tarjan.addEdge(u, v);
        vector<vector<int> > sccs = tarjan.findSCC();
        
        fill(comp, comp + n*2+1, -1);
        
        for (int i = 0; i < sccs.size(); ++i) {
            for (int j = 0; j < sccs[i].size(); ++j)
                comp[sccs[i][j]] = i+1;
        }
        
        for (int i = 1; i <= n; ++i)
            if (comp[mask(i)] == comp[mask(-i)] && comp[mask(i)] > 0)
                feasible = false;
        
        if (!feasible)
            return feasible;
        // = End checking for feasible solution
        
        // Find and return solution
        for (int i = 1; i <= n; ++i) {
            sol[mask(i)] = comp[mask(i)] < comp[mask(-i)];
            sol[mask(-i)] = !sol[mask(i)];
        }
        return feasible;
                
        // = End return solution
        
    }
        
};

Leave a Reply