如何在客户端连接到tokio-proto服务器时显示欢迎消息/横幅?

时间:2017-10-19 13:00:00

标签: tcp rust rust-tokio

SMTP服务器应在建立连接时显示欢迎消息(220 service ready),这是客户端开始发送命令的信号。这似乎与tokio-proto的请求 - 响应范例相冲突。

我可以想象协议可以完全颠倒,例如服务器发送请求和客户端响应(弃用TURN),但目前我只关注连接时的欢迎消息,即横幅。之后,客户请求=>服务器响应将得到维护。

我一直试图想出把它挂钩的地方,但w_±_± w_±_± 1 w_˚ w_˚ 1 w_ฌ w_ฌ 1 w_ℓ w_ℓ 1 w_㎡ w_㎡ 1 bind_server对我来说是超级神秘的。我需要实施传输吗?

我在编解码器的bind_transport方法中有这个。问题是除非有可用于解码哪种有意义的数据,否则不会调用decode方法。我希望有一些连接初始化方法可以挂钩,但我什么都没找到。

decode

我的work-in-progress study project is on GitHub,我也opened an issue with tokio-proto

1 个答案:

答案 0 :(得分:0)

实现我自己的有状态传输装饰器(SmtpConnectTransport)就可以了。它将在初始化时注入给定的帧。我想通过将initframe类型作为参数可以将其制作成通用解决方案。除了解析和序列化之外,编解码器最终不必做任何异常的事情。

在连接时框架正确,服务可以生成所需的欢迎消息或横幅。我已将SmtpCommand::Connect中的本地和远程套接字地址包含在服务中,因为它将用于垃圾邮件检测。

我的预感是正确的,但是锻炼确实感觉就像生锈的金属磨削一样:D我现在很开心samotop is coming together。这是一些代码:

use std::io;
use std::str;
use bytes::Bytes;
use model::response::SmtpReply;
use model::request::SmtpCommand;
use protocol::codec::SmtpCodec;
use tokio_proto::streaming::pipeline::{Frame, Transport, ServerProto};
use tokio_io::codec::Framed;
use futures::{Stream, Sink, StartSend, Poll, Async};
use protocol::parser::SmtpParser;
use protocol::writer::SmtpSerializer;

type Error = io::Error;
type CmdFrame = Frame<SmtpCommand, Bytes, Error>;
type RplFrame = Frame<SmtpReply, (), Error>;

pub struct SmtpProto;

impl<TIO: NetSocket + 'static> ServerProto<TIO> for SmtpProto {
    type Error = Error;
    type Request = SmtpCommand;
    type RequestBody = Bytes;
    type Response = SmtpReply;
    type ResponseBody = ();
    type Transport = SmtpConnectTransport<Framed<TIO, SmtpCodec<'static>>>;
    type BindTransport = io::Result<Self::Transport>;

    fn bind_transport(&self, io: TIO) -> Self::BindTransport {
        // save local and remote socket address so we can use it as the first frame
        let initframe = Frame::Message {
            body: false,
            message: SmtpCommand::Connect {
                local_addr: io.local_addr().ok(),
                peer_addr: io.peer_addr().ok(),
            },
        };
        let codec = SmtpCodec::new(
            SmtpParser::session_parser(),
            SmtpSerializer::answer_serializer(),
        );
        let upstream = io.framed(codec);
        let transport = SmtpConnectTransport::new(upstream, initframe);
        Ok(transport)
    }
}

pub struct SmtpConnectTransport<TT> {
    initframe: Option<CmdFrame>,
    upstream: TT,
}

impl<TT> SmtpConnectTransport<TT> {
    pub fn new(upstream: TT, initframe: CmdFrame) -> Self {
        Self {
            upstream,
            initframe: Some(initframe),
        }
    }
}

impl<TT> Stream for SmtpConnectTransport<TT>
where
    TT: 'static + Stream<Error = Error, Item = CmdFrame>,
{
    type Error = Error;
    type Item = CmdFrame;

    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
        match self.initframe.take() {
            Some(frame) => {
                println!("transport initializing");
                Ok(Async::Ready(Some(frame)))
            }
            None => self.upstream.poll(),
        }
    }
}

impl<TT> Sink for SmtpConnectTransport<TT>
where
    TT: 'static + Sink<SinkError = Error, SinkItem = RplFrame>,
{
    type SinkError = Error;
    type SinkItem = RplFrame;

    fn start_send(&mut self, request: Self::SinkItem) -> StartSend<Self::SinkItem, io::Error> {
        self.upstream.start_send(request)
    }

    fn poll_complete(&mut self) -> Poll<(), io::Error> {
        self.upstream.poll_complete()
    }

    fn close(&mut self) -> Poll<(), io::Error> {
        self.upstream.close()
    }
}

impl<TT> Transport for SmtpConnectTransport<TT>
where
    TT: 'static,
    TT: Stream<Error = Error, Item = CmdFrame>,
    TT: Sink<SinkError = Error, SinkItem = RplFrame>,
{
}


pub trait NetSocket: AsyncRead + AsyncWrite {
    fn peer_addr(&self) -> Result<SocketAddr>;
    fn local_addr(&self) -> Result<SocketAddr>;
}

impl NetSocket for TcpStream {
    fn peer_addr(&self) -> Result<SocketAddr> {
        self.peer_addr()
    }
    fn local_addr(&self) -> Result<SocketAddr> {
        self.local_addr()
    }
}