DarthGizka DarthGizka - 20 days ago 4x
C# Question

Square pyramidal number of Int32 modulo some M using only Int64 intermediates

Computing the square pyramidal number

n (n + 1) (2 n + 1) / 6 mod M
for values of n up to 10^9 (and prime M) poses a bit of a challenge because the intermediate result before modulo reduction can exceed 10^27 and can therefore be too big for a 64-bit integer.

Reducing the factors modulo M before the multiplication creates a problem with the division by 6 because performing that division after reduction modulo M would give nonsensical results, obviously.

A the moment I'm using a workaround based on the fact that
n (n + 1)
must be even for any n and that either
n (n + 1)
(2 n + 1)
must be divisible by 3:

const int M = 1000000007;

static int modular_square_pyramidal_number (int n)
var a = (Int64)n * (n + 1) / 2;
var b = 2 * n + 1;
var q = a / 3;
var p = q * 3 == a ? (q % M) * b : (a % M) * (b / 3);

return (int)(p % M);

As you can see, this is really awkward. Is there a more elegant/efficient way of performing this computation without resorting to BigInteger or Decimal, perhaps using intermediate reduction modulo 3 M in some way?

Background: the problem came up in solving the Tic Tac Toe practice problem at HackerEarth. The submission based on my awkward hack was accepted but I'm not satisfied with this half-baked solution. That's the whole point of these practice problems, isn't it: I'm not going to learn anything if I accept any half-baked solution based on pre-existing knowledge that sort of scrapes by the robot judge. Hence I'm always aiming to improve the solutions until they achieve a state of simplicity and grace...


My intuition about reduction modulo 3 M panned off - it just took a while to pin the thing down mathematically after testing showed that it worked.

The key is the Chinese Remainder Theorem which effectively guarantees for coprime p and q that

(x / q) mod p = ((x mod pq) / q) mod p

Let's take the same split of the formula to be computed as in my question:

n (n + 1) (2 n + 1) / 6 mod M = a b / 3 mod M

a = n (n + 1) / 2
b = 2 n + 1

Either a or b must be divisible by 3 but it is not known which one, and a * b could be too big to fit in a 64-bit integer (around 90 bits, given the original constraint of n ≤ 1e9).

However, with M = 1000000007 (i.e. the usual 1e9 + 7) the term 3 * M requires only 32 bits, and the same holds for a reduced modulo 3 M. Since b already fits into 31 bits this means that the product can be computed using 64-bit arithmetic:

((a mod 3 M) * b) / 3 mod M

Changed code:

static int v1 (int i)
    var n = (uint)i;
    var a = ((UInt64)n * (n + 1) >> 1) % (M * 3U);
    var b = 2 * n + 1;

    return (int)((a * b / 3) % M);

This uses unsigned arithmetic which is appropriate here and also more efficient, since signed arithmetic usually requires extra effort by the compiler (read: emission of additional instructions) in order to realise the signed arithmetic semantics.

Benchmark show this to be more than twice as fast as the original code from my question - but only under old framework versions (up to 3.5). Starting with version 4.0 the JIT compiler no longer turns unsigned division by constants into multiplication + shift. Division instructions tend to be at least an order of magnitude slower than multiplicatios, and so the code becomes a lot slower than the original code on systems with the newer compiler.

On such systems it is better to go with the flow and use the inefficient - but politically correct - signed integers:

static int v2 (int n)
    var a = ((Int64)n * (n + 1) >> 1) % (M * 3L);
    var b = 2 * n + 1;

    return (int)((a * b / 3) % M);

Here are the benchmarks for 1000000 calls on my aging Haswell laptop for framework version 2.0:

IntPtr.Size = 8, Environment.Version = 2.0.50727.8009
bench 1000000:    8,407 v0    3,413 v1    4,653 v2
bench 1000000:    8,017 v0    3,179 v1    5,038 v2
bench 1000000:    8,641 v0    3,114 v1    4,801 v2

Times are in milliseconds, and v0 stands for the original code from my question. It is easy to see how the overhead of signed semantics makes v2 measurably slower than v1 which uses unsigned arithmetic internally.

Environment.Version and timings are exactly the same for framework versions up to 3.5, so I guess they are all using the same environment/compiler.

And now the timings for Microsoft's new and ‘improved’ compilers that come with framework 4.0 and newer:

IntPtr.Size = 8, Environment.Version = 4.0.30319.42000
bench 1000000:    9,518 v0   20,479 v1    5,687 v2
bench 1000000:    9,225 v0   20,251 v1    5,540 v2
bench 1000000:    9,133 v0   20,333 v1    5,389 v2

Environment.Version and timings are exactly the same for framework versions 4.0 through 4.6.1.

POST SCRIPTUM - using the modular multiplicative inverse

An alternative solution would be to use the modular multiplicative inverse of the divisor. In the present case this works because the final product is known to be evenly divisible by the divisor (i.e. 3); if it weren't then the result would be wildly inaccurate. Example (333333336 being the multiplicative inverse of 3 modulo 1000000007):

7 * 333333336 % 1000000007 = 333333338  // 7 mod 3 != 0
8 * 333333336 % 1000000007 = 666666674  // 8 mod 3 != 0
9 * 333333336 % 1000000007 =         1  // 9 mod 3 == 0     

The raison d'etre for this topic was that integer division is potentially lossy since it drops the remainder, if any, and so the result of the pyramidal square calculation would be wrong if the wrong factor were divided by 3.

Modular division - i.e. multiplication with the multiplicative inverse - is not lossy, and so it doesn't matter which factor is multiplied with the inverse. This can readily be seen in the example just shown, where the outlandish residues for 1 and 2 effectively encode the fractional remainder, and adding them - which corresponds to computing 7/3 + 8/3 - gives 1000000012 which equals 5 mod 1000000007 just as expected.

Hence the crux of the matter is that the final product be evenly divisible by the divisor but it doesn't matter when and where the ‘division’ (multiplication with the inverse) occurs. The resulting code is slightly less efficient than v1 and roughly on par with v2 because it requires an additional reduction modulo M after the multiplication with the inverse. However, I'm showing it anyway because the approach could come in handy on occasion:

static int v3 (int n)
    var a = n * (n + 1L) % M;
    var b = (2 * n + 1L) * 166666668 % M;

    return (int)(a * b % M);

Note: I dropped the right shift and incorporated the divisor 2 into the inverse, since the separate division by 2 no longer serves any purpose here. The timings are the same as for v2.