Add specialized SIMD line seeking routines (#408)
Some checks are pending
CI / check (ubuntu-latest) (push) Waiting to run
CI / check (windows-latest) (push) Waiting to run

The previous `memchr` loop had the fatal flaw that it would break out
of the SIMD routines every time it hit a newline. This resulted in a
throughput drop down to ~250MB/s on my system in the worst case.
By writing SIMD routines specific to newline seeking, we can bump
that up by >500x. Navigating through a 1GB of text now takes ~16ms
independent of the contents.
This commit is contained in:
Leonard Hecker 2025-06-05 19:34:07 +02:00 committed by GitHub
parent 6a7ff206a2
commit 065fa748cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 643 additions and 334 deletions

View file

@ -3,7 +3,7 @@
use std::hint::black_box;
use std::io::Cursor;
use std::mem;
use std::{mem, vec};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use edit::helpers::*;
@ -133,18 +133,36 @@ fn bench_oklab(c: &mut Criterion) {
});
}
fn bench_simd_lines_fwd(c: &mut Criterion) {
let mut group = c.benchmark_group("simd");
let buf = vec![b'\n'; 128 * MEBI];
for &lines in &[1, 8, 128, KIBI, 128 * KIBI, 128 * MEBI] {
group.throughput(Throughput::Bytes(lines as u64)).bench_with_input(
BenchmarkId::new("lines_fwd", lines),
&lines,
|b, &lines| {
b.iter(|| simd::lines_fwd(black_box(&buf), 0, 0, lines as CoordType));
},
);
}
}
fn bench_simd_memchr2(c: &mut Criterion) {
let mut group = c.benchmark_group("simd");
let mut buffer_u8 = [0u8; 2048];
let mut buf = vec![0u8; 128 * MEBI + KIBI];
for &bytes in &[8usize, 32 + 8, 64 + 8, KIBI + 8] {
// For small sizes we add a small offset of +8,
// to ensure we also benchmark the non-SIMD tail handling.
// For large sizes, its relative impact is negligible.
for &bytes in &[8usize, 128 + 8, KIBI, 128 * KIBI, 128 * MEBI] {
group.throughput(Throughput::Bytes(bytes as u64 + 1)).bench_with_input(
BenchmarkId::new("memchr2", bytes),
&bytes,
|b, &size| {
buffer_u8.fill(b'a');
buffer_u8[size] = b'\n';
b.iter(|| simd::memchr2(b'\n', b'\r', black_box(&buffer_u8), 0));
buf.fill(b'a');
buf[size] = b'\n';
b.iter(|| simd::memchr2(b'\n', b'\r', black_box(&buf), 0));
},
);
}
@ -154,9 +172,12 @@ fn bench_simd_memset<T: MemsetSafe + Copy + Default>(c: &mut Criterion) {
let mut group = c.benchmark_group("simd");
let name = format!("memset<{}>", std::any::type_name::<T>());
let size = mem::size_of::<T>();
let mut buf: Vec<T> = vec![Default::default(); 2048 / size];
let mut buf: Vec<T> = vec![Default::default(); 128 * MEBI / size];
for &bytes in &[8usize, 32 + 8, 64 + 8, KIBI + 8] {
// For small sizes we add a small offset of +8,
// to ensure we also benchmark the non-SIMD tail handling.
// For large sizes, its relative impact is negligible.
for &bytes in &[8usize, 128 + 8, KIBI, 128 * KIBI, 128 * MEBI] {
group.throughput(Throughput::Bytes(bytes as u64)).bench_with_input(
BenchmarkId::new(&name, bytes),
&bytes,
@ -206,6 +227,7 @@ fn bench(c: &mut Criterion) {
bench_buffer(c);
bench_hash(c);
bench_oklab(c);
bench_simd_lines_fwd(c);
bench_simd_memchr2(c);
bench_simd_memset::<u32>(c);
bench_simd_memset::<u8>(c);

View file

@ -8,7 +8,6 @@ use std::path::{Path, PathBuf};
use edit::buffer::{RcTextBuffer, TextBuffer};
use edit::helpers::{CoordType, Point};
use edit::simd::memrchr2;
use edit::{apperr, path, sys};
use crate::state::DisplayablePathBuf;
@ -244,8 +243,12 @@ impl DocumentManager {
Some(num)
}
fn find_colon_rev(bytes: &[u8], offset: usize) -> Option<usize> {
(0..offset.min(bytes.len())).rev().find(|&i| bytes[i] == b':')
}
let bytes = path.as_os_str().as_encoded_bytes();
let colend = match memrchr2(b':', b':', bytes, bytes.len()) {
let colend = match find_colon_rev(bytes, bytes.len()) {
// Reject filenames that would result in an empty filename after stripping off the :line:char suffix.
// For instance, a filename like ":123:456" will not be processed by this function.
Some(colend) if colend > 0 => colend,
@ -260,7 +263,7 @@ impl DocumentManager {
let mut len = colend;
let mut goto = Point { x: 0, y: last };
if let Some(colbeg) = memrchr2(b':', b':', bytes, colend) {
if let Some(colbeg) = find_colon_rev(bytes, colend) {
// Same here: Don't allow empty filenames.
if colbeg != 0
&& let Some(first) = parse(&bytes[colbeg + 1..colend])

View file

@ -44,7 +44,7 @@ use crate::helpers::*;
use crate::oklab::oklab_blend;
use crate::simd::memchr2;
use crate::unicode::{self, Cursor, MeasurementConfig};
use crate::{apperr, icu};
use crate::{apperr, icu, simd};
/// The margin template is used for line numbers.
/// The max. line number we should ever expect is probably 64-bit,
@ -341,7 +341,7 @@ impl TextBuffer {
break 'outer;
}
let (delta, line) = unicode::newlines_forward(chunk, 0, 0, 1);
let (delta, line) = simd::lines_fwd(chunk, 0, 0, 1);
off += delta;
if line == 1 {
break;
@ -684,7 +684,7 @@ impl TextBuffer {
}
}
(offset, lines) = unicode::newlines_forward(chunk, offset, lines, lines + 1);
(offset, lines) = simd::lines_fwd(chunk, offset, lines, lines + 1);
// Check if the preceding line ended in CRLF.
if offset >= 2 && &chunk[offset - 2..offset] == b"\r\n" {
@ -723,7 +723,7 @@ impl TextBuffer {
// If the file has more than 1000 lines, figure out how many are remaining.
if offset < chunk.len() {
(_, lines) = unicode::newlines_forward(chunk, offset, lines, CoordType::MAX);
(_, lines) = simd::lines_fwd(chunk, offset, lines, CoordType::MAX);
}
let final_newline = chunk.ends_with(b"\n");
@ -1219,7 +1219,7 @@ impl TextBuffer {
break;
}
let (delta, line) = unicode::newlines_forward(chunk, 0, result.logical_pos.y, y);
let (delta, line) = simd::lines_fwd(chunk, 0, result.logical_pos.y, y);
result.offset += delta;
result.logical_pos.y = line;
}
@ -1239,8 +1239,7 @@ impl TextBuffer {
break;
}
let (delta, line) =
unicode::newlines_backward(chunk, chunk.len(), result.logical_pos.y, y);
let (delta, line) = simd::lines_bwd(chunk, chunk.len(), result.logical_pos.y, y);
result.offset -= chunk.len() - delta;
result.logical_pos.y = line;
if delta > 0 {
@ -2082,7 +2081,7 @@ impl TextBuffer {
selection_end.x -= remove as CoordType;
}
(offset, y) = unicode::newlines_forward(&replacement, offset, y, y + 1);
(offset, y) = simd::lines_fwd(&replacement, offset, y, y + 1);
}
if replacement.len() == initial_len {
@ -2376,7 +2375,7 @@ impl TextBuffer {
let mut offset = cursor.offset;
while beg < added.len() {
let (end, line) = unicode::newlines_forward(added, beg, 0, 1);
let (end, line) = simd::lines_fwd(added, beg, 0, 1);
let has_newline = line != 0;
let link = &added[beg..end];
let line = unicode::strip_newline(link);

283
src/simd/lines_bwd.rs Normal file
View file

@ -0,0 +1,283 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
use std::ptr;
use crate::helpers::CoordType;
/// Starting from the `offset` in `haystack` with a current line index of
/// `line`, this seeks backwards to the `line_stop`-nth line and returns the
/// new offset and the line index at that point.
///
/// Note that this function differs from `lines_fwd` in that it
/// seeks backwards even if the `line` is already at `line_stop`.
/// This allows you to ensure (or test) whether `offset` is at a line start.
///
/// It returns an offset *past* a newline and thus at the start of a line.
pub fn lines_bwd(
haystack: &[u8],
offset: usize,
line: CoordType,
line_stop: CoordType,
) -> (usize, CoordType) {
unsafe {
let beg = haystack.as_ptr();
let it = beg.add(offset.min(haystack.len()));
let (it, line) = lines_bwd_raw(beg, it, line, line_stop);
(it.offset_from_unsigned(beg), line)
}
}
unsafe fn lines_bwd_raw(
beg: *const u8,
end: *const u8,
line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
#[cfg(target_arch = "x86_64")]
return unsafe { LINES_BWD_DISPATCH(beg, end, line, line_stop) };
#[cfg(target_arch = "aarch64")]
return unsafe { lines_bwd_neon(beg, end, line, line_stop) };
#[allow(unreachable_code)]
return unsafe { lines_bwd_fallback(beg, end, line, line_stop) };
}
unsafe fn lines_bwd_fallback(
beg: *const u8,
mut end: *const u8,
mut line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
unsafe {
while !ptr::eq(end, beg) {
let n = end.sub(1);
if *n == b'\n' {
if line <= line_stop {
break;
}
line -= 1;
}
end = n;
}
(end, line)
}
}
#[cfg(target_arch = "x86_64")]
static mut LINES_BWD_DISPATCH: unsafe fn(
beg: *const u8,
end: *const u8,
line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) = lines_bwd_dispatch;
#[cfg(target_arch = "x86_64")]
unsafe fn lines_bwd_dispatch(
beg: *const u8,
end: *const u8,
line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
let func = if is_x86_feature_detected!("avx2") { lines_bwd_avx2 } else { lines_bwd_fallback };
unsafe { LINES_BWD_DISPATCH = func };
unsafe { func(beg, end, line, line_stop) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn lines_bwd_avx2(
beg: *const u8,
mut end: *const u8,
mut line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
unsafe {
use std::arch::x86_64::*;
#[inline(always)]
unsafe fn horizontal_sum_i64(v: __m256i) -> i64 {
unsafe {
let hi = _mm256_extracti128_si256::<1>(v);
let lo = _mm256_castsi256_si128(v);
let sum = _mm_add_epi64(lo, hi);
let shuf = _mm_shuffle_epi32::<0b11_10_11_10>(sum);
let sum = _mm_add_epi64(sum, shuf);
_mm_cvtsi128_si64(sum)
}
}
let lf = _mm256_set1_epi8(b'\n' as i8);
let line_stop = line_stop.min(line);
let mut remaining = end.offset_from_unsigned(beg);
while remaining >= 128 {
let chunk_start = end.sub(128);
let v1 = _mm256_loadu_si256(chunk_start.add(0) as *const _);
let v2 = _mm256_loadu_si256(chunk_start.add(32) as *const _);
let v3 = _mm256_loadu_si256(chunk_start.add(64) as *const _);
let v4 = _mm256_loadu_si256(chunk_start.add(96) as *const _);
let mut sum = _mm256_setzero_si256();
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v1, lf));
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v2, lf));
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v3, lf));
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v4, lf));
let sum = _mm256_sad_epu8(sum, _mm256_setzero_si256());
let sum = horizontal_sum_i64(sum);
let line_next = line - sum as CoordType;
if line_next <= line_stop {
break;
}
end = chunk_start;
remaining -= 128;
line = line_next;
}
while remaining >= 32 {
let chunk_start = end.sub(32);
let v = _mm256_loadu_si256(chunk_start as *const _);
let c = _mm256_cmpeq_epi8(v, lf);
let ones = _mm256_and_si256(c, _mm256_set1_epi8(0x01));
let sum = _mm256_sad_epu8(ones, _mm256_setzero_si256());
let sum = horizontal_sum_i64(sum);
let line_next = line - sum as CoordType;
if line_next <= line_stop {
break;
}
end = chunk_start;
remaining -= 32;
line = line_next;
}
lines_bwd_fallback(beg, end, line, line_stop)
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn lines_bwd_neon(
beg: *const u8,
mut end: *const u8,
mut line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
unsafe {
use std::arch::aarch64::*;
let lf = vdupq_n_u8(b'\n');
let line_stop = line_stop.min(line);
let mut remaining = end.offset_from_unsigned(beg);
while remaining >= 64 {
let chunk_start = end.sub(64);
let v1 = vld1q_u8(chunk_start.add(0));
let v2 = vld1q_u8(chunk_start.add(16));
let v3 = vld1q_u8(chunk_start.add(32));
let v4 = vld1q_u8(chunk_start.add(48));
let mut sum = vdupq_n_u8(0);
sum = vsubq_u8(sum, vceqq_u8(v1, lf));
sum = vsubq_u8(sum, vceqq_u8(v2, lf));
sum = vsubq_u8(sum, vceqq_u8(v3, lf));
sum = vsubq_u8(sum, vceqq_u8(v4, lf));
let sum = vaddvq_u8(sum);
let line_next = line - sum as CoordType;
if line_next <= line_stop {
break;
}
end = chunk_start;
remaining -= 64;
line = line_next;
}
while remaining >= 16 {
let chunk_start = end.sub(16);
let v = vld1q_u8(chunk_start);
let c = vceqq_u8(v, lf);
let c = vandq_u8(c, vdupq_n_u8(0x01));
let sum = vaddvq_u8(c);
let line_next = line - sum as CoordType;
if line_next <= line_stop {
break;
}
end = chunk_start;
remaining -= 16;
line = line_next;
}
lines_bwd_fallback(beg, end, line, line_stop)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::helpers::CoordType;
use crate::simd::test::*;
#[test]
fn pseudo_fuzz() {
let text = generate_random_text(1024);
let lines = count_lines(&text);
let mut offset_rng = make_rng();
let mut line_rng = make_rng();
let mut line_distance_rng = make_rng();
for _ in 0..1000 {
let offset = offset_rng() % (text.len() + 1);
let line_stop = line_distance_rng() % (lines + 1);
let line = line_stop + line_rng() % 100;
let line = line as CoordType;
let line_stop = line_stop as CoordType;
let expected = reference_lines_bwd(text.as_bytes(), offset, line, line_stop);
let actual = lines_bwd(text.as_bytes(), offset, line, line_stop);
assert_eq!(expected, actual);
}
}
fn reference_lines_bwd(
haystack: &[u8],
mut offset: usize,
mut line: CoordType,
line_stop: CoordType,
) -> (usize, CoordType) {
if line >= line_stop {
while offset > 0 {
let c = haystack[offset - 1];
if c == b'\n' {
if line == line_stop {
break;
}
line -= 1;
}
offset -= 1;
}
}
(offset, line)
}
#[test]
fn seeks_to_start() {
for i in 6..=11 {
let (off, line) = lines_bwd(b"Hello\nWorld\n", i, 123, 456);
assert_eq!(off, 6); // After "Hello\n"
assert_eq!(line, 123); // Still on the same line
}
}
}

281
src/simd/lines_fwd.rs Normal file
View file

@ -0,0 +1,281 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
use std::ptr;
use crate::helpers::CoordType;
/// Starting from the `offset` in `haystack` with a current line index of
/// `line`, this seeks to the `line_stop`-nth line and returns the
/// new offset and the line index at that point.
///
/// It returns an offset *past* the newline.
/// If `line` is already at or past `line_stop`, it returns immediately.
pub fn lines_fwd(
haystack: &[u8],
offset: usize,
line: CoordType,
line_stop: CoordType,
) -> (usize, CoordType) {
unsafe {
let beg = haystack.as_ptr();
let end = beg.add(haystack.len());
let it = beg.add(offset.min(haystack.len()));
let (it, line) = lines_fwd_raw(it, end, line, line_stop);
(it.offset_from_unsigned(beg), line)
}
}
unsafe fn lines_fwd_raw(
beg: *const u8,
end: *const u8,
line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
#[cfg(target_arch = "x86_64")]
return unsafe { LINES_FWD_DISPATCH(beg, end, line, line_stop) };
#[cfg(target_arch = "aarch64")]
return unsafe { lines_fwd_neon(beg, end, line, line_stop) };
#[allow(unreachable_code)]
return unsafe { lines_fwd_fallback(beg, end, line, line_stop) };
}
unsafe fn lines_fwd_fallback(
mut beg: *const u8,
end: *const u8,
mut line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
unsafe {
if line < line_stop {
while !ptr::eq(beg, end) {
let c = *beg;
beg = beg.add(1);
if c == b'\n' {
line += 1;
if line == line_stop {
break;
}
}
}
}
(beg, line)
}
}
#[cfg(target_arch = "x86_64")]
static mut LINES_FWD_DISPATCH: unsafe fn(
beg: *const u8,
end: *const u8,
line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) = lines_fwd_dispatch;
#[cfg(target_arch = "x86_64")]
unsafe fn lines_fwd_dispatch(
beg: *const u8,
end: *const u8,
line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
let func = if is_x86_feature_detected!("avx2") { lines_fwd_avx2 } else { lines_fwd_fallback };
unsafe { LINES_FWD_DISPATCH = func };
unsafe { func(beg, end, line, line_stop) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn lines_fwd_avx2(
mut beg: *const u8,
end: *const u8,
mut line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
unsafe {
use std::arch::x86_64::*;
#[inline(always)]
unsafe fn horizontal_sum_i64(v: __m256i) -> i64 {
unsafe {
let hi = _mm256_extracti128_si256::<1>(v);
let lo = _mm256_castsi256_si128(v);
let sum = _mm_add_epi64(lo, hi);
let shuf = _mm_shuffle_epi32::<0b11_10_11_10>(sum);
let sum = _mm_add_epi64(sum, shuf);
_mm_cvtsi128_si64(sum)
}
}
let lf = _mm256_set1_epi8(b'\n' as i8);
let mut remaining = end.offset_from_unsigned(beg);
if line < line_stop {
// Unrolling the loop by 4x speeds things up by >3x.
// It allows us to accumulate matches before doing a single `vpsadbw`.
while remaining >= 128 {
let v1 = _mm256_loadu_si256(beg.add(0) as *const _);
let v2 = _mm256_loadu_si256(beg.add(32) as *const _);
let v3 = _mm256_loadu_si256(beg.add(64) as *const _);
let v4 = _mm256_loadu_si256(beg.add(96) as *const _);
// `vpcmpeqb` leaves each comparison result byte as 0 or -1 (0xff).
// This allows us to accumulate the comparisons by subtracting them.
let mut sum = _mm256_setzero_si256();
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v1, lf));
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v2, lf));
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v3, lf));
sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v4, lf));
// Calculate the total number of matches in this chunk.
let sum = _mm256_sad_epu8(sum, _mm256_setzero_si256());
let sum = horizontal_sum_i64(sum);
let line_next = line + sum as CoordType;
if line_next >= line_stop {
break;
}
beg = beg.add(128);
remaining -= 128;
line = line_next;
}
while remaining >= 32 {
let v = _mm256_loadu_si256(beg as *const _);
let c = _mm256_cmpeq_epi8(v, lf);
// If you ask an LLM, the best way to do this is
// to do a `vpmovmskb` followed by `popcnt`.
// One contemporary hardware that's a bad idea though.
let ones = _mm256_and_si256(c, _mm256_set1_epi8(0x01));
let sum = _mm256_sad_epu8(ones, _mm256_setzero_si256());
let sum = horizontal_sum_i64(sum);
let line_next = line + sum as CoordType;
if line_next >= line_stop {
break;
}
beg = beg.add(32);
remaining -= 32;
line = line_next;
}
}
lines_fwd_fallback(beg, end, line, line_stop)
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn lines_fwd_neon(
mut beg: *const u8,
end: *const u8,
mut line: CoordType,
line_stop: CoordType,
) -> (*const u8, CoordType) {
unsafe {
use std::arch::aarch64::*;
let lf = vdupq_n_u8(b'\n');
let mut remaining = end.offset_from_unsigned(beg);
if line < line_stop {
while remaining >= 64 {
let v1 = vld1q_u8(beg.add(0));
let v2 = vld1q_u8(beg.add(16));
let v3 = vld1q_u8(beg.add(32));
let v4 = vld1q_u8(beg.add(48));
// `vceqq_u8` leaves each comparison result byte as 0 or -1 (0xff).
// This allows us to accumulate the comparisons by subtracting them.
let mut sum = vdupq_n_u8(0);
sum = vsubq_u8(sum, vceqq_u8(v1, lf));
sum = vsubq_u8(sum, vceqq_u8(v2, lf));
sum = vsubq_u8(sum, vceqq_u8(v3, lf));
sum = vsubq_u8(sum, vceqq_u8(v4, lf));
let sum = vaddvq_u8(sum);
let line_next = line + sum as CoordType;
if line_next >= line_stop {
break;
}
beg = beg.add(64);
remaining -= 64;
line = line_next;
}
while remaining >= 16 {
let v = vld1q_u8(beg);
let c = vceqq_u8(v, lf);
let c = vandq_u8(c, vdupq_n_u8(0x01));
let sum = vaddvq_u8(c);
let line_next = line + sum as CoordType;
if line_next >= line_stop {
break;
}
beg = beg.add(16);
remaining -= 16;
line = line_next;
}
}
lines_fwd_fallback(beg, end, line, line_stop)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::helpers::CoordType;
use crate::simd::test::*;
#[test]
fn pseudo_fuzz() {
let text = generate_random_text(1024);
let lines = count_lines(&text);
let mut offset_rng = make_rng();
let mut line_rng = make_rng();
let mut line_distance_rng = make_rng();
for _ in 0..1000 {
let offset = offset_rng() % (text.len() + 1);
let line = line_rng() % 100;
let line_stop = line + line_distance_rng() % (lines + 1);
let line = line as CoordType;
let line_stop = line_stop as CoordType;
let expected = reference_lines_fwd(text.as_bytes(), offset, line, line_stop);
let actual = lines_fwd(text.as_bytes(), offset, line, line_stop);
assert_eq!(expected, actual);
}
}
fn reference_lines_fwd(
haystack: &[u8],
mut offset: usize,
mut line: CoordType,
line_stop: CoordType,
) -> (usize, CoordType) {
if line < line_stop {
while offset < haystack.len() {
let c = haystack[offset];
offset += 1;
if c == b'\n' {
line += 1;
if line == line_stop {
break;
}
}
}
}
(offset, line)
}
}

View file

@ -1,194 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
//! `memchr`, but with two needles.
use std::ptr;
/// `memchr`, but with two needles.
///
/// If no needle is found, 0 is returned.
/// Unlike `memchr2` (or `memrchr`), an offset PAST the hit is returned.
/// This is because this function is primarily used for
/// `ucd::newlines_backward`, which needs exactly that.
pub fn memrchr2(needle1: u8, needle2: u8, haystack: &[u8], offset: usize) -> Option<usize> {
unsafe {
let beg = haystack.as_ptr();
let it = beg.add(offset.min(haystack.len()));
let it = memrchr2_raw(needle1, needle2, beg, it);
if it.is_null() { None } else { Some(it.offset_from_unsigned(beg)) }
}
}
unsafe fn memrchr2_raw(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
return unsafe { MEMRCHR2_DISPATCH(needle1, needle2, beg, end) };
#[cfg(target_arch = "aarch64")]
return unsafe { memrchr2_neon(needle1, needle2, beg, end) };
#[allow(unreachable_code)]
return unsafe { memrchr2_fallback(needle1, needle2, beg, end) };
}
unsafe fn memrchr2_fallback(
needle1: u8,
needle2: u8,
beg: *const u8,
mut end: *const u8,
) -> *const u8 {
unsafe {
while !ptr::eq(end, beg) {
end = end.sub(1);
let ch = *end;
if ch == needle1 || needle2 == ch {
return end;
}
}
ptr::null()
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
static mut MEMRCHR2_DISPATCH: unsafe fn(
needle1: u8,
needle2: u8,
beg: *const u8,
end: *const u8,
) -> *const u8 = memrchr2_dispatch;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn memrchr2_dispatch(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 {
let func = if is_x86_feature_detected!("avx2") { memrchr2_avx2 } else { memrchr2_fallback };
unsafe { MEMRCHR2_DISPATCH = func };
unsafe { func(needle1, needle2, beg, end) }
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn memrchr2_avx2(needle1: u8, needle2: u8, beg: *const u8, mut end: *const u8) -> *const u8 {
unsafe {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
if end.offset_from_unsigned(beg) >= 32 {
let n1 = _mm256_set1_epi8(needle1 as i8);
let n2 = _mm256_set1_epi8(needle2 as i8);
loop {
end = end.sub(32);
let v = _mm256_loadu_si256(end as *const _);
let a = _mm256_cmpeq_epi8(v, n1);
let b = _mm256_cmpeq_epi8(v, n2);
let c = _mm256_or_si256(a, b);
let m = _mm256_movemask_epi8(c) as u32;
if m != 0 {
return end.add(31 - m.leading_zeros() as usize);
}
if end.offset_from_unsigned(beg) < 32 {
break;
}
}
}
memrchr2_fallback(needle1, needle2, beg, end)
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn memrchr2_neon(needle1: u8, needle2: u8, beg: *const u8, mut end: *const u8) -> *const u8 {
unsafe {
use std::arch::aarch64::*;
if end.offset_from_unsigned(beg) >= 16 {
let n1 = vdupq_n_u8(needle1);
let n2 = vdupq_n_u8(needle2);
loop {
end = end.sub(16);
let v = vld1q_u8(end as *const _);
let a = vceqq_u8(v, n1);
let b = vceqq_u8(v, n2);
let c = vorrq_u8(a, b);
// https://community.arm.com/arm-community-blogs/b/servers-and-cloud-computing-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
let m = vreinterpretq_u16_u8(c);
let m = vshrn_n_u16(m, 4);
let m = vreinterpret_u64_u8(m);
let m = vget_lane_u64(m, 0);
if m != 0 {
return end.add(15 - (m.leading_zeros() as usize >> 2));
}
if end.offset_from_unsigned(beg) < 16 {
break;
}
}
}
memrchr2_fallback(needle1, needle2, beg, end)
}
}
#[cfg(test)]
mod tests {
use std::slice;
use super::*;
use crate::sys;
#[test]
fn test_empty() {
assert_eq!(memrchr2(b'a', b'b', b"", 0), None);
}
#[test]
fn test_basic() {
let haystack = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
let haystack = &haystack[..43];
assert_eq!(memrchr2(b'Q', b'P', haystack, 43), Some(42));
assert_eq!(memrchr2(b'p', b'o', haystack, 43), Some(15));
assert_eq!(memrchr2(b'a', b'b', haystack, 43), Some(1));
assert_eq!(memrchr2(b'0', b'9', haystack, 43), None);
}
// Test that it doesn't match before/after the start offset respectively.
#[test]
fn test_with_offset() {
let haystack = b"abcdefghabcdefghabcdefghabcdefghabcdefgh";
assert_eq!(memrchr2(b'h', b'g', haystack, 40), Some(39));
assert_eq!(memrchr2(b'h', b'g', haystack, 39), Some(38));
assert_eq!(memrchr2(b'a', b'b', haystack, 9), Some(8));
assert_eq!(memrchr2(b'a', b'b', haystack, 1), Some(0));
assert_eq!(memrchr2(b'a', b'b', haystack, 0), None);
}
// Test memory access safety at page boundaries.
// The test is a success if it doesn't segfault.
#[test]
fn test_page_boundary() {
let page = unsafe {
const PAGE_SIZE: usize = 64 * 1024; // 64 KiB to cover many architectures.
// 3 pages: uncommitted, committed, uncommitted
let ptr = sys::virtual_reserve(PAGE_SIZE * 3).unwrap();
sys::virtual_commit(ptr.add(PAGE_SIZE), PAGE_SIZE).unwrap();
slice::from_raw_parts_mut(ptr.add(PAGE_SIZE).as_ptr(), PAGE_SIZE)
};
page.fill(b'a');
// Same as above, but for memrchr2 (hence reversed).
assert_eq!(memrchr2(b'\0', b'\0', &page[page.len() - 10..], 10), None);
assert_eq!(memrchr2(b'\0', b'\0', &page[..40], 40), None);
}
}

View file

@ -3,10 +3,41 @@
//! Provides various high-throughput utilities.
pub mod lines_bwd;
pub mod lines_fwd;
mod memchr2;
mod memrchr2;
mod memset;
pub use lines_bwd::*;
pub use lines_fwd::*;
pub use memchr2::*;
pub use memrchr2::*;
pub use memset::*;
#[cfg(test)]
mod test {
// Knuth's MMIX LCG
pub fn make_rng() -> impl FnMut() -> usize {
let mut state = 1442695040888963407u64;
move || {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
state as usize
}
}
pub fn generate_random_text(len: usize) -> String {
const ALPHABET: &[u8; 20] = b"0123456789abcdef\n\n\n\n";
let mut rng = make_rng();
let mut res = String::new();
for _ in 0..len {
res.push(ALPHABET[rng() % ALPHABET.len()] as char);
}
res
}
pub fn count_lines(text: &str) -> usize {
text.lines().count()
}
}

View file

@ -157,7 +157,7 @@ use crate::framebuffer::{Attributes, Framebuffer, INDEXED_COLORS_COUNT, IndexedC
use crate::hash::*;
use crate::helpers::*;
use crate::input::{InputKeyMod, kbmod, vk};
use crate::{apperr, arena_format, input, unicode};
use crate::{apperr, arena_format, input, simd, unicode};
const ROOT_ID: u64 = 0x14057B7EF767814F; // Knuth's MMIX constant
const SHIFT_TAB: InputKey = vk::TAB.with_modifiers(kbmod::SHIFT);
@ -2690,7 +2690,7 @@ impl<'a> Context<'a, '_> {
}
if single_line && !write.is_empty() {
let (end, _) = unicode::newlines_forward(write, 0, 0, 1);
let (end, _) = simd::lines_fwd(write, 0, 0, 1);
write = unicode::strip_newline(&write[..end]);
}
if !write.is_empty() {

View file

@ -7,7 +7,6 @@ use super::Utf8Chars;
use super::tables::*;
use crate::document::ReadableDocument;
use crate::helpers::{CoordType, Point};
use crate::simd::{memchr2, memrchr2};
// On one hand it's disgusting that I wrote this as a global variable, but on the
// other hand, this isn't a public library API, and it makes the code a lot cleaner,
@ -478,104 +477,6 @@ impl<'doc> MeasurementConfig<'doc> {
}
}
/// Seeks forward to the given line start.
///
/// If given a piece of `text`, and assuming you're currently at `offset` which
/// is on the logical line `line`, this will seek forward until the logical line
/// `line_stop` is reached. For instance, if `line` is 0 and `line_stop` is 2,
/// it'll seek forward past 2 line feeds.
///
/// This function always stops exactly past a line feed
/// and thus returns a position at the start of a line.
///
/// # Warning
///
/// If the end of `text` is hit before reaching `line_stop`, the function
/// will return an offset of `text.len()`, not at the start of a line.
///
/// # Parameters
///
/// * `text`: The text to search in.
/// * `offset`: The offset to start searching from.
/// * `line`: The current line.
/// * `line_stop`: The line to stop at.
///
/// # Returns
///
/// A tuple consisting of:
/// * The new offset.
/// * The line number that was reached.
pub fn newlines_forward(
text: &[u8],
mut offset: usize,
mut line: CoordType,
line_stop: CoordType,
) -> (usize, CoordType) {
// Leaving the cursor at the beginning of the current line when the limit
// is 0 makes this function behave identical to ucd_newlines_backward.
if line >= line_stop {
return newlines_backward(text, offset, line, line_stop);
}
let len = text.len();
offset = offset.min(len);
loop {
// TODO: This code could be optimized by replacing memchr with manual line counting.
//
// If `line_stop` is very far away, we could accumulate newline counts horizontally
// in a AVX2 register (= 32 u8 slots). Then, every 256 bytes we compute the horizontal
// sum via `_mm256_sad_epu8` yielding us the newline count in the last block.
//
// We could also just use `_mm256_sad_epu8` on each fetch as-is.
offset = memchr2(b'\n', b'\n', text, offset);
if offset >= len {
break;
}
offset += 1;
line += 1;
if line >= line_stop {
break;
}
}
(offset, line)
}
/// Seeks backward to the given line start.
///
/// See [`newlines_forward`] for details.
/// This function does almost the same thing, but in reverse.
///
/// # Warning
///
/// In addition to the notes in [`newlines_forward`]:
///
/// No matter what parameters are given, [`newlines_backward`] only returns an
/// offset at the start of a line. Put differently, even if `line == line_stop`,
/// it'll seek backward to the line start.
pub fn newlines_backward(
text: &[u8],
mut offset: usize,
mut line: CoordType,
line_stop: CoordType,
) -> (usize, CoordType) {
offset = offset.min(text.len());
loop {
offset = match memrchr2(b'\n', b'\n', text, offset) {
Some(i) => i,
None => return (0, line),
};
if line <= line_stop {
// +1: Past the newline, at the start of the current line.
return (offset + 1, line);
}
line -= 1;
}
}
/// Returns an offset past a newline.
///
/// If `offset` is right in front of a newline,
@ -1152,23 +1053,6 @@ mod test {
);
}
#[test]
fn test_newlines_and_strip() {
// Offset line 0: 0
// Offset line 1: 6
// Offset line 2: 13
// Offset line 3: 18
let text = "line1\nline2\r\nline3".as_bytes();
assert_eq!(newlines_forward(text, 0, 0, 2), (13, 2));
assert_eq!(newlines_forward(text, 0, 0, 0), (0, 0));
assert_eq!(newlines_forward(text, 100, 2, 100), (18, 2));
assert_eq!(newlines_backward(text, 18, 2, 1), (6, 1));
assert_eq!(newlines_backward(text, 18, 2, 0), (0, 0));
assert_eq!(newlines_backward(text, 100, 2, 1), (6, 1));
}
#[test]
fn test_strip_newline() {
assert_eq!(strip_newline(b"hello\n"), b"hello");