Geant4 11.1.1
Toolkit for the simulation of the passage of particles through matter
Loading...
Searching...
No Matches
mulmod.h
Go to the documentation of this file.
1//
2// -*- C++ -*-
3//
4// -----------------------------------------------------------------------
5// HEP Random
6// --- RanluxppEngine ---
7// helper implementation file
8// -----------------------------------------------------------------------
9
10#ifndef RANLUXPP_MULMOD_H
11#define RANLUXPP_MULMOD_H
12
13#include "helpers.h"
14
15#include <cstdint>
16
17/// Multiply two 576 bit numbers, stored as 9 numbers of 64 bits each
18///
19/// \param[in] in1 first factor as 9 numbers of 64 bits each
20/// \param[in] in2 second factor as 9 numbers of 64 bits each
21/// \param[out] out result with 18 numbers of 64 bits each
22static void multiply9x9(const uint64_t *in1, const uint64_t *in2,
23 uint64_t *out) {
24 uint64_t next = 0;
25 unsigned nextCarry = 0;
26
27#if defined(__clang__) || defined(__INTEL_COMPILER)
28#pragma unroll
29#elif defined(__GNUC__) && __GNUC__ >= 8
30// This pragma was introduced in GCC version 8.
31#pragma GCC unroll 18
32#endif
33 for (int i = 0; i < 18; i++) {
34 uint64_t current = next;
35 unsigned carry = nextCarry;
36
37 next = 0;
38 nextCarry = 0;
39
40#if defined(__clang__) || defined(__INTEL_COMPILER)
41#pragma unroll
42#elif defined(__GNUC__) && __GNUC__ >= 8
43// This pragma was introduced in GCC version 8.
44#pragma GCC unroll 9
45#endif
46 for (int j = 0; j < 9; j++) {
47 int k = i - j;
48 if (k < 0 || k >= 9)
49 continue;
50
51 uint64_t fac1 = in1[j];
52 uint64_t fac2 = in2[k];
53#if defined(__SIZEOF_INT128__) && !defined(CLHEP_NO_INT128)
54#ifdef __GNUC__
55#pragma GCC diagnostic push
56#pragma GCC diagnostic ignored "-Wpedantic"
57#endif
58 unsigned __int128 prod = fac1;
59 prod = prod * fac2;
60
61 uint64_t upper = prod >> 64;
62 uint64_t lower = static_cast<uint64_t>(prod);
63#ifdef __GNUC__
64#pragma GCC diagnostic pop
65#endif
66#else
67 uint64_t upper1 = fac1 >> 32;
68 uint64_t lower1 = static_cast<uint32_t>(fac1);
69
70 uint64_t upper2 = fac2 >> 32;
71 uint64_t lower2 = static_cast<uint32_t>(fac2);
72
73 // Multiply 32-bit parts, each product has a maximum value of
74 // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2 ** 32 + 1.
75 uint64_t upper = upper1 * upper2;
76 uint64_t middle1 = upper1 * lower2;
77 uint64_t middle2 = lower1 * upper2;
78 uint64_t lower = lower1 * lower2;
79
80 // When adding the two products, the maximum value for middle is
81 // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which exceeds a uint64_t.
82 unsigned overflow;
83 uint64_t middle = add_overflow(middle1, middle2, overflow);
84 // Handling the overflow by a multiplication with 0 or 1 is cheaper
85 // than branching with an if statement, which the compiler does not
86 // optimize to this equivalent code. Note that we could do entirely
87 // without this overflow handling when summing up the intermediate
88 // products differently as described in the following SO answer:
89 // https://stackoverflow.com/a/51587262
90 // However, this approach takes at least the same amount of thinking
91 // why a) the code gives the same results without b) overflowing due
92 // to the mixture of 32 bit arithmetic. Moreover, my tests show that
93 // the scheme implemented here is actually slightly more performant.
94 uint64_t overflow_add = overflow * (uint64_t(1) << 32);
95 // This addition can never overflow because the maximum value of upper
96 // is 2 ** 64 - 2 * 2 ** 32 + 1 (see above). When now adding another
97 // 2 ** 32, the result is 2 ** 64 - 2 ** 32 + 1 and still smaller than
98 // the maximum 2 ** 64 - 1 that can be stored in a uint64_t.
99 upper += overflow_add;
100
101 uint64_t middle_upper = middle >> 32;
102 uint64_t middle_lower = middle << 32;
103
104 lower = add_overflow(lower, middle_lower, overflow);
105 upper += overflow;
106
107 // This still can't overflow since the maximum of middle_upper is
108 // - 2 ** 32 - 4 if there was an overflow for middle above, bringing
109 // the maximum value of upper to 2 ** 64 - 2.
110 // - otherwise upper still has the initial maximum value given above
111 // and the addition of a value smaller than 2 ** 32 brings it to
112 // a maximum value of 2 ** 64 - 2 ** 32 + 2.
113 // (Both cases include the increment to handle the overflow in lower.)
114 //
115 // All the reasoning makes perfect sense given that the product of two
116 // 64 bit numbers is smaller than or equal to
117 // (2 ** 64 - 1) ** 2 = 2 ** 128 - 2 * 2 ** 64 + 1
118 // with the upper bits matching the 2 ** 64 - 2 of the first case.
119 upper += middle_upper;
120#endif
121
122 // Add to current, remember carry.
123 current = add_carry(current, lower, carry);
124
125 // Add to next, remember nextCarry.
126 next = add_carry(next, upper, nextCarry);
127 }
128
129 next = add_carry(next, carry, nextCarry);
130
131 out[i] = current;
132 }
133}
134
135/// Compute a value congruent to mul modulo m less than 2 ** 576
136///
137/// \param[in] mul product from multiply9x9 with 18 numbers of 64 bits each
138/// \param[out] out result with 9 numbers of 64 bits each
139///
140/// \f$ m = 2^{576} - 2^{240} + 1 \f$
141///
142/// The result in out is guaranteed to be smaller than the modulus.
143static void mod_m(const uint64_t *mul, uint64_t *out) {
144 uint64_t r[9];
145 // Assign r = t0
146 for (int i = 0; i < 9; i++) {
147 r[i] = mul[i];
148 }
149
150 int64_t c = compute_r(mul + 9, r);
151
152 // To update r = r - c * m, it suffices to know c * (-2 ** 240 + 1)
153 // because the 2 ** 576 will cancel out. Also note that c may be zero, but
154 // the operation is still performed to avoid branching.
155
156 // c * (-2 ** 240 + 1) in 576 bits looks as follows, depending on c:
157 // - if c = 0, the number is zero.
158 // - if c = 1: bits 576 to 240 are set,
159 // bits 239 to 1 are zero, and
160 // the last one is set
161 // - if c = -1, which corresponds to all bits set (signed int64_t):
162 // bits 576 to 240 are zero and the rest is set.
163 // Note that all bits except the last are exactly complimentary (unless c = 0)
164 // and the last byte is conveniently represented by c already.
165 // Now construct the three bit patterns from c, their names correspond to the
166 // assembly implementation by Alexei Sibidanov.
167
168 // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1 -> all bits set (sign extension)
169 // (The assembly implementation shifts by 63, which gives the same result.)
170 int64_t t0 = c >> 1;
171
172 // Left shifting negative values is undefined behavior until C++20, cast to
173 // unsigned.
174 uint64_t c_unsigned = static_cast<uint64_t>(c);
175
176 // c = 0 -> t2 = 0; c = 1 -> upper 16 bits set; c = -1 -> lower 48 bits set
177 int64_t t2 = t0 - (c_unsigned << 48);
178
179 // c = 0 -> t1 = 0; c = 1 -> all bits set (sign extension); c = -1 -> t1 = 0
180 // (The assembly implementation shifts by 63, which gives the same result.)
181 int64_t t1 = t2 >> 48;
182
183 unsigned carry = 0;
184 {
185 uint64_t r_0 = r[0];
186
187 uint64_t out_0 = sub_carry(r_0, c, carry);
188 out[0] = out_0;
189 }
190 for (int i = 1; i < 3; i++) {
191 uint64_t r_i = r[i];
192 r_i = sub_overflow(r_i, carry, carry);
193
194 uint64_t out_i = sub_carry(r_i, t0, carry);
195 out[i] = out_i;
196 }
197 {
198 uint64_t r_3 = r[3];
199 r_3 = sub_overflow(r_3, carry, carry);
200
201 uint64_t out_3 = sub_carry(r_3, t2, carry);
202 out[3] = out_3;
203 }
204 for (int i = 4; i < 9; i++) {
205 uint64_t r_i = r[i];
206 r_i = sub_overflow(r_i, carry, carry);
207
208 uint64_t out_i = sub_carry(r_i, t1, carry);
209 out[i] = out_i;
210 }
211}
212
213/// Combine multiply9x9 and mod_m with internal temporary storage
214///
215/// \param[in] in1 first factor with 9 numbers of 64 bits each
216/// \param[inout] inout second factor and also the output of the same size
217///
218/// The result in inout is guaranteed to be smaller than the modulus.
219static void mulmod(const uint64_t *in1, uint64_t *inout) {
220 uint64_t mul[2 * 9] = {0};
221 multiply9x9(in1, inout, mul);
222 mod_m(mul, inout);
223}
224
225/// Compute base to the n modulo m
226///
227/// \param[in] base with 9 numbers of 64 bits each
228/// \param[out] res output with 9 numbers of 64 bits each
229/// \param[in] n exponent
230///
231/// The arguments base and res may point to the same location.
232static void powermod(const uint64_t *base, uint64_t *res, uint64_t n) {
233 uint64_t fac[9] = {0};
234 fac[0] = base[0];
235 res[0] = 1;
236 for (int i = 1; i < 9; i++) {
237 fac[i] = base[i];
238 res[i] = 0;
239 }
240
241 uint64_t mul[18] = {0};
242 while (n) {
243 if (n & 1) {
244 multiply9x9(res, fac, mul);
245 mod_m(mul, res);
246 }
247 n >>= 1;
248 if (!n)
249 break;
250 multiply9x9(fac, fac, mul);
251 mod_m(mul, fac);
252 }
253}
254
255#endif