我正在学习Rust,我正在为练习解决a coding puzzle。这个难题可以通过计算树半径来解决 我实施了this algorithm:
use std::io;
use std::collections::*;
type Graph<T> = HashMap<T, HashSet<T>>;
macro_rules! print_err {
($($arg:tt)*) => (
{
use std::io::Write;
writeln!(&mut ::std::io::stderr(), $($arg)*).ok();
}
)
}
macro_rules! parse_input {
($x:expr, $t:ident) => ($x.trim().parse::<$t>().unwrap())
}
macro_rules! add_dedge{
($g: ident, $u: ident, $v: ident) => (
{
let mut adj = match $g.get(&$u) {
Some(adj) => adj.clone(),
None => HashSet::new()
};
adj.insert($v);
$g.insert($u, adj);
}
)
}
fn tree_rad(graph: Graph<i32>) -> i32 {
match graph.keys().last() {
Some(u) => {
let (v, _) = farthest_node(graph.clone(), *u, *u);
let (_, d) = farthest_node(graph.clone(), v, v);
(d + 1) / 2
}
None => 0
}
}
fn farthest_node(graph: Graph<i32>, u: i32, prev_u: i32) -> (i32, i32) {
match graph.get(&u) {
Some(adj) => {
let mut v = u;
let mut d = 0;
for w in adj {
if *w != prev_u {
let (x, e) = farthest_node(graph.clone(), *w, u);
if d <= e {
v = x; d = e + 1;
}
}
}
(v, d)
}
None => (u, 0)
}
}
/**
* Auto-generated code below aims at helping you parse
* the standard input according to the problem statement.
**/
fn main() {
let mut input_line = String::new();
io::stdin().read_line(&mut input_line).unwrap();
let n = parse_input!(input_line, usize); // the number of adjacency relations
let mut graph: HashMap<i32, HashSet<i32>> = HashMap::new();
for _ in 0..n as usize {
let mut input_line = String::new();
io::stdin().read_line(&mut input_line).unwrap();
let inputs = input_line.split(" ").collect::<Vec<_>>();
let xi = parse_input!(inputs[0], i32); // the ID of a person which is adjacent to yi
let yi = parse_input!(inputs[1], i32); // the ID of a person which is adjacent to xi
add_dedge!(graph, xi, yi);
add_dedge!(graph, yi, xi);
}
print_err!("{:?}", graph);
let rad = tree_rad(graph);
// The minimal amount of steps required to completely propagate the advertisement
println!("{}", rad);
}
这段代码传递了一些小测试用例,但是对于一个大型测试用例来说却很快(测试4)。 我怀疑我没能正确实现算法, 所以我在Python中重新实现如下,它通过了所有测试用例。
import sys
import math
from sets import Set
def farthest_node(graph, u, pu):
v, d = u, 0
for w in graph[u]:
if w != pu:
(x, e) = farthest_node(graph, w, u)
if d <= e:
v, d = x, e + 1
return (v, d)
# Auto-generated code below aims at helping you parse
# the standard input according to the problem statement.
graph = {}
n = int(raw_input()) # the number of adjacency relations
for i in xrange(n):
# xi: the ID of a person which is adjacent to yi
# yi: the ID of a person which is adjacent to xi
xi, yi = [int(j) for j in raw_input().split()]
if not xi in graph: graph[xi] = Set()
graph[xi].add(yi)
if not yi in graph: graph[yi] = Set()
graph[yi].add(xi)
# Write an action using print
# To debug: print >> sys.stderr, "Debug messages..."
u = graph.keys()[0]
v, d = farthest_node(graph, u, u)
w, e = farthest_node(graph, v, v)
# The minimal amount of steps required to completely propagate the advertisement
print (e + 1) >> 1
我的代码有什么问题,我该如何解决?