树的直径两种求法
- 贪心 。任取一点 a 作为根节点并找到树上距离点 a 最远的点 b ,然后再以点 b 为根节点,找到树上距离点 b 最远的点 c 。则点 b 与点 c 之间的距离就是树的直径。
- 树形DP 。任取一点作为根节点,找到每个节点向下最深和次深的两条路径,对于每个点的两个路径之和维护最大值
smax
,最终树的直径就是 smax
。
- 时间复杂度 :两种做法都是 $O(n)$ 。
法一:树形DP
C++ 112ms
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cstdio>
#include <cmath>
#include <unordered_map>
#include <map>
#include <queue>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
#define x first
#define y second
const int N = 100010, M = N * 2;
int n;
LL ans;
int h[N], ne[M], e[M], w[M], idx;
void add(int a, int b, int c)
{
w[idx] = c, e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
LL get_post(LL u)
{
return u * 10 + (1 + u) * u / 2;
}
LL dfs(int s, int f)
{
LL d1 = 0, d2 = 0;
for (int i = h[s]; ~i; i = ne[i])
{
int j = e[i];
if (j == f) continue;
LL d = dfs(j, s) + w[i];
if (d > d1)
{
swap(d1, d2);
swap(d1, d);
}
else d2 = max(d2, d);
}
ans = max(ans, d1 + d2);
return d1;
}
int main()
{
cin >> n;
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
dfs(1, -1);
cout << get_post(ans) << endl;
return 0;
}
法二:BFS
C++ 177ms 如果用数组模拟 queue
用时会更少
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cstdio>
#include <cmath>
#include <unordered_map>
#include <map>
#include <queue>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
#define x first
#define y second
const int N = 100010, M = N * 2;
int n, ans;
int h[N], ne[M], e[M], w[M], idx;
LL dist[N];
bool st[N];
void add(int a, int b, int c)
{
w[idx] = c, e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void bfs(int u)
{
queue<int> q;
q.push(u);
memset(st, false, sizeof st);
st[u] = true;
dist[u] = 0;
while (q.size())
{
int v = q.front();
q.pop();
for (int i = h[v]; ~i; i = ne[i])
{
int j = e[i];
if (st[j]) continue;
st[j] = true;
dist[j] = dist[v] + w[i];
q.push(j);
}
}
return ;
}
LL get_post(LL u)
{
return u * 10 + (1 + u) * u / 2;
}
int main()
{
cin >> n;
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
bfs(1);
LL st = 1;
for (int i = 2; i <= n; i ++ )
{
if (dist[i] > dist[st]) st = i;
}
bfs(st);
st = 0;
for (int i = 1; i <= n; i ++ )
{
st = max(st, dist[i]);
}
cout << get_post(st);
return 0;
}
法三:DFS
C++ 119ms
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cstdio>
#include <cmath>
#include <unordered_map>
#include <map>
#include <queue>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
#define x first
#define y second
const int N = 100010, M = N * 2;
int n, ans;
int h[N], ne[M], e[M], w[M], idx;
LL dist[N];
void add(int a, int b, int c)
{
w[idx] = c, e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void dfs(int s, int f, LL distance)
{
dist[s] = distance;
for (int i = h[s]; ~i; i = ne[i])
{
int j = e[i];
if (j == f) continue;
dfs(j, s, dist[s] + w[i]);
}
return ;
}
LL get_post(LL u)
{
return u * 10 + (1 + u) * u / 2;
}
int main()
{
cin >> n;
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
dfs(1, -1, 0);
LL st = 1;
for (int i = 2; i <= n; i ++ )
{
if (dist[i] > dist[st]) st = i;
}
dfs(st, -1, 0);
st = 0;
for (int i = 1; i <= n; i ++ )
{
st = max(st, dist[i]);
}
cout << get_post(st);
return 0;
}