用于计算树半径的Rust代码超时

时间:2016-04-23 16:45:59

标签: algorithm rust

我正在学习Rust,我正在为练习解决a coding puzzle。这个难题可以通过计算树半径来解决 我实施了this algorithm

use std::io;
use std::collections::*;

type Graph<T> = HashMap<T, HashSet<T>>;

macro_rules! print_err {
    ($($arg:tt)*) => (
        {
            use std::io::Write;
            writeln!(&mut ::std::io::stderr(), $($arg)*).ok();
        }
    )
}

macro_rules! parse_input {
    ($x:expr, $t:ident) => ($x.trim().parse::<$t>().unwrap())
}

macro_rules! add_dedge{
    ($g: ident, $u: ident, $v: ident) => (
        {
            let mut adj = match $g.get(&$u) {
                Some(adj) => adj.clone(),
                None => HashSet::new()
            };
            adj.insert($v);
            $g.insert($u, adj);
        }
    )
}

fn tree_rad(graph: Graph<i32>) -> i32 {
    match graph.keys().last() {
        Some(u) => {
            let (v, _) = farthest_node(graph.clone(), *u, *u);
            let (_, d) = farthest_node(graph.clone(), v, v);
            (d + 1) / 2
        }
        None => 0
    }
}

fn farthest_node(graph: Graph<i32>, u: i32, prev_u: i32) -> (i32, i32) {
    match graph.get(&u) {
        Some(adj) => {
            let mut v = u;
            let mut d = 0;
            for w in adj {
                if *w != prev_u {
                    let (x, e) = farthest_node(graph.clone(), *w, u);
                    if d <=  e {
                        v = x; d = e + 1;
                    }
                }
            }
            (v, d)
        }
        None => (u, 0)
    }
}

/**
 * Auto-generated code below aims at helping you parse
 * the standard input according to the problem statement.
 **/
fn main() {
    let mut input_line = String::new();
    io::stdin().read_line(&mut input_line).unwrap();
    let n = parse_input!(input_line, usize); // the number of adjacency relations
    let mut graph: HashMap<i32, HashSet<i32>> = HashMap::new();
    for _ in 0..n as usize {
        let mut input_line = String::new();
        io::stdin().read_line(&mut input_line).unwrap();
        let inputs = input_line.split(" ").collect::<Vec<_>>();
        let xi = parse_input!(inputs[0], i32); // the ID of a person which is adjacent to yi
        let yi = parse_input!(inputs[1], i32); // the ID of a person which is adjacent to xi
        add_dedge!(graph, xi, yi);
        add_dedge!(graph, yi, xi);
    }
    print_err!("{:?}", graph);
    let rad = tree_rad(graph);

    // The minimal amount of steps required to completely propagate the advertisement
    println!("{}", rad);
}

这段代码传递了一些小测试用例,但是对于一个大型测试用例来说却很快(测试4)。 我怀疑我没能正确实现算法, 所以我在Python中重新实现如下,它通过了所有测试用例。

import sys
import math
from sets import Set

def farthest_node(graph, u, pu):
    v, d = u, 0
    for w in graph[u]:
        if w != pu:
            (x, e) = farthest_node(graph, w, u)
            if d <= e:
                v, d = x, e + 1
    return (v, d)

# Auto-generated code below aims at helping you parse
# the standard input according to the problem statement.
graph = {}
n = int(raw_input())  # the number of adjacency relations
for i in xrange(n):
    # xi: the ID of a person which is adjacent to yi
    # yi: the ID of a person which is adjacent to xi
    xi, yi = [int(j) for j in raw_input().split()]
    if not xi in graph: graph[xi] = Set()
    graph[xi].add(yi)
    if not yi in graph: graph[yi] = Set()
    graph[yi].add(xi)    

# Write an action using print
# To debug: print >> sys.stderr, "Debug messages..."    
u = graph.keys()[0]
v, d = farthest_node(graph, u, u)
w, e = farthest_node(graph, v, v)

# The minimal amount of steps required to completely propagate the advertisement
print (e + 1) >> 1

我的代码有什么问题,我该如何解决?

0 个答案:

没有答案