洛谷 P4061 [Code+#1]大吉大利,晚上吃鸡!

题目大概是给一张无向图然后给你起点和终点让你找到一个点对 $(a,b)$ 满足 $a$ 在最短路上并且 $a$ 和 $b$ 不在任意一条最短路径上

刚开始拿到题读错题以为边 $(a,b)\in E_{G}$ 然后光荣爆零…

思路看起来和其他题解不同,首先跑两边最短路确定在最短路上的点,这个时候我们顺便记录一下经过这个点的的最短路径中点的个数,这个怎么处理呢,一个 $ bitset$ 暴力莽就完事儿了,最后再来找到个数( $bitset$ 去重也挺方便的)

实际上找到经过这个点的最短路径的点的个数并不是一次 bfs 就能完成(至少我太菜了想不到),我的做法是两次 dij 的时候这个点只要进队就刷一次(如果相等就或上),最后统计答案

我们记第一次 dfs 得到的 $ind1[t]$ 表示从 $t$ 出发到 $ed$ 能经过的最短路的点的个数(此时图已定向),$ind2[t]$ 表示从 $t$ 出发到 $st$ 经过的最短路的点的个数 (图已定向)

然后考虑答案,他们都用了一个 $dp$ (奇怪的是我并没用),只需要大论一下答案即可:

1.在最短路上的,一定可以和所有不在最短路上的点结合,即$ans1+=n-ind1[t]-ind2[t]-1$

2.不在最短路上的,只能找最短路结合,即 $ans2+=n-tot$ ( $tot $为$st$ 到 $ed$ 的所有最短路上的点的个数)

最后统计答案 $ans=(ans1+ans2)/2$ (因为每个点被算了两次)

对了,如果$st,ed$不连通的话,那就说输出答案 $\frac{n*(n-1)}{2}$ (鬼知道为啥这个要给 $45$ 分)

另外吐槽一下数据是真的水,我的代码过了几个 $hack$ 数据就没管啥了,如果有锅记得提醒我

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<touwenjian.h>

#define int long long

using namespace std;

const int maxn=50050;
int vis[maxn],st,ed,d1[maxn],d2[maxn];
int n,m,size;
int bj[maxn],head[maxn];
int ans,ans1,ans2;
int ind1[maxn],ind2[maxn];
int dep[maxn],szft[maxn];
int tot;

bitset <50050> lk[maxn];

struct edge{
int next,to,dis;
}e[maxn*2],looker[maxn*2];

inline void addedge(int next,int to,int dis)
{
e[++size].to=to;
e[size].dis=dis;
e[size].next=head[next];
head[next]=size;
}

inline void dij(int st,int *dis,int *ind)
{
priority_queue < pair< int , int > > q;
q.push(make_pair(0,st));
memset(dis,0x3f,sizeof(d1));
memset(vis,0,sizeof(vis));
dis[st]=0;
for(int i=1;i<=n;i++) lk[i].reset();
while(!q.empty())
{
int t=q.top().second;
q.pop();
if(vis[t]) continue;
vis[t]=1;
int i,j,k;
for(i=head[t];i;i=e[i].next)
{
j=e[i].to;
k=e[i].dis;
if(dis[t]+k<dis[j])
{
lk[j].reset();
lk[j]=lk[t];
lk[j].set(t);
dis[j]=dis[t]+k;
q.push(make_pair(-dis[j],j));
}
else if(dis[t]+k==dis[j])
{
lk[j]|=lk[t];
lk[j].set(t);
}
}
}
for(int i=1;i<=n;i++) ind[i]=lk[i].count();
}

signed main()
{
ios::sync_with_stdio(false);
register int i,j;
cin>>n>>m>>st>>ed;
int t1,t2,t3;
for(i=1;i<=m;i++)
{
cin>>t1>>t2>>t3;
addedge(t2,t1,t3);
addedge(t1,t2,t3);
}
dij(st,d1,ind1);
dij(ed,d2,ind2);
if(d1[ed]==0x3f3f3f3f3f3f3f3f)
{
cout<<(n*(n-1))/2<<endl;
return 0;
}
for(i=1;i<=n;i++) if(d1[i]+d2[i]==d1[ed]) bj[i]=1,tot++;
for(i=1;i<=n;i++)
if(bj[i]) ans1+=(n-ind1[i]-ind2[i]-1);
else ans2+=tot;
ans=(ans1+ans2)/2;
cout<<ans<<endl;
return 0;
}