Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
152 views
in Technique[技术] by (71.8m points)

multiplication - Translation from Complex-FFT to Finite-Field-FFT

Good afternoon!

I am trying to develop an NTT algorithm based on the naive recursive FFT implementation I already have.

Consider the following code (coefficients' length, let it be m, is an exact power of two):

/// <summary>
/// Calculates the result of the recursive Number Theoretic Transform.
/// </summary>
/// <param name="coefficients"></param>
/// <returns></returns>
private static BigInteger[] Recursive_NTT_Skeleton(
    IList<BigInteger> coefficients, 
    IList<BigInteger> rootsOfUnity, 
    int step, 
    int offset)
{
    // Calculate the length of vectors at the current step of recursion.
    // -
    int n = coefficients.Count / step - offset / step;

    if (n == 1)
    {
        return new BigInteger[] { coefficients[offset] };
    }

    BigInteger[] results = new BigInteger[n];

    IList<BigInteger> resultEvens = 
        Recursive_NTT_Skeleton(coefficients, rootsOfUnity, step * 2, offset);

    IList<BigInteger> resultOdds = 
        Recursive_NTT_Skeleton(coefficients, rootsOfUnity, step * 2, offset + step);

    for (int k = 0; k < n / 2; k++)
    {
        BigInteger bfly = (rootsOfUnity[k * step] * resultOdds[k]) % NTT_MODULUS;

        results[k]          = (resultEvens[k] + bfly) % NTT_MODULUS;
        results[k + n / 2]  = (resultEvens[k] - bfly) % NTT_MODULUS;
    }

    return results;
}

It worked for complex FFT (replace BigInteger with a complex numeric type (I had my own)). It doesn't work here even though I changed the procedure of finding the primitive roots of unity appropriately.

Supposedly, the problem is this: rootsOfUnity parameter passed originally contained only the first half of m-th complex roots of unity in this order:

omega^0 = 1, omega^1, omega^2, ..., omega^(n/2)

It was enough, because on these three lines of code:

BigInteger bfly = (rootsOfUnity[k * step] * resultOdds[k]) % NTT_MODULUS;        

results[k]          = (resultEvens[k] + bfly) % NTT_MODULUS;
results[k + n / 2]  = (resultEvens[k] - bfly) % NTT_MODULUS;

I originally made use of the fact, that at any level of recursion (for any n and i), the complex root of unity -omega^(i) = omega^(i + n/2).

However, that property obviously doesn't hold in finite fields. But is there any analogue of it which would allow me to still compute only the first half of the roots?

Or should I extend the cycle from n/2 to n and pre-compute all the m-th roots of unity?

Maybe there are other problems with this code?..

Thank you very much in advance!

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I recently wanted to implement NTT for fast multiplication instead of DFFT too. Read a lot of confusing things, different letters everywhere and no simple solution, and also my finite fields knowledge is rusty , but today i finally got it right (after 2 days of trying and analog-ing with DFT coefficients) so here are my insights for NTT:

  1. Computation

    X(i) = sum(j=0..n-1) of ( Wn^(i*j)*x(i) );
    

    where X[] is NTT transformed x[] of size n where Wn is the NTT basis. All computations are on integer modulo arithmetics mod p no complex numbers anywhere.

  2. Important values


    Wn = r ^ L mod p is basis for NTT
    Wn = r ^ (p-1-L) mod p is basis for INTT
    Rn = n ^ (p-2) mod p is scaling multiplicative constant for INTT ~(1/n)
    p is prime that p mod n == 1 and p>max'
    max is max value of x[i] for NTT or X[i] for INTT
    r = <1,p)
    L = <1,p) and also divides p-1
    r,L must be combined so r^(L*i) mod p == 1 if i=0 or i=n
    r,L must be combined so r^(L*i) mod p != 1 if 0 < i < n
    max' is the sub-result max value and depends on n and type of computation. For single (I)NTT it is max' = n*max but for convolution of two n sized vectors it is max' = n*max*max etc. See Implementing FFT over finite fields for more info about it.

  3. functional combination of r,L,p is different for different n

    this is important, you have to recompute or select parameters from table before each NTT layer (n is always half of the previous recursion).

Here is my C++ code that finds the r,L,p parameters (needs modular arithmetics which is not included, you can replace it with (a+b)%c,(a-b)%c,(a*b)%c,... but in that case beware of overflows especial for modpow and modmul) The code is not optimized yet there are ways to speed it up considerably. Also prime table is fairly limited so either use SoE or any other algo to obtain primes up to max' in order to work safely.

DWORD _arithmetics_primes[]=
    {
    2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,
    179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401,409,
    419,421,431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,
    661,673,677,683,691,701,709,719,727,733,739,743,751,757,761,769,773,787,797,809,811,821,823,827,829,839,853,857,859,863,877,881,883,887,907,911,919,929,937,941,
    947,953,967,971,977,983,991,997,1009,1013,1019,1021,1031,1033,1039,1049,1051,1061,1063,1069,1087,1091,1093,1097,1103,1109,1117,1123,1129,1151,
    0}; // end of table is 0, the more primes are there the bigger numbers and n can be used
// compute NTT consts W=r^L%p for n
int i,j,k,n=16;
long w,W,iW,p,r,L,l,e;
long max=81*n;  // edit1: max num for NTT for my multiplication purposses
for (e=1,j=0;e;j++)             // find prime p that p%n=1 AND p>max ... 9*9=81
    {
    p=_arithmetics_primes[j];
    if (!p) break;
    if ((p>max)&&(p%n==1))
     for (r=2;r<p;r++)  // check all r
        {
        for (l=1;l<p;l++)// all l that divide p-1
            {
            L=(p-1);
            if (L%l!=0) continue;
            L/=l;
            W=modpow(r,L,p);
            e=0;
            for (w=1,i=0;i<=n;i++,w=modmul(w,W,p))
                {
                if ((i==0)      &&(w!=1)) { e=1; break; }
                if ((i==n)      &&(w!=1)) { e=1; break; }
                if ((i>0)&&(i<n)&&(w==1)) { e=1; break; }
                }
            if (!e) break;
            }
        if (!e) break;
        }
    }
if (e) { error; }           // error no combination r,l,p for n found
 W=modpow(r,    L,p);   // Wn for NTT
iW=modpow(r,p-1-L,p);   // Wn for INTT

and here is my slow NTT and INTT implementations (i havent got to fast NTT,INTT yet) they are both tested with Sch?nhage–Strassen multiplication successfully.

//---------------------------------------------------------------------------
void NTT(long *dst,long *src,long n,long m,long w)
    {
    long i,j,wj,wi,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i],m),m);
            wi=modmul(wi,wj,m);
            }
        dst[j]=a;
        wj=modmul(wj,w,m);
        }
    }
//---------------------------------------------------------------------------
void INTT(long *dst,long *src,long n,long m,long w)
    {
    long i,j,wi=1,wj=1,rN,a,n2=n>>1;
    rN=modpow(n,m-2,m);
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i],m),m);
            wi=modmul(wi,wj,m);
            }
        dst[j]=modmul(a,rN,m);
        wj=modmul(wj,w,m);
        }
    }
//---------------------------------------------------------------------------


dst is destination array
src is source array
n is array size
m is modulus (p)
w is basis (Wn)

hope this helps to someone. If i forgot something please write ...

[edit1: fast NTT/INTT]

Finally I manage to get fast NTT/INTT to work. Was little bit more tricky than normal FFT:

//---------------------------------------------------------------------------
void _NFTT(long *dst,long *src,long n,long m,long w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    long i,j,a0,a1,n2=n>>1,w2=modmul(w,w,m);
    // reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
    // recursion
    _NFTT(src   ,dst   ,n2,m,w2);   // even
    _NFTT(src+n2,dst+n2,n2,m,w2);   // odd
    // restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w,m))
        {
        a0=src[i];
        a1=modmul(src[j],w2,m);
        dst[i]=modadd(a0,a1,m);
        dst[j]=modsub(a0,a1,m);
        }
    }
//---------------------------------------------------------------------------
void _INFTT(long *dst,long *src,long n,long m,long w)
    {
    long i,rN;
    rN=modpow(n,m-2,m);
    _NFTT(dst,src,n,m,w);
    for (i=0;i<n;i++) dst[i]=modmul(dst[i],rN,m);
    }
//---------------------------------------------------------------------------

[edit3]

I have optimized my code (3x times faster than code above),but still i am not satisfied with it so i started new question with it. There I have optimized my code even further (about 40x times faster than code above) so its almost the same speed as FFT on floating point of the same bit size. Link to it is here:


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...