Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/neqo-crypto/src/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 10 kB image not shown  

Quelle  agentio.rs   Sprache: unbekannt

 
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::{
    cmp::min,
    fmt, mem,
    ops::Deref,
    os::raw::{c_uint, c_void},
    pin::Pin,
    ptr::{null, null_mut},
};

use neqo_common::{hex, hex_with_len, qtrace};

use crate::{
    constants::{ContentType, Epoch},
    err::{nspr, Error, PR_SetError, Res},
    null_safe_slice, prio, ssl,
};

// Alias common types.
type PrFd = *mut prio::PRFileDesc;
type PrStatus = prio::PRStatus::Type;
const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS;
const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE;

/// Convert a pinned, boxed object into a void pointer.
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
}

/// A slice of the output.
#[derive(Default)]
pub struct Record {
    pub epoch: Epoch,
    pub ct: ContentType,
    pub data: Vec<u8>,
}

impl Record {
    #[must_use]
    pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self {
        Self {
            epoch,
            ct,
            data: data.to_vec(),
        }
    }

    // Shoves this record into the socket, returns true if blocked.
    pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> {
        qtrace!("write {:?}", self);
        unsafe {
            ssl::SSL_RecordLayerData(
                fd,
                self.epoch,
                ssl::SSLContentType::Type::from(self.ct),
                self.data.as_ptr(),
                c_uint::try_from(self.data.len())?,
            )
        }
    }
}

impl fmt::Debug for Record {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(
            f,
            "Record {:?}:{:?} {}",
            self.epoch,
            self.ct,
            hex_with_len(&self.data[..])
        )
    }
}

#[derive(Debug, Default)]
pub struct RecordList {
    records: Vec<Record>,
}

impl RecordList {
    fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) {
        self.records.push(Record::new(epoch, ct, data));
    }

    unsafe extern "C" fn ingest(
        _fd: *mut ssl::PRFileDesc,
        epoch: ssl::PRUint16,
        ct: ssl::SSLContentType::Type,
        data: *const ssl::PRUint8,
        len: c_uint,
        arg: *mut c_void,
    ) -> ssl::SECStatus {
        let records = arg.cast::<Self>().as_mut().unwrap();

        let slice = null_safe_slice(data, len);
        records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
        ssl::SECSuccess
    }

    /// Create a new record list.
    pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> {
        let mut records = Box::pin(Self::default());
        unsafe {
            ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records))
        }?;
        Ok(records)
    }
}

impl Deref for RecordList {
    type Target = Vec<Record>;
    #[must_use]
    fn deref(&self) -> &Vec<Record> {
        &self.records
    }
}

pub struct RecordListIter(std::vec::IntoIter<Record>);

impl Iterator for RecordListIter {
    type Item = Record;
    fn next(&mut self) -> Option<Self::Item> {
        self.0.next()
    }
}

impl IntoIterator for RecordList {
    type Item = Record;
    type IntoIter = RecordListIter;
    #[must_use]
    fn into_iter(self) -> Self::IntoIter {
        RecordListIter(self.records.into_iter())
    }
}

pub struct AgentIoInputContext<'a> {
    input: &'a mut AgentIoInput,
}

impl Drop for AgentIoInputContext<'_> {
    fn drop(&mut self) {
        self.input.reset();
    }
}

#[derive(Debug)]
struct AgentIoInput {
    // input is data that is read by TLS.
    input: *const u8,
    // input_available is how much data is left for reading.
    available: usize,
}

impl AgentIoInput {
    fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
        assert!(self.input.is_null());
        self.input = input.as_ptr();
        self.available = input.len();
        qtrace!("AgentIoInput wrap {:p}", self.input);
        AgentIoInputContext { input: self }
    }

    // Take the data provided as input and provide it to the TLS stack.
    fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> {
        let amount = min(self.available, count);
        if amount == 0 {
            unsafe {
                PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0);
            }
            return Err(Error::NoDataAvailable);
        }

        #[allow(clippy::disallowed_methods)] // We just checked if this was empty.
        let src = unsafe { std::slice::from_raw_parts(self.input, amount) };
        qtrace!([self], "read {}", hex(src));
        let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) };
        dst.copy_from_slice(src);
        self.input = self.input.wrapping_add(amount);
        self.available -= amount;
        Ok(amount)
    }

    fn reset(&mut self) {
        qtrace!([self], "reset");
        self.input = null();
        self.available = 0;
    }
}

impl ::std::fmt::Display for AgentIoInput {
    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
        write!(f, "AgentIoInput {:p}", self.input)
    }
}

#[derive(Debug)]
pub struct AgentIo {
    // input collects the input we might provide to TLS.
    input: AgentIoInput,

    // output contains data that is written by TLS.
    output: Vec<u8>,
}

impl AgentIo {
    pub const fn new() -> Self {
        Self {
            input: AgentIoInput {
                input: null(),
                available: 0,
            },
            output: Vec::new(),
        }
    }

    unsafe fn borrow(fd: &mut PrFd) -> &mut Self {
        (**fd).secret.cast::<Self>().as_mut().unwrap()
    }

    pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
        assert_eq!(self.output.len(), 0);
        self.input.wrap(input)
    }

    // Stage output from TLS into the output buffer.
    fn save_output(&mut self, buf: *const u8, count: usize) {
        let slice = unsafe { null_safe_slice(buf, count) };
        qtrace!([self], "save output {}", hex(slice));
        self.output.extend_from_slice(slice);
    }

    pub fn take_output(&mut self) -> Vec<u8> {
        qtrace!([self], "take output");
        mem::take(&mut self.output)
    }
}

impl ::std::fmt::Display for AgentIo {
    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
        write!(f, "AgentIo")
    }
}

unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus {
    (*fd).secret = null_mut();
    if let Some(dtor) = (*fd).dtor {
        dtor(fd);
    }
    PR_SUCCESS
}

unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus {
    let io = AgentIo::borrow(&mut fd);
    if let Ok(a) = usize::try_from(amount) {
        match io.input.read_input(buf.cast(), a) {
            Ok(_) => PR_SUCCESS,
            Err(_) => PR_FAILURE,
        }
    } else {
        PR_FAILURE
    }
}

unsafe extern "C" fn agent_recv(
    mut fd: PrFd,
    buf: *mut c_void,
    amount: prio::PRInt32,
    flags: prio::PRIntn,
    _timeout: prio::PRIntervalTime,
) -> prio::PRInt32 {
    let io = AgentIo::borrow(&mut fd);
    if flags != 0 {
        return PR_FAILURE;
    }
    if let Ok(a) = usize::try_from(amount) {
        io.input.read_input(buf.cast(), a).map_or(PR_FAILURE, |v| {
            prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE)
        })
    } else {
        PR_FAILURE
    }
}

unsafe extern "C" fn agent_write(
    mut fd: PrFd,
    buf: *const c_void,
    amount: prio::PRInt32,
) -> PrStatus {
    let io = AgentIo::borrow(&mut fd);
    usize::try_from(amount).map_or(PR_FAILURE, |a| {
        io.save_output(buf.cast(), a);
        amount
    })
}

unsafe extern "C" fn agent_send(
    mut fd: PrFd,
    buf: *const c_void,
    amount: prio::PRInt32,
    flags: prio::PRIntn,
    _timeout: prio::PRIntervalTime,
) -> prio::PRInt32 {
    let io = AgentIo::borrow(&mut fd);

    if flags != 0 {
        return PR_FAILURE;
    }
    usize::try_from(amount).map_or(PR_FAILURE, |a| {
        io.save_output(buf.cast(), a);
        amount
    })
}

unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 {
    let io = AgentIo::borrow(&mut fd);
    io.input.available.try_into().unwrap_or(PR_FAILURE)
}

unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 {
    let io = AgentIo::borrow(&mut fd);
    io.input
        .available
        .try_into()
        .unwrap_or_else(|_| PR_FAILURE.into())
}

#[allow(clippy::cast_possible_truncation)]
unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
    let a = addr.as_mut().unwrap();
    // Cast is safe because prio::PR_AF_INET is 2
    a.inet.family = prio::PR_AF_INET as prio::PRUint16;
    a.inet.port = 0;
    a.inet.ip = 0;
    PR_SUCCESS
}

unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus {
    let o = opt.as_mut().unwrap();
    if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
        o.value.non_blocking = 1;
        return PR_SUCCESS;
    }
    PR_FAILURE
}

pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods {
    file_type: prio::PRDescType::PR_DESC_LAYERED,
    close: Some(agent_close),
    read: Some(agent_read),
    write: Some(agent_write),
    available: Some(agent_available),
    available64: Some(agent_available64),
    fsync: None,
    seek: None,
    seek64: None,
    fileInfo: None,
    fileInfo64: None,
    writev: None,
    connect: None,
    accept: None,
    bind: None,
    listen: None,
    shutdown: None,
    recv: Some(agent_recv),
    send: Some(agent_send),
    recvfrom: None,
    sendto: None,
    poll: None,
    acceptread: None,
    transmitfile: None,
    getsockname: Some(agent_getname),
    getpeername: Some(agent_getname),
    reserved_fn_6: None,
    reserved_fn_5: None,
    getsocketoption: Some(agent_getsockopt),
    setsocketoption: None,
    sendfile: None,
    connectcontinue: None,
    reserved_fn_3: None,
    reserved_fn_2: None,
    reserved_fn_1: None,
    reserved_fn_0: None,
};

[ Dauer der Verarbeitung: 0.3 Sekunden  (vorverarbeitet)  ]