diff --git a/src/uu/stty/src/flags.rs b/src/uu/stty/src/flags.rs index c2a82198a95..c346cbe7c5c 100644 --- a/src/uu/stty/src/flags.rs +++ b/src/uu/stty/src/flags.rs @@ -27,6 +27,14 @@ use nix::sys::termios::{ SpecialCharacterIndices as S, }; +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub enum BaudType { + Input, + Output, + Both, +} + #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum AllFlags<'a> { @@ -38,7 +46,7 @@ pub enum AllFlags<'a> { target_os = "netbsd", target_os = "openbsd" ))] - Baud(u32), + Baud(u32, BaudType), #[cfg(not(any( target_os = "freebsd", target_os = "dragonfly", @@ -47,7 +55,7 @@ pub enum AllFlags<'a> { target_os = "netbsd", target_os = "openbsd" )))] - Baud(BaudRate), + Baud(BaudRate, BaudType), ControlFlags((&'a Flag, bool)), InputFlags((&'a Flag, bool)), LocalFlags((&'a Flag, bool)), diff --git a/src/uu/stty/src/stty.rs b/src/uu/stty/src/stty.rs index 8808857b630..f34f9b498e5 100644 --- a/src/uu/stty/src/stty.rs +++ b/src/uu/stty/src/stty.rs @@ -10,7 +10,7 @@ // spell-checker:ignore isig icanon iexten echoe crterase echok echonl noflsh xcase tostop echoprt prterase echoctl ctlecho echoke crtkill flusho extproc // spell-checker:ignore lnext rprnt susp swtch vdiscard veof veol verase vintr vkill vlnext vquit vreprint vstart vstop vsusp vswtc vwerase werase // spell-checker:ignore sigquit sigtstp -// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb NCCS +// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb NCCS cfsetispeed // spell-checker:ignore notaflag notacombo notabaud mod flags; @@ -21,7 +21,7 @@ use clap::{Arg, ArgAction, ArgMatches, Command}; use nix::libc::{O_NONBLOCK, TIOCGWINSZ, TIOCSWINSZ, c_ushort}; use nix::sys::termios::{ ControlFlags, InputFlags, LocalFlags, OutputFlags, SetArg, SpecialCharacterIndices as S, - Termios, cfgetospeed, cfsetospeed, tcgetattr, tcsetattr, + Termios, cfgetospeed, cfsetispeed, cfsetospeed, tcgetattr, tcsetattr, }; use nix::{ioctl_read_bad, ioctl_write_ptr_bad}; use std::cmp::Ordering; @@ -274,19 +274,24 @@ fn stty(opts: &Options) -> UResult<()> { let mut args_iter = args.iter(); while let Some(&arg) = args_iter.next() { match arg { - "ispeed" | "ospeed" => match args_iter.next() { + "ispeed" => match args_iter.next() { Some(speed) => { - if let Some(baud_flag) = string_to_baud(speed) { + if let Some(baud_flag) = string_to_baud(speed, flags::BaudType::Input) { valid_args.push(ArgOptions::Flags(baud_flag)); } else { - return Err(USimpleError::new( - 1, - translate!( - "stty-error-invalid-speed", - "arg" => *arg, - "speed" => *speed, - ), - )); + return invalid_speed(arg, speed); + } + } + None => { + return missing_arg(arg); + } + }, + "ospeed" => match args_iter.next() { + Some(speed) => { + if let Some(baud_flag) = string_to_baud(speed, flags::BaudType::Output) { + valid_args.push(ArgOptions::Flags(baud_flag)); + } else { + return invalid_speed(arg, speed); } } None => { @@ -383,12 +388,12 @@ fn stty(opts: &Options) -> UResult<()> { return missing_arg(arg); } // baud rate - } else if let Some(baud_flag) = string_to_baud(arg) { + } else if let Some(baud_flag) = string_to_baud(arg, flags::BaudType::Both) { valid_args.push(ArgOptions::Flags(baud_flag)); // non control char flag } else if let Some(flag) = string_to_flag(arg) { let remove_group = match flag { - AllFlags::Baud(_) => false, + AllFlags::Baud(_, _) => false, AllFlags::ControlFlags((flag, remove)) => { check_flag_group(flag, remove) } @@ -417,7 +422,7 @@ fn stty(opts: &Options) -> UResult<()> { for arg in &valid_args { match arg { ArgOptions::Mapping(mapping) => apply_char_mapping(&mut termios, mapping), - ArgOptions::Flags(flag) => apply_setting(&mut termios, flag), + ArgOptions::Flags(flag) => apply_setting(&mut termios, flag)?, ArgOptions::Special(setting) => { apply_special_setting(&mut termios, setting, opts.file.as_raw_fd())?; } @@ -468,6 +473,17 @@ fn invalid_integer_arg(arg: &str) -> Result> { )) } +fn invalid_speed(arg: &str, speed: &str) -> Result> { + Err(UUsageError::new( + 1, + translate!( + "stty-error-invalid-speed", + "arg" => arg, + "speed" => speed, + ), + )) +} + /// GNU uses different error messages if values overflow or underflow a u8, /// this function returns the appropriate error message in the case of overflow or underflow, or u8 on success fn parse_u8_or_err(arg: &str) -> Result { @@ -719,7 +735,7 @@ fn parse_baud_with_rounding(normalized: &str) -> Option { Some(value) } -fn string_to_baud(arg: &str) -> Option> { +fn string_to_baud(arg: &str, baud_type: flags::BaudType) -> Option> { // Reject invalid formats if arg != arg.trim_end() || arg.trim().starts_with('-') @@ -744,7 +760,7 @@ fn string_to_baud(arg: &str) -> Option> { target_os = "netbsd", target_os = "openbsd" ))] - return Some(AllFlags::Baud(value)); + return Some(AllFlags::Baud(value, baud_type)); #[cfg(not(any( target_os = "freebsd", @@ -757,7 +773,7 @@ fn string_to_baud(arg: &str) -> Option> { { for (text, baud_rate) in BAUD_RATES { if text.parse::().ok() == Some(value) { - return Some(AllFlags::Baud(*baud_rate)); + return Some(AllFlags::Baud(*baud_rate, baud_type)); } } None @@ -940,9 +956,9 @@ fn print_flags( } /// Apply a single setting -fn apply_setting(termios: &mut Termios, setting: &AllFlags) { +fn apply_setting(termios: &mut Termios, setting: &AllFlags) -> nix::Result<()> { match setting { - AllFlags::Baud(_) => apply_baud_rate_flag(termios, setting), + AllFlags::Baud(_, _) => apply_baud_rate_flag(termios, setting)?, AllFlags::ControlFlags((setting, disable)) => { setting.flag.apply(termios, !disable); } @@ -956,34 +972,21 @@ fn apply_setting(termios: &mut Termios, setting: &AllFlags) { setting.flag.apply(termios, !disable); } } + Ok(()) } -fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) { - // BSDs use a u32 for the baud rate, so any decimal number applies. - #[cfg(any( - target_os = "freebsd", - target_os = "dragonfly", - target_os = "ios", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd" - ))] - if let AllFlags::Baud(n) = input { - cfsetospeed(termios, *n).expect("Failed to set baud rate"); - } - - // Other platforms use an enum. - #[cfg(not(any( - target_os = "freebsd", - target_os = "dragonfly", - target_os = "ios", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd" - )))] - if let AllFlags::Baud(br) = input { - cfsetospeed(termios, *br).expect("Failed to set baud rate"); +fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) -> nix::Result<()> { + if let AllFlags::Baud(rate, baud_type) = input { + match baud_type { + flags::BaudType::Input => cfsetispeed(termios, *rate)?, + flags::BaudType::Output => cfsetospeed(termios, *rate)?, + flags::BaudType::Both => { + cfsetispeed(termios, *rate)?; + cfsetospeed(termios, *rate)?; + } + } } + Ok(()) } fn apply_char_mapping(termios: &mut Termios, mapping: &(S, u8)) { @@ -1446,10 +1449,10 @@ mod tests { target_os = "openbsd" )))] { - assert!(string_to_baud("9600").is_some()); - assert!(string_to_baud("115200").is_some()); - assert!(string_to_baud("38400").is_some()); - assert!(string_to_baud("19200").is_some()); + assert!(string_to_baud("9600", flags::BaudType::Both).is_some()); + assert!(string_to_baud("115200", flags::BaudType::Both).is_some()); + assert!(string_to_baud("38400", flags::BaudType::Both).is_some()); + assert!(string_to_baud("19200", flags::BaudType::Both).is_some()); } #[cfg(any( @@ -1461,10 +1464,10 @@ mod tests { target_os = "openbsd" ))] { - assert!(string_to_baud("9600").is_some()); - assert!(string_to_baud("115200").is_some()); - assert!(string_to_baud("1000000").is_some()); - assert!(string_to_baud("0").is_some()); + assert!(string_to_baud("9600", flags::BaudType::Both).is_some()); + assert!(string_to_baud("115200", flags::BaudType::Both).is_some()); + assert!(string_to_baud("1000000", flags::BaudType::Both).is_some()); + assert!(string_to_baud("0", flags::BaudType::Both).is_some()); } } @@ -1479,10 +1482,10 @@ mod tests { target_os = "openbsd" )))] { - assert_eq!(string_to_baud("995"), None); - assert_eq!(string_to_baud("invalid"), None); - assert_eq!(string_to_baud(""), None); - assert_eq!(string_to_baud("abc"), None); + assert_eq!(string_to_baud("995", flags::BaudType::Both), None); + assert_eq!(string_to_baud("invalid", flags::BaudType::Both), None); + assert_eq!(string_to_baud("", flags::BaudType::Both), None); + assert_eq!(string_to_baud("abc", flags::BaudType::Both), None); } } diff --git a/tests/by-util/test_stty.rs b/tests/by-util/test_stty.rs index ae64eb6aeb5..c2b8a77e5a2 100644 --- a/tests/by-util/test_stty.rs +++ b/tests/by-util/test_stty.rs @@ -1627,6 +1627,71 @@ fn test_stty_uses_stdin() { .stdout_contains("columns 100"); } +#[test] +#[cfg(unix)] +fn test_ispeed_ospeed_valid_speeds() { + let (path, _controller, _replica) = pty_path(); + let (_at, ts) = at_and_ts!(); + + // Test various valid baud rates for both ispeed and ospeed + let test_cases = [ + ("ispeed", "50"), + ("ispeed", "9600"), + ("ispeed", "19200"), + ("ospeed", "1200"), + ("ospeed", "9600"), + ("ospeed", "38400"), + ]; + + for (arg, speed) in test_cases { + let result = ts.ucmd().args(&["--file", &path, arg, speed]).run(); + let exp_result = unwrap_or_return!(expected_result(&ts, &["--file", &path, arg, speed])); + let normalized_stderr = normalize_stderr(result.stderr_str()); + + result + .stdout_is(exp_result.stdout_str()) + .code_is(exp_result.code()); + assert_eq!(normalized_stderr, exp_result.stderr_str()); + } +} + +#[test] +#[cfg(all( + unix, + not(any( + target_os = "freebsd", + target_os = "dragonfly", + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + )) +))] +#[ignore = "Issue: #9547"] +fn test_ispeed_ospeed_invalid_speeds() { + let (path, _controller, _replica) = pty_path(); + let (_at, ts) = at_and_ts!(); + + // Test invalid speed values (non-standard baud rates) + let test_cases = [ + ("ispeed", "12345"), + ("ospeed", "99999"), + ("ispeed", "abc"), + ("ospeed", "xyz"), + ]; + + for (arg, speed) in test_cases { + let result = ts.ucmd().args(&["--file", &path, arg, speed]).run(); + let exp_result = unwrap_or_return!(expected_result(&ts, &["--file", &path, arg, speed])); + let normalized_stderr = normalize_stderr(result.stderr_str()); + + result + .stdout_is(exp_result.stdout_str()) + .code_is(exp_result.code()); + assert_eq!(normalized_stderr, exp_result.stderr_str()); + } +} + #[test] #[cfg(unix)] fn test_columns_env_wrapping() {