查找树中两个节点之间的最大成本边

时间:2018-05-20 14:56:46

标签: c++ algorithm graph tree disjoint-sets

任务:

给定加权树图和一组节点对。对于来自集合的每对(u,v),我需要(有效地)找到(u,v)之间的最大边缘。

我的方法:

对每对(u,v)使用Tarjan算法,我们可以找到最低共同祖先LCA(u,v)= a。然后我们可以将(u,v)之间的路径表示为(u,a)和(v,a)pathes的联合,并将(u,v)之间的最大边缘表示为max(max_edge(u,a),max_edge(v,一))。

问题:

我试图在LCA算法中添加max_edge save,但还没有取得任何成功。

问题是:如何在LCA Tarjan算法中添加对最大边缘保存的支持?

我的尝试代码:

int max_cost;

int dsu_find(int node)
{
    if (node == parent[node])
        return node;
    max_cost = std::max(max_cost, edges[node][parent[node]]);
    return parent[node] = dsu_find(parent[node]);
}
void lca_dfs(int node, std::vector<std::list<int>> &query_list)
{
    dsu_make(node);
    ancestor[node] = node;
    marks[node] = true;
    for(auto neighbour:adjacency_list[node])
    {
        if (!marks[neighbour.first])
        {
            lca_dfs(neighbour.first,query_list);
            dsu_unite(node, neighbour.first);
            ancestor[dsu_find(node)] = node;
        }
    }
    for (auto query_node : query_list[node])
        if (marks[query_node])
        {
            dsu_find(query_node);
            dsu_find(node);
            printf("%d %d -> %lld\n", node, query_node,max_cost);
            query_list[query_node].remove(node);
            max_cost = 0;
        }

}

但它的工作不正确。

我的完整lca实现(没有不正确的修改):

std::vector<int> parent;
std::vector<int> rank;
std::vector<int> ancestor;
std::vector<bool> marks;
std::vector<std::list<std::pair<int, long long>>> adjacency_list;

void lca_dfs(int node, std::vector<std::list<int>> &query_list)
{
    dsu_make(node);
    ancestor[node] = node;
    marks[node] = true;
    for(auto neighbour:adjacency_list[node])
    {
        if (!marks[neighbour.first])
        {
            lca_dfs(neighbour.first,query_list);
            dsu_unite(node, neighbour.first);
            ancestor[dsu_find(node)] = node;
        }
    }
    for (auto query_node : query_list[node])
        if (marks[query_node])
        {
            printf("LCA of %d %d is %d\n", node, query_node,ancestor[dsu_find(query_node)]);
            query_list[query_node].remove(node);
        }

}
//dsu operations
void dsu_make(int node)
{
    parent[node] = node;
    rank[node] = 0;
}

int dsu_find(int node)
{
    return node == parent[node] ? node : parent[node]=dsu_find(parent[node]);

}
void dsu_unite(int node_1,int node_2)
{
    int root_1 = dsu_find(node_1), root_2 = dsu_find(node_2);
    if(root_1!=root_2)
    {
        if(rank[root_1] < rank[root_2])
            std::swap(root_1, root_2);
        parent[root_2] = root_1;
        if (rank[root_1] == rank[root_2])
            rank[root_1]++;
    }
}

*对于每个节点,query_list [node]由v组成,例如(node,v)是需要的对。 我明白,我使用双内存(只是为了更方便访问)。

我会感激任何提示或实施修复。

1 个答案:

答案 0 :(得分:-1)

Hope this Implementation works for you.


#include <bits/stdc++.h>
#include <algorithm>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define ass 1e18
#define MOD 1000000007
#define mp make_pair
#define pb push_back
#define pf push_front
#define pob pop_back
#define pof pop_front
#define fi first
#define se second
#define sz(x)   (ll)x.size()
#define present(c,x) ((c).find(x) != (c).end())
#define boost ios_base::sync_with_stdio(false);cin.tie(NULL);
#define debug(x) cout << #x << ": " << x << endl;
#define debug2(x,y) cout<<#x<<": "<< x<< ", "<< #y<< ": "<< y<< endl;
#define debug3(x,y,z) cout<<#x<<": "<< x<< ", "<< #y<< ": "<< y<<" "<<#z<<" : "<<z<< endl;
using namespace std;
typedef long long int ll;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds;   
#define ordered_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update>
pair<int,int> parent[200005],par[200005],dpp[200005][19];
vector<pair< pair<int,int>,pair<int,int> > >v;
vector<pair<int,int> >vv[200005];
int level[200005],vis[200005];
ll ans[200005];

void dfs(int x,int p)
{
    level[x]=level[p]+1;
    for(int i=0;i<sz(vv[x]);i++)
    {
        if(vv[x][i].fi!=p)
        {
            par[vv[x][i].fi].fi=x;
            par[vv[x][i].fi].se=vv[x][i].se;
            dfs(vv[x][i].fi,x);
        }
    }
}

void computeparent(int n)
{
    for(int i=1;i<=n;i++)
        dpp[i][0]=par[i];
    for(int j=1;j<=18;j++)
    {
        for(int i=1;i<=n;i++)
        {
            dpp[i][j].fi=dpp[dpp[i][j-1].fi][j-1].fi;
            dpp[i][j].se=max(dpp[i][j-1].se,dpp[dpp[i][j-1].fi][j-1].se);
        }
    }
}

int lca(int a,int b)
{
    if(level[b]>level[a])
        swap(a,b);
    int diff=level[a]-level[b];
    int m=ceil(log2(diff));
    for(int i=m;i>=0;i--)
    {
        if(diff&(1LL<<i))
            a=dpp[a][i].fi;
    }
    if(a==b)
        return a;
    for(int i=m;i>=0;i--)
    {
        if(dpp[a][i].fi!=dpp[b][i].fi)
        {
            a=dpp[a][i].fi;
            b=dpp[b][i].fi;
        }
    }
    return dpp[a][0].fi;
}

int lca2(int a,int b)
{
    int c=a,d=b;
    if(level[b]>level[a])
        swap(a,b);
    int i,maxi=0,diff=level[a]-level[b];
    int m=ceil(log2(diff));
    for(int i=m;i>=0;i--)
    {
        if(diff&(1LL<<i))
        {
            maxi=max(maxi,dpp[a][i].se);
            a=dpp[a][i].fi;
        }
    }
    return maxi;
}

int finds(int a)
{
    while(parent[a].fi!=a)
    {
        a=parent[a].fi;
    }
    return a;
}   

void unions(int x,int y)
{
    if(parent[x].se>parent[y].se)
        parent[y].fi=x;
    else if(parent[x].se<parent[y].se)
        parent[x].fi=y;
    else
    {
        parent[x].fi=y;
        parent[x].se++;
    }
}

void solve()
{
    ll sum=0;
    int n,m,i,a,b,c;
    cin>>n>>m;
    for(i=1;i<=n;i++)
    {
        parent[i].fi=i;
        parent[i].se=0;
    }
    for(i=0;i<m;i++)
    {
        cin>>a>>b>>c;
        v.pb(mp(mp(c,i),mp(a,b)));
    }
    sort(v.begin(),v.end());
    for(i=0;i<sz(v);i++)
    {
        int a=v[i].se.fi,b=v[i].se.se;
        int x=finds(a),y=finds(b);
        if(x!=y)
        {
            vv[a].pb(mp(b,v[i].fi.fi));
            vv[b].pb(mp(a,v[i].fi.fi));
            ans[v[i].fi.se]=1;
            unions(x,y);
            sum+=v[i].fi.fi;
        }
    }
    dfs(1,0);
    computeparent(n);
    for(i=0;i<m;i++)
    {
        if(ans[v[i].fi.se]==0)
        {
            int a=lca(v[i].se.fi,v[i].se.se);
            ans[v[i].fi.se]=sum+v[i].fi.fi-max(lca2(a,v[i].se.fi),lca2(a,v[i].se.se));
        }
        else
            ans[v[i].fi.se]=sum;
    }
    for(i=0;i<m;i++)
        cout<<ans[i]<<"\n";
}

int main()
{
    boost
    int t=1;
    //cin>>t;
    while(t--)
    {
        solve();
    }
    return 0;
}