@@ -10,6 +10,7 @@ namespace SimdUnicode
1010 public static class UTF8
1111 {
1212
13+
1314 // Returns &inputBuffer[inputLength] if the input buffer is valid.
1415 /// <summary>
1516 /// Given an input buffer <paramref name="pInputBuffer"/> of byte length <paramref name="inputLength"/>,
@@ -35,11 +36,10 @@ public static class UTF8
3536 {
3637 return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength);
3738 }*/
38- // if (Ssse3.IsSupported)
39- // {
40- // return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength);
41- // }
42- // return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength);
39+ if ( Ssse3 . IsSupported )
40+ {
41+ return GetPointerToFirstInvalidByteSse ( pInputBuffer , inputLength , out Utf16CodeUnitCountAdjustment , out ScalarCodeUnitCountAdjustment ) ;
42+ }
4343
4444 return GetPointerToFirstInvalidByteScalar ( pInputBuffer , inputLength , out Utf16CodeUnitCountAdjustment , out ScalarCodeUnitCountAdjustment ) ;
4545
@@ -471,15 +471,13 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
471471 return ( utfadjust , scalaradjust ) ;
472472 }
473473
474- public unsafe static byte * GetPointerToFirstInvalidByteSse ( byte * pInputBuffer , int inputLength )
474+ public unsafe static byte * GetPointerToFirstInvalidByteSse ( byte * pInputBuffer , int inputLength , out int utf16CodeUnitCountAdjustment , out int scalarCountAdjustment )
475475 {
476-
477476 int processedLength = 0 ;
478- int TempUtf16CodeUnitCountAdjustment = 0 ;
479- int TempScalarCountAdjustment = 0 ;
480-
481477 if ( pInputBuffer == null || inputLength <= 0 )
482478 {
479+ utf16CodeUnitCountAdjustment = 0 ;
480+ scalarCountAdjustment = 0 ;
483481 return pInputBuffer ;
484482 }
485483 if ( inputLength > 128 )
@@ -503,24 +501,24 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
503501
504502 if ( processedLength + 16 < inputLength )
505503 {
506- // We still have work to do!
507504 Vector128 < byte > prevInputBlock = Vector128 < byte > . Zero ;
508505
509506 Vector128 < byte > maxValue = Vector128 . Create (
510507 255 , 255 , 255 , 255 , 255 , 255 , 255 , 255 ,
511508 255 , 255 , 255 , 255 , 255 , 0b11110000 - 1 , 0b11100000 - 1 , 0b11000000 - 1 ) ;
512- Vector128 < byte > prevIncomplete = Sse2 . SubtractSaturate ( prevInputBlock , maxValue ) ;
513-
509+ Vector128 < byte > prevIncomplete = Sse3 . SubtractSaturate ( prevInputBlock , maxValue ) ;
514510
515- Vector128 < byte > shuf1 = Vector128 . Create ( TOO_LONG , TOO_LONG , TOO_LONG , TOO_LONG ,
511+ Vector128 < byte > shuf1 = Vector128 . Create (
512+ TOO_LONG , TOO_LONG , TOO_LONG , TOO_LONG ,
516513 TOO_LONG , TOO_LONG , TOO_LONG , TOO_LONG ,
517514 TWO_CONTS , TWO_CONTS , TWO_CONTS , TWO_CONTS ,
518515 TOO_SHORT | OVERLONG_2 ,
519516 TOO_SHORT ,
520517 TOO_SHORT | OVERLONG_3 | SURROGATE ,
521518 TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4 ) ;
522519
523- Vector128 < byte > shuf2 = Vector128 . Create ( CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4 ,
520+ Vector128 < byte > shuf2 = Vector128 . Create (
521+ CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4 ,
524522 CARRY | OVERLONG_2 ,
525523 CARRY ,
526524 CARRY ,
@@ -536,7 +534,8 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
536534 CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE ,
537535 CARRY | TOO_LARGE | TOO_LARGE_1000 ,
538536 CARRY | TOO_LARGE | TOO_LARGE_1000 ) ;
539- Vector128 < byte > shuf3 = Vector128 . Create ( TOO_SHORT , TOO_SHORT , TOO_SHORT , TOO_SHORT ,
537+ Vector128 < byte > shuf3 = Vector128 . Create (
538+ TOO_SHORT , TOO_SHORT , TOO_SHORT , TOO_SHORT ,
540539 TOO_SHORT , TOO_SHORT , TOO_SHORT , TOO_SHORT ,
541540 TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4 ,
542541 TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE ,
@@ -548,24 +547,71 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
548547 Vector128 < byte > fourthByte = Vector128 . Create ( ( byte ) ( 0b11110000u - 0x80 ) ) ;
549548 Vector128 < byte > v0f = Vector128 . Create ( ( byte ) 0x0F ) ;
550549 Vector128 < byte > v80 = Vector128 . Create ( ( byte ) 0x80 ) ;
550+ /****
551+ * So we want to count the number of 4-byte sequences,
552+ * the number of 4-byte sequences, 3-byte sequences, and
553+ * the number of 2-byte sequences.
554+ * We can do it indirectly. We know how many bytes in total
555+ * we have (length). Let us assume that the length covers
556+ * only complete sequences (we need to adjust otherwise).
557+ * We have that
558+ * length = 4 * n4 + 3 * n3 + 2 * n2 + n1
559+ * where n1 is the number of 1-byte sequences (ASCII),
560+ * n2 is the number of 2-byte sequences, n3 is the number
561+ * of 3-byte sequences, and n4 is the number of 4-byte sequences.
562+ *
563+ * Let ncon be the number of continuation bytes, then we have
564+ * length = n4 + n3 + n2 + ncon + n1
565+ *
566+ * We can solve for n2 and n3 in terms of the other variables:
567+ * n3 = n1 - 2 * n4 + 2 * ncon - length
568+ * n2 = -2 * n1 + n4 - 4 * ncon + 2 * length
569+ * Thus we only need to count the number of continuation bytes,
570+ * the number of ASCII bytes and the number of 4-byte sequences.
571+ */
572+ ////////////
573+ // The *block* here is what begins at processedLength and ends
574+ // at processedLength/16*16 or when an error occurs.
575+ ///////////
576+ int start_point = processedLength ;
577+
578+ // The block goes from processedLength to processedLength/16*16.
579+ int asciibytes = 0 ; // number of ascii bytes in the block (could also be called n1)
580+ int contbytes = 0 ; // number of continuation bytes in the block
581+ int n4 = 0 ; // number of 4-byte sequences that start in this block
551582 for ( ; processedLength + 16 <= inputLength ; processedLength += 16 )
552583 {
553584
554- Vector128 < byte > currentBlock = Sse2 . LoadVector128 ( pInputBuffer + processedLength ) ;
555-
556- int mask = Sse2 . MoveMask ( currentBlock ) ;
585+ Vector128 < byte > currentBlock = Avx . LoadVector128 ( pInputBuffer + processedLength ) ;
586+ int mask = Sse42 . MoveMask ( currentBlock ) ;
557587 if ( mask == 0 )
558588 {
559589 // We have an ASCII block, no need to process it, but
560590 // we need to check if the previous block was incomplete.
561- if ( Sse2 . MoveMask ( prevIncomplete ) != 0 )
591+ //
592+
593+ if ( ! Sse41 . TestZ ( prevIncomplete , prevIncomplete ) )
562594 {
563- return SimdUnicode . UTF8 . RewindAndValidateWithErrors ( processedLength , pInputBuffer + processedLength , inputLength - processedLength , ref TempUtf16CodeUnitCountAdjustment , ref TempScalarCountAdjustment ) ;
595+ int off = processedLength >= 3 ? processedLength - 3 : processedLength ;
596+ byte * invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 16 - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
597+ // So the code is correct up to invalidBytePointer
598+ if ( invalidBytePointer < pInputBuffer + processedLength )
599+ {
600+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
601+ }
602+ else
603+ {
604+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
605+ }
606+ int totalbyteasciierror = processedLength - start_point ;
607+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , totalbyteasciierror ) ;
608+ return invalidBytePointer ;
564609 }
565610 prevIncomplete = Vector128 < byte > . Zero ;
566611 }
567- else
612+ else // Contains non-ASCII characters, we need to do non-trivial processing
568613 {
614+ // Use SubtractSaturate to effectively compare if bytes in block are greater than markers.
569615 // Contains non-ASCII characters, we need to do non-trivial processing
570616 Vector128 < byte > prev1 = Ssse3 . AlignRight ( currentBlock , prevInputBlock , ( byte ) ( 16 - 1 ) ) ;
571617 Vector128 < byte > byte_1_high = Ssse3 . Shuffle ( shuf1 , Sse2 . ShiftRightLogical ( prev1 . AsUInt16 ( ) , 4 ) . AsByte ( ) & v0f ) ;
@@ -575,54 +621,93 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
575621 Vector128 < byte > prev2 = Ssse3 . AlignRight ( currentBlock , prevInputBlock , ( byte ) ( 16 - 2 ) ) ;
576622 Vector128 < byte > prev3 = Ssse3 . AlignRight ( currentBlock , prevInputBlock , ( byte ) ( 16 - 3 ) ) ;
577623 prevInputBlock = currentBlock ;
624+
578625 Vector128 < byte > isThirdByte = Sse2 . SubtractSaturate ( prev2 , thirdByte ) ;
579626 Vector128 < byte > isFourthByte = Sse2 . SubtractSaturate ( prev3 , fourthByte ) ;
580627 Vector128 < byte > must23 = Sse2 . Or ( isThirdByte , isFourthByte ) ;
581628 Vector128 < byte > must23As80 = Sse2 . And ( must23 , v80 ) ;
582629 Vector128 < byte > error = Sse2 . Xor ( must23As80 , sc ) ;
583- if ( Sse2 . MoveMask ( error ) != 0 )
630+
631+ if ( ! Sse42 . TestZ ( error , error ) )
584632 {
585- return SimdUnicode . UTF8 . RewindAndValidateWithErrors ( processedLength , pInputBuffer + processedLength , inputLength - processedLength , ref TempUtf16CodeUnitCountAdjustment , ref TempScalarCountAdjustment ) ;
633+
634+ byte * invalidBytePointer ;
635+ if ( processedLength == 0 )
636+ {
637+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
638+ }
639+ else
640+ {
641+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
642+ }
643+ if ( invalidBytePointer < pInputBuffer + processedLength )
644+ {
645+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
646+ }
647+ else
648+ {
649+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
650+ }
651+ int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
652+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
653+ return invalidBytePointer ;
586654 }
587- prevIncomplete = Sse2 . SubtractSaturate ( currentBlock , maxValue ) ;
655+
656+ prevIncomplete = Sse3 . SubtractSaturate ( currentBlock , maxValue ) ;
657+
658+ contbytes += ( int ) Popcnt . PopCount ( ( uint ) Sse42 . MoveMask ( byte_2_high ) ) ;
659+ // We use two instructions (SubtractSaturate and MoveMask) to update n4, with one arithmetic operation.
660+ n4 += ( int ) Popcnt . PopCount ( ( uint ) Sse42 . MoveMask ( Sse42 . SubtractSaturate ( currentBlock , fourthByte ) ) ) ;
588661 }
662+
663+ // important: we just update asciibytes if there was no error.
664+ // We count the number of ascii bytes in the block using just some simple arithmetic
665+ // and no expensive operation:
666+ asciibytes += ( int ) ( 16 - Popcnt . PopCount ( ( uint ) mask ) ) ;
589667 }
590- }
591- }
592- // We have processed all the blocks using SIMD, we need to process the remaining bytes.
593668
594- // Process the remaining bytes with the scalar function
595- if ( processedLength < inputLength )
596- {
597- // We need to possibly backtrack to the start of the last code point
598- // worst possible case is 4 bytes, where we need to backtrack 3 bytes
599- // 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
600- if ( processedLength > 0 && ( sbyte ) pInputBuffer [ processedLength ] <= - 65 )
601- {
602- processedLength -= 1 ;
603- if ( processedLength > 0 && ( sbyte ) pInputBuffer [ processedLength ] <= - 65 )
669+
670+ // We may still have an error.
671+ if ( processedLength < inputLength || ! Sse42 . TestZ ( prevIncomplete , prevIncomplete ) )
604672 {
605- processedLength -= 1 ;
606- if ( processedLength > 0 && ( sbyte ) pInputBuffer [ processedLength ] <= - 65 )
673+ byte * invalidBytePointer ;
674+ if ( processedLength == 0 )
675+ {
676+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
677+ }
678+ else
679+ {
680+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
681+
682+ }
683+ if ( invalidBytePointer != pInputBuffer + inputLength )
684+ {
685+ if ( invalidBytePointer < pInputBuffer + processedLength )
686+ {
687+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
688+ }
689+ else
690+ {
691+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
692+ }
693+ int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
694+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
695+ return invalidBytePointer ;
696+ }
697+ else
607698 {
608- processedLength -= 1 ;
699+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
609700 }
610701 }
611- }
612- int TailScalarCodeUnitCountAdjustment = 0 ;
613- int TailUtf16CodeUnitCountAdjustment = 0 ;
614- byte * invalidBytePointer = SimdUnicode . UTF8 . GetPointerToFirstInvalidByteScalar ( pInputBuffer + processedLength , inputLength - processedLength , out TailUtf16CodeUnitCountAdjustment , out TailScalarCodeUnitCountAdjustment ) ;
615- if ( invalidBytePointer != pInputBuffer + inputLength )
616- {
617- // An invalid byte was found by the scalar function
618- return invalidBytePointer ;
702+ int final_total_bytes_processed = inputLength - start_point ;
703+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , final_total_bytes_processed ) ;
704+ return pInputBuffer + inputLength ;
619705 }
620706 }
621-
622- return pInputBuffer + inputLength ;
707+ return GetPointerToFirstInvalidByteScalar ( pInputBuffer + processedLength , inputLength - processedLength , out utf16CodeUnitCountAdjustment , out scalarCountAdjustment ) ;
623708 }
624709
625-
710+ //
626711 public unsafe static byte * GetPointerToFirstInvalidByteAvx2 ( byte * pInputBuffer , int inputLength , out int utf16CodeUnitCountAdjustment , out int scalarCountAdjustment )
627712 {
628713 int processedLength = 0 ;
0 commit comments