Piège des nombres à virgules flottante

J’ai l’habitude d’être très prudent avec la gestion des nombres à virgule flottante, notamment lorsqu’il y a des allers-retours avec des nombres entiers, mais je me suis récemment fait avoir comme un bleu sous R, probablement parce qu’une interruption de plusieurs jours dans mon écriture du script m’a fait oublier l’origine des données, ainsi qu’une autre problématique ayant détourné mon attention à un moment critique.

Afin de bien présenter cette problématique, je vais exposer tous les éléments en jeu en essayant de ne pas demander trop de pré-requis.

Représentation binaire d’un nombre entier

Commençons par décrire la représentation décimale des nombres entiers qu’on apprend à l’école primaire. Un nombre comme 532 comporte les trois chiffres « cinq », « trois » et « deux » dans le système décimal. Le nombre 532 est simplement égal à 5×100 + 3×10 + 2×1 = 5×10^2 + 3×10^1 + 2×10^0. De droite à gauche, les trois chiffres ont un « poids » croissant, selon les puissances de dix. Le chiffre 2 a un poids d’une unité, alors que le chiffre 3 représente les dizaines et le chiffre 5 représente les centaines. On peut représenter n’importe quel nombre entier positif ainsi. Le choix d’avoir dix chiffres (les chiffres de 0 à 9) plutôt que douze ou n’importe quel autre nombre est arbitraire. Une représentation bien plus simple pour les ordinateurs est la représentation binaire dans laquelle il n’y a que deux chiffres (0 et 1) et dans laquelle le poids des différents chiffres d’une représentation numérique correspondant à des puissances de deux. Ainsi, la représentation numérique binaire 11010 sera interprétée comme 1×2^4 + 1×2^3 + 0×2^2 + 1×2^1 + 0×2^0 = 1×16 + 1×8 + 0×4 + 1×2 + 0×0 = 16+8+2 = 26. Une autre représentation, assez populaire en informatique, est la représentation hexadécimale dans laquelle on a seize chiffres : 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F. Par exemple, la représentation numérique hexadécimale B16 correspondra au nombre entier 11×16^2 + 1×16^1 + 6×16^0 = 11×256 + 1×16 + 6×1 = 2838.

Pour la plupart de leurs calculs, les ordinateurs utiliseront, en interne, un format binaire avec un nombre fixé de chiffres. Par exemple, un ordinateur 32 bits manipulera majoritairement des nombres avec exactement 32 chiffres binaires, pas un de plus, ni un de moins. Le terme de bit est ici synonyme de chiffre binaire. Pour représenter les petits nombres, on utilisera des zéros comme préfixe. Par exemple, la valeur trois sera représentée par la séquence de 32 chiffres binaires suivante:

0000 0000 0000 0000 0000 0000 0000 0011

(les espaces ici sont seulement présentés pour faciliter la lecture mais ne sont pas signifiants)

En lisant cette représentation de droite à gauche, on peut calculer qu’elle représente le nombre 1×2^0 + 1×2^1 + 0×2^2 + 0×2^3 + … + 0×2^31 = 1 + 2 + 0 + 0 + … + 0 = 3

Le plus grand nombre que l’on puisse représenter ainsi ne contient que des 1 et est égal à la somme de toutes les puissances entières de deux, de 2^0 jusqu’à 2^31. On peut aisément montrer par récurrence que ce nombre est égal à 2^32 – 1, soit 4294967295 ; ce qui est suffisamment grand pour les décomptes de tous les jours. D’une manière générale, avec N bits, on peut représenter tous les nombre entiers de 0 jusqu’à 2^N-1, pour un total de 2^N nombres distincts. Il faut dire qu’une séquence ordonnée de N chiffres binaires dispose d’exactement 2^N combinaisons possibles (cela se démontre aisément par récurrence) si on considère que l’ordre a de l’importance, c’est-à-dire que 0011 n’est pas équivalent à 1100 ni à 1010. Pour les nombres entiers relatifs (positifs ou négatifs), il existe une représentation particulièrement astucieuse (représentation en complément à deux) si on connaît ℤ/nℤ, que je ne décrirai pas ici, car ce n’est pas nécessaire à mon exposé.

Un ordinateur moderne 64 bits, va principalement manipuler des nombres entiers 8 bits, 32 bits et 64 bits, selon le contexte.

Nombres à virgule flottante

Représentation

R, comme à peu près tous les logiciels et langages de programmation gèrent principalement les nombres réels de manière approximative sous forme de nombres à virgule flottante, communément appelés floating point values, ou FP. Cela permet de représenter approximativement des nombres comme 0.537877 ou 8.11215×10^56.

Les FP sont omniprésents en informatique : que ce soit sur une calculatrice, le processeur graphique (GPU) qui affiche un jeu vidéo sur mobile, PC, Mac ou console, n’importe quelle vidéo sur n’importe quelle plateforme postérieure à l’an 2000, ainsi que de nombreux formats de compression de sons ou d’image avec perte ; enfin, c’est encore utilisé par les réseaux neuronaux convolutionnels. C’est pourquoi, sur une vie entière au XXIème siècle, on est forcément confronté à un moment ou à un autre aux problèmes de ces représentations numériques. Enfin, un statisticien est amené à manipuler des FP dès qu’il utilise des variables quantitatives continues ou qu’il calcule des moyennes, pourcentages, odds ratio, hazard ratio ou toute autre statistique qui ne soit pas structurellement un nombre entier. Même si on peut utiliser des FP pendant des années sans en connaître les problèmes, on se cassera forcément une dent dessus à un moment où à un autre de sa carrière, parce que les FP sont structurellement imparfaits : ils tentent de résoudre au moins mal le problème insoluble de la représentation des nombres réels, appartenant à un ensemble infini indénombrable, dans un espace numérique fini et très borné.

Comme pour les nombres entiers, les représentations utiliseront typiquement des formats avec un nombre de bits bien prédéfini. R utilisera majoritairement des nombres à virgule flottante à double précision sur 64 bits, connus sous le nom de FP64 (conformes à la norme IEEE-754). Ces 64 bits sont divisés en trois partie : la mantisse, sur 52 bits, l’exposant sur 11 bits et le signe, sur 1 bit, pour un total de 52+11+1 = 64 bits. Avec ces trois composants, on représentera le nombre : signe × mantisse × 2^exposant. Le signe vaudra -1 ou +1. La mantisse sera un nombre entier positif. L’exposant sera un nombre entier positif ou négatif. Par exemple on pourra représenter -0,75 comme -1×3×2^-2. On constate, de prime abord, qu’il y a plusieurs représentations possibles pour le même nombre. On pourra représenter 16 avec une mantisse à 16 et un exposant à 0 (16×2^0), ou avec une mantisse à 8 et un exposant à 1 (16×2^1), ou encore, avec une mantisse à 4 et un exposant à 2 (4×2^2). Il existe un processus de normalisation, qui force l’usage d’une seule représentation avec une mantisse qui débute par le chiffre binaire 1 (sauf valeur trop petite, appelée dénormale). Ce bit à 1 est implicite dans la représentation binaire, ce qui permet d’avoir une précision de 53 bits sur la mantisse quand bien même elle n’est représentée que sur 52 bits. En bref, en dédoublonnant les représentations, on gagne en espace de représentation. C’est assez technique, mais ça ne change pas le fond du problème.

Problèmes (bugs insolubles) des nombres à virgules flottante

Cette représentation de nombre à virgule flottante souffre des mêmes forces et défaut que la représentation « scientifique » à quatre chiffres significatifs que vous avez peut-être apprise au lycée. Dans cette représentation, on va représenter n’importe quel nombre sous la forme A,BCD×10^EF où A, B, C, D, E et F représentent des chiffres de 0 à 9, sauf A qui vaudra au moins 1 ; sauf pour représenter zéro. Cette représentation est très précise, en absolu, pour les tous petits nombres. Par exemple, on pourra exprimer un volume de 1,567 microlitre comme 1,567×10^-6 litre ; ici, on a la précision du nanolitre. Cette même représentation sera beaucoup moins précise pour les gros nombres. Par exemple, mille deux-cent trois mètres cubes seront exprimés comme 1,203×10^6 litres, avec une précision du mètre cube. Si on veut ajouter quinze litres aux mille deux-cent trois mètres cubes, on va obtenir 1,203015×10^6 litres qui seront arrondis à 1,203×10^6 litres pour ne garder que les quatre chiffres significatifs selon le principe d’arrondi à la plus proche valeur représentable. Au final, le nombre est inchangé par l’addition des quinze litres, en raison de l’erreur d’arrondi finale. Ce problème d’arrondi touche aussi bien les nombres binaires à virgule flottante tels que les FP64. Même s’ils sont capables de représenter des nombres très proches de zéro, tels que 1/(2^1000), soit 2^-1000, et des nombres très grands, tels que 2^1000, l’addition de chiffres qui sont d’ordre de grandeur très différent, conduit à une perte de précision dommageable. Notamment, si on ajoute un nombre très grand à un nombre très petit, puis qu’on retranche le nombre très grand, alors le nombre très petit peu perdre énormément en précision, voire être annulé.

Ci-dessous, des exemples sous R:

> 1e9 + 42 - 1e9 # ça reste tolérable
[1] 42
> 1e17 + 42 - 1e17 # on commence à perdre
[1] 48
> 5e17 + 42 - 5e17 # ça fait mal
[1] 64
> 1e20 + 42 - 1e20 # on perd tout
[1] 0

L’ordre dans lequel on fait les calculs se met à avoir une importance !

> 5e20 - 5e20 + 42
[1] 42
> 1e20 + 42 - 1e20
[1] 0

En effet, la première ligne commence par le calcul (5e20 – 5e20) qui fait zéro, qui est ensuite additionné à 42, alors que la seconde ligne commence par 1e20+42 qui fait 1e20, puis retranche 1e20, ce qui conduit à la valeur zéro.

Un autre problème, c’est l’impossibilité de représenter de manière exacte certaines fractions pourtant simples, telles qu’un tiers. En effet, un tiers est approximé, en décimal, par 0.3333333333. Pour une représentation exacte, il faudrait une infinité de chiffres après la virgule. Même si 1/5 peut être représenté en décimal, il ne peut pas être représenté en binaire, parce que 5 est premier avec 2.

C’est ainsi, que des calculs très simples peuvent conduire à des résultats légèrement erronés:

> (1/5-1+1)*5 - 1
[1] -2.220446e-16

On peut se dire que le résultat est « suffisamment » proche de la bonne réponse (zéro) pour que ce soit tolérable, mais ça peut, en réalité, conduire à des réponses totalement erronées dès qu’on utilise l’opérateur d’égalité:

> (1/5-1+1)*5 == 1
[1] FALSE

Si R disposait d’un moteur de calcul exact, il afficherait TRUE ici. Sa réponse, d’une certaine manière, est complètement erronée.

Pour rendre les choses plus compliquées, R n’est pas honnête sur le vrai contenue des valeurs numériques, les arrondissant à quelques chiffres après la virgule (7 chiffres significatifs par défaut) pour l’affichage comme le montre le code ci-dessous:

> x=(1/5-1+1)*5
> x # on dirait qu'on a bien la valeur 1
[1] 1
> 1 == x # mais ce n'est pas vraiment 1
[1] FALSE
> x-1 # en fait, c'est un poil inférieur à 1
[1] -2.220446e-16
> 1-1 # alors que le "vrai" 1 est bien égal à 1
[1] 0

Ce comportement peut être corrigé par l’option digits, comme décrit ci-dessous:

> options(digits=22)
> (1/5-1+1)*5
[1] 0.99999999999999978
> 1/5
[1] 0.20000000000000001
> 1/5-1+1
[1] 0.19999999999999996

Là, on comprend beaucoup mieux les erreurs d’arrondi !

Il faut donc considérer que TOUT calcul avec nombre à virgule flottante est approximatif et que deux manières très légèrement différentes d’arriver au même calcul conduiront à deux approximations légèrement différentes. En fait, dans certains langages de programmation (p.e. le langage C), on n’a même pas la garantie qu’en réalisant deux fois exactement le même calcul on arrive au même résultat, parce que le compilateur peut traduire le code de manière différente.

Pour « tester » l’égalité, il faut toujours se donner des marges de manoeuvre, en comparant la valeur absolue de la différence à une valeur seuil correspondant à l’erreur maximale attendue.

Par exemple, plutôt que de tester (1/5-1+1)*5 == 1, on testera abs((1/5-1+1)*5 – 1)<1e-14.

On ne peut pas donner de valeur seuil universelle, parce que la bonne constante dépend de l’échelle sur laquelle on se trouve et de l’erreur cumulée des calculs. C’est-à-dire, que quand on manipule des nombres de l’ordre de 10^20, une erreur de plusieurs unités, voire plusieurs dizaines est rapidement arrivée, alors que quand on manipule correctement des nombres de l’ordre de 10^-20, alors on s’attend à avoir des erreurs inférieures à 10^-30. Parfois, on peut comparer le rapport entre les deux nombres à une valeur précise, mais ça ne marche que si les nombres restent écartés, d’une certaine manière, de la valeur zéro. Ce problème de connaissance de la bonne échelle se retrouve dans des problématiques apparentées, telles que la détection de convergence par stabilisation d’un paramètre (cf https://bugs.r-project.org/bugzilla/show_bug.cgi?id=17885).

Il faut aussi savoir que même lorsqu’on manipule des nombres du même ordre de grandeur, les erreurs d’arrondi ont tendance à se cumuler. Même si les erreurs positives compensent en partie les erreurs négatives, grace à la méthode d’arrondi au plus proche, plus on fera d’opérations, plus l’erreur sur le résultat final sera importante.

Il existe une seule situation, fréquente dans R, pour laquelle les calculs peuvent être considérés comme exacts : lorsqu’on manipule des nombres entiers suffisamment petits (< 2^52) sous forme de nombre à virgule flottante, en veillant à ne faire aucune opération qui les fasse passer en nombre non entiers. En effet, dans R, si on oublie de préciser un suffixe ‘L’ derrière une constante numérique, elle représente un nombre à virgule flottante.

> storage.mode(42L) # vrai nombre entier
[1] "integer"
> storage.mode(42) # nombre à virgule flottante
[1] "double"

La manipulation du nombre à virgule flottante 42 est relativement sûre. J’appelerai ces nombres à virgule flottante représentant des entiers, des pseudo-entiers.

> 42*2 - 42 == 42 # ça reste sûr
[1] TRUE

Les opérations acceptables avec les pseudo-entiers sont : l’addition, la soustraction, la multiplication. L’exponentiation par l’opérateur ^ est acceptable si on reste dans des valeurs entières pas trop grandes (en restant en dessous de 2^52 pour le résultat final) et qu’on reste sur plateforme PC. En effet, elle repose sur la fonction pow() en langage C dont je crains qu’elle soit fortement approximative sur beaucoup de plateformes, en passant par des fonctions logarithme et exponentielles qui n’aient pas toute la précision des FP64. Cet opérateur ^ est toujours risqué car il n’existe pas en version vraiment entière mais seulement sur les nombres à virgule flottante, comme le montre le code suivant:

> storage.mode(3L^3L)
[1] "double"

Les pseudo-entiers ne marchent plus correctement à partir de 2^54:

> x=18014398509481984
> x+1 == x
[1] TRUE

R fournit quelques points de repères sur la précision des nombres à virgule flottantes dans l’objet .Machine.

Par exemple la constante .Machine$double.eps environ égale à 2.22e-16 est la plus petite valeur x telle que 1 + x != 1 selon la documentation. En réalité, ce n’est pas tout à fait vrai, il s’agirait plutôt de la plus petite valeur telle que 1 + x -1 == x.

Représentations spéciales

Outre les problèmes d’arrondi, les FP ont aussi des problèmes d’overflow et underflow parce que l’exposant est limité (variant de -1023 à +1024). La valeur 1,8×10^308 est le plus grand nombre représentable sur FP64 alors que 5×10^-324 est le nombre le plus petit (proche de zéro) représentable. Si on dépasse le plus grand nombre, on fait un « overflow » qui conduit à une valeur spéciale Inf représentant l’infini. Si on passe en dessous du plus petit nombre représentable, alors on atteint la valeur zéro par un « underflow ».

> 1e308*100
[1] Inf
> 1e-324/100 == 0
[1] TRUE

Il existe deux zéros différents et deux infinis différents : positifs et négatifs, avec des relations assez attendues entre les quatre:

> 1/0
[1] Inf
> 1/-0
[1] -Inf
> 1/Inf
[1] 0
> 1/(1/Inf)
[1] Inf
> 1/-Inf
[1] 0
> 1/(1/-Inf)
[1] -Inf

Néanmoins, les deux zéros sont égaux sans que leurs inverses le soient !

> pos=0
> neg=-0
> pos == neg
[1] TRUE
> 1/pos == 1/neg
[1] FALSE
> sign(pos)
[1] 0
> sign(neg)
[1] 0

La valeur Inf représente donc un nombre supérieur à tout ce qui est représentable mais inconnu. Alors que le +0 représente un nombre positif inférieur à tout ce qui est représentable mais inconnu. Certains calculs, tels que Inf – Inf conduisent à un résultat totalement inconnu qui est représenté par une valeur spéciale NaN. Cette valeur NaN se lit « Not a Number ».

> Inf - Inf
[1] NaN

Ce NaN se propage un peu de la même manière que les données manquantes (NA). La fonction is.na reconnaît d’ailleurs ces NaN comme des données manquantes. La fonction is.nan, par contre est spécifique des NaN.

Cas rapporté

Revenons à mon cas, qui illustre bien la perversité des FP si on n’est pas attentif. J’avais plusieurs fichiers de données correspondant à la capture de mouvement de plusieurs mires sur un mouvement de marche de plusieurs sujets. La capture de mouvement était faite cent fois par seconde par des caméras. Le timing de certains événements était enregistré en secondes, avec deux chiffres après la virgule. Ainsi, le temps 1,75 seconde correspondait à la frame numéro 175 de l’enregistrement. Une simple multiplication par 100 permettait de passer du timing enregistré au numéro de frame.

Pour un des calculs, j’avais besoin de calculer, sur un intervalle temporelle, la différence entre la coordonnée Z (hauteur) minimale observée sur l’intervalle [T1 ; T2] et la coordonnée Z observée à la fin de l’intervalle, c’est-à-dire au temps T2. Étant donné que T2 fait partie de l’intervalle, cette différence était obligatoirement positive ou nulle. J’observais pourtant des valeurs négatives ! Cela me conduit à un bout de code ressemblant à peu près à :

> min(Z[T1:T2])
[1] 23.78
> Z[T2]
[1] 21.14

(j’ai changé les noms des variables et les valeurs numériques, mais l’idée reste là)

Voilà à peu près ce qu’ont donné les investigations:

> T1 # ok
[1] 10
> T2 # ok
[1] 15
> T1:T2 # semble ok
[1] 10 11 12 13 14 15
> Z=1:15 # ça va simplifier l'interprétation
> Z[T2] # toujours ok
[1] 15
> Z[T1:T2] # tout est décalé vers le bas !
[1]  9 10 11 12 13 14
> Z[10:15]# mais là ça marche 
[1] 10 11 12 13 14 15

Pouvez-vous deviner le problème ?

Il vient de T1 et T2 qui ne sont pas vraiment entiers mais flottant, très proches de nombres entiers, d’une manière particulièrement perverse:

> T1-10
[1] -1.065814e-14
> T2-15
[1] 1.065814e-14

Et voilà !

Avec options(digits=22) on comprend encore mieux ce qui se passe:

> options(digits=22)
> T1
[1] 9.9999999999999893
> T2
[1] 15.000000000000011
> T1:T2
[1]  9.9999999999999893 10.9999999999999893 11.9999999999999893
[4] 12.9999999999999893 13.9999999999999893 14.9999999999999893
> as.integer(T2)
[1] 15
> as.integer(T1:T2)
[1]  9 10 11 12 13 14

Il faut comprendre que l’opérateur deux-points est un raccourci pour la fonction seq(from,to,by) qui marche avec des nombres non entiers. C’est ainsi, qu’on peut générer des séquences de nombres, en additionnant répététivement by à from, sans dépasser to. Ainsi, seq(0, 1, 0.40) va générer la sequence c(0, 0.40, 0.80). Eh bien, A:B est équivalent à seq(A, B, 1). Ici, le décalage vers le bas de T1 est répercuté sur la séquence entière. Ensuite, lorsqu’on indexe un vecteur par des indices non entiers, ils sont implicitement convertis par la fonction as.integer() qui tronque la virgule, c’est-à-dire, qui arrondit vers le bas pour les nombres positifs et vers le haut pour les nombres négatifs.

Le problème provenait des timings T1 et T2 que j’avais oublié d’arrondir au plus proche après la multiplication par 100 pour passer des secondes aux centisecondes.

Une fois détecté, le problème est trivial à résoudre, avec un simple appel à la fonction round().

Conséquences

Cette erreur a eu des conséquences non négligeables. J’ai rendu un rapport d’analyse statistique préliminaire dont les résultats étaient très légèrement erronés, car ils étaient fondés sur des valeurs qui avaient été calculés avec des décalages accidentels d’un centième de seconde. Ce n’est que lorsque j’ai voulu calculer des statistiques post hoc, que j’ai découvert l’incohérence ; celle-ci m’a ensuite conduit à découvrir le bug. Heureusement, une fois l’incohérence découverte l’investigation du bug n’a demandé que quelques minutes.

Ce qu’il faut en retenir

Il faut bien connaître les FP, parce qu’on va forcément se les prendre dans la face un jour ou l’autre. Je conseille aussi d’avoir une approche très dichotomique et systématique au débogage. Non seulement, cela permet de corriger son bug, mais parfois aussi d’en découvrir dans le logiciel et de faire un beau bug report !

Est-ce un bug de R ? Certains diraient que it works as designed ce à que je répondrais it is badly designed. Plutôt qu’un bug report, cela mériterait une feature request, pour demander au beau message d’erreur pour des indiçages aussi foireux.

Pour aller plus loin

Je conseille le site Web suivant: https://randomascii.wordpress.com/2012/04/21/exceptional-floating-point/

Update (14/05/2021)

Le logiciel R utilise pour certaines routines, sur plateforme PC (i386 ou x86_64), les FP80 de précision étendue offerts par l’interface x87 des microprocesseurs des PC plutôt que l’interface SSE2 qui implémente des FP64. C’est ainsi, que certaines routines fonctionnent avec une précision excédentaire.

> x=0;for(i in 1:1e9) {x=x+1/3};print(x); # FP64
[1] 333333332.6651181
> sum(rep(1/3,1e9)) # FP80 en interne
[1] 333333333.33365959
> 1/3*1e9 # pas de cumul d'erreur (en relatif)
[1] 333333333.33333331

Avec les FP80, l’erreur cumulée est 2048 fois plus faible, révélant une mantisse 63 bits plutôt que 52 bits, soit 11 bits supplémentaires, conduisant à une erreur 2^11 = 2048 fois plus faible.

Les fonctions rowSums, colSums, rowMeans, colMeans fonctionnent aussi en FP80 sur plateforme PC. On peut imaginer qu’on gagnerait nettement en performance si R utilisait des FP64 avec du SSE2 (SIMD 2×FP64) ou de l’AVX2 (SIMD 4×FP64), voire de l’AVX-512 (SIMD 8×FP64), mais en réalité, même pour un logiciel numérique comme R, le gain n’en vaut pas la chandelle par rapport au gain de précision fourni par les vieux FP80 des années 80 qui ont fait la renommée du Pentium. Pour ceux qui voudraient absolument booster les performances au détriment de la précision, il existe rowsums dans le package Rfast.

> Rfast::rowsums(matrix(nrow=1,rep(1/3,1e9)))
[1] 333333332.6651181

Mais le gain de performances est loin d’être extraordinaire sur des énormes vecteurs qui débordent largement de la cache L3, même lorsque la préparation de la matrice est minimaliste (testé sur AMD Ryzen 1700 qui dispose de 2 unités 128 bits capables de faire des additions 2×FP64):

> system.time(Rfast::colsums(matrix(nrow=1e9,ncol=1,1/3)))
   user  system elapsed 
   1.58    0.41    1.99 
> system.time(colSums(matrix(nrow=1e9,ncol=1,1/3)))
   user  system elapsed 
   2.56    0.49    3.05

Sur des vecteurs de taille modeste, tenant en cache L1, on peut perdre en performance:

> system.time(replicate(1e5,Rfast::colsums(matrix(nrow=1e3,ncol=1,1/3))))
   user  system elapsed 
   1.53    0.00    1.54 
> system.time(replicate(1e5,colSums(matrix(nrow=1e3,ncol=1,1/3))))
   user  system elapsed 
   1.13    0.00    1.13 

C’est probablement dû à l’inertie de lancement de la fonction Rfast::colsums. Cela nous montre que même sur des microbenchmarks, le gain SIMD est loin d’être extraordinaire.

Si on veut gagner un facteur deux en performances, on est obligé de créer un scenario spécifiquement optimisé pour ça:

> system.time(replicate(1e4,Rfast::colsums(matrix(nrow=1e5,ncol=1,1/3))))
   user  system elapsed 
   1.26    0.00    1.27 
> system.time(replicate(1e4,colSums(matrix(nrow=1e5,ncol=1,1/3))))
   user  system elapsed 
   2.33    0.00    2.33

Sur microarchitecture Intel, le gain devrait être nettement supérieur (https://www.agner.org/optimize/blog/read.php?i=838) depuis Haswell grace aux deux units 256 bits capables toutes deux de faire des additions de 4 FP64.

Update 2 (14/05/2021)

La fonction Rfast::colsums semble ne pas utiliser AVX2. Afin de vraiment évaluer les performances d’AVX2, j’ai dû développer un petit package:

#include <Rinternals.h>

#define NBYTES 32
#define NFP64 (NBYTES/8)
typedef double multifp __attribute__((vector_size (NBYTES)));

static double fpaddpar(double  *v, R_xlen_t n) {
        multifp psum1={0};
        multifp psum2={0};
        R_xlen_t i=0;
        double asum=0;
        
        /* ensure alignment */
        size_t uv = (size_t)v;
        while ((uv % NBYTES)!=0) {
                asum += *v++;
                uv = (size_t)v;
                n--;
        }

        multifp * restrict w=(multifp * restrict)v;
        R_xlen_t n2 = n/NFP64;
        for (i=0; i <= n2-2; i+=2) {
                psum1 += w[i];
                psum2 += w[i+1];
        }
                
        psum1 += psum2;
        #if NBYTES == 16
        	asum += psum1[0] + psum1[1];
        #elif NBYTES == 32
        	asum += psum1[0] + psum1[1] + psum1[2] + psum1[3];
        #else
        	#error number of bytes unsupported
        #endif
        
        for(i=i*NFP64; i < n; i++) {
                asum += v[i];
        }

        return asum;
}

static double slow_fpaddpar(double  *v, R_xlen_t n) {
  long double out=0;
  for(R_xlen_t i=0;i<n;i++) {
   out += v[i];
  }
  return out;
}

static SEXP fastsum(SEXP v) {
  SEXP result = PROTECT(allocVector(REALSXP, 1));
  REAL(result)[0] = fpaddpar(REAL(v), XLENGTH(v));

  UNPROTECT(1);
  return result;
}
static SEXP slowsum(SEXP v) {
  SEXP result = PROTECT(allocVector(REALSXP, 1));
  REAL(result)[0] = slow_fpaddpar(REAL(v), XLENGTH(v));

  UNPROTECT(1);
  return result;
}
static const R_CallMethodDef callMethods[]  = {
  {"fastsum", (DL_FUNC) &fastsum, 1},
  {"slowsum", (DL_FUNC) &slowsum, 1},
  {NULL, NULL, 0}
};

__declspec(dllexport) void R_init_fastsum(DllInfo *info) {
  R_registerRoutines(info, NULL, callMethods, NULL, NULL);
}

Le coeur de l’algorithme peut être retrouvé avec objdump -S pour avoir le code assembleur (usage de GCC 8.3.0 64 bits avec les options -O2 -march=haswell -mtune=haswell) :

 120:   48 83 c2 02             add    $0x2,%rdx
 124:   c4 c1 7d 58 00          vaddpd (%r8),%ymm0,%ymm0
 129:   c4 c1 75 58 48 20       vaddpd 0x20(%r8),%ymm1,%ymm1
 12f:   49 83 c0 40             add    $0x40,%r8
 133:   48 39 d1                cmp    %rdx,%rcx
 136:   7f e8                   jg     120 <fastsum+0x90>

Ainsi, AVX2 est bien utilisé, avec une parallélisation qui permet d’exploiter à fond les deux unités 256 bits du Haswell.

En créant des conditions artificielles absolument idéales, on voit effectivement un gain:

> v=rep(1/3,1e5) # AMD Ryzen 1700 à 3.00 Ghz
> system.time(replicate(1e4,.Call(C_fastsum,v)))
   user  system elapsed 
   0.14    0.00    0.14 
> system.time(replicate(1e4,sum(v)))
   user  system elapsed 
   1.61    0.00    1.61 

Mais dès qu’on prend en compte le temps de génération du vecteur, même simple, ce bénéfice s’estompe énormément:

> system.time(replicate(1e4,.Call(C_fastsum,rep(1/3, 1e5))))
   user  system elapsed 
   2.53    0.94    3.47 
> system.time(replicate(1e4,sum(rep(1/3, 1e5))))
   user  system elapsed 
   3.67    0.89    4.56

De même sur Haswell (core i5-4460 à 3.20 Ghz)

> v=rep(1/3,1e5)
> system.time(replicate(1e4,.Call(C_fastsum,v)))
   user  system elapsed 
   0.21    0.00    0.21 
> system.time(replicate(1e4,sum(v)))
   user  system elapsed 
   0.91    0.00    0.91 
> system.time(replicate(1e4,.Call(C_fastsum,rep(1/3,1e5))))
   user  system elapsed 
   2.28    1.30    3.58 
> system.time(replicate(1e4,sum(rep(1/3,1e5))))
   user  system elapsed 
   2.93    1.26    4.22

Étonnamment, le vieux Haswell core i5-4460 de 2014 est plus performant en FP80 que le Ryzen de 2017. En tout cas, les gains du SIMD se sentent toujours sur les microbenchmarks mais sont peu susceptibles d’avoir une influence dans du code R réel.

Update 3 (14/05/2021)

Après un déroulage de la boucle des FP80, la différence de performance s’amenuise:


#define NBYTES 32
#define NFP64 (NBYTES/8)
typedef double multifp __attribute__((vector_size (NBYTES)));

static double fpaddpar(double  *v, R_xlen_t n) {
        multifp psum1={0};
        multifp psum2={0};
        multifp psum3={0};
        multifp psum4={0};
        R_xlen_t i=0;
        double asum=0;
        
        /* ensure alignment */
        size_t uv = (size_t)v;
        while ((uv % NBYTES)!=0) {
                asum += *v++;
                uv = (size_t)v;
                n--;
        }

        multifp * restrict w=(multifp * restrict)v;
        R_xlen_t n2 = n/NFP64;
        for (i=0; i <= n2-4; i+=4) {
                psum1 += w[i];
                psum2 += w[i+1];
                psum3 += w[i+2];
                psum4 += w[i+3];
        }
                
        psum1 += psum2 + psum3 + psum4;
        #if NBYTES == 16
        	asum += psum1[0] + psum1[1];
        #elif NBYTES == 32
        	asum += psum1[0] + psum1[1] + psum1[2] + psum1[3];
        #else
        	#error number of bytes unsupported
        #endif
        
        for(i=i*NFP64; i < n; i++) {
                asum += v[i];
        }

        return asum;
}

static double slow_fpaddpar(double  *v, R_xlen_t n) {
  long double out1=0;
  long double out2=0;
  long double out3=0;
  long double out4=0;
  long double out5=0;
  long double out6=0;

  R_xlen_t i=0;
  for(;i<n;i+=6) {
   out1 += v[i];
   out2 += v[i+1];
   out3 += v[i+2];
   out4 += v[i+3];
   out5 += v[i+4];
   out6 += v[i+5];
  }
  out1+=out2+out3+out4+out5+out6;
  for(;i<n;i++) {
    out1+=v[i];
  }
  return out1;
}

Sur le core i5-4460 à 3.2 Ghz:

> v=rep(1/3, 1e5)
> system.time(replicate(1e4,.Call(C_fastsum,v)))
   user  system elapsed 
   0.22    0.00    0.22 
> system.time(replicate(1e4,.Call(C_slowsum,v)))
   user  system elapsed 
   0.42    0.02    0.44

Et sur le Ryzen 1700 à 3.0 Ghz (qui monte en réalité à 3.2 Ghz bien ventilé) :

> v=rep(1/3, 1e5)
> system.time(replicate(1e4,.Call(C_fastsum,v)))
   user  system elapsed 
   0.14    0.00    0.14 
> system.time(replicate(1e4,.Call(C_slowsum,v)))
   user  system elapsed 
   0.33    0.00    0.33 

Après analyse plus fine, les unités SIMD 128 bits (Ryzen) et 256 bits (Haswell) n’arrivent pas à saturation à cause du temps de load/store dans la cache L1/L2. Si on veut artificiellement montrer la supériorité du SIMD, il faut bidouiller le micro-benchmark pour que tout tienne de justesse en cache L1 et minimiser strictement les overheads d’appels en s’aidant de l’ALTREP de R 4.0 avec la boucle for.

> v=rep(1/3, 3e3) # Ryzen 1700
> A_fastsum=C_fastsum$address
> A_slowsum = C_slowsum$address
> system.time(for(i in 1:1e6) .Call(A_fastsum,v))
   user  system elapsed 
   0.33    0.00    0.33 
> system.time(for(i in 1:1e6) .Call(A_slowsum,v))
   user  system elapsed 
   1.08    0.00    1.08 

Même en trichant au maximum, on voit que c’est difficile de montrer un gros bénéfice du SIMD. L’architecture Zen étant limitée à 2 loads 128 bits/cycle depuis la cache L1 vers un registre.