跳转至

拉格朗日插值

约 349 个字 100 行代码 预计阅读时间 2 分钟

拉格朗日插值

原版

n 个不同点 (x_i,y_i) 确定唯一 n-1 次多项式 g

对于第 k 个点,构造一个多项式,满足:

F_k(x_k)=y_k

(k\ne i)F_k(x_i)=0

如果满足上述式子,则 g(X)=\sum\limits_{i=1}^n F_k(X)

构造0点,T_k(X)=\prod\limits_{i=1}^n[k\ne i](X-x_i)

构造 y_k 点,把 x_k 代入,T_k(x_k)=\prod\limits_{i=1}^n[k\ne i](x_k-x_i),这是个常数

修正这个多项式,使 F_k(x_k)=y_k,即 F_k(X)=\frac{y_kT_k(X)}{T_k(x_k)}

因此,g(X)=\sum\limits_{k=1}^ny_k\prod\limits_{i=1,k\ne i}^n\frac{X-x_i}{x_k-x_i}

根据上式,已知 n 个点值,可以 O(n^2) 预处理,O(nlog(mod)) 求出 g(k)

const int N=2e3+3;
const ll mod=998244353;
struct lagrange
{
    int n;
    ll x[N],y[N],w[N];
    long long inv(long long base)
    {
        long long result=1,exponent=mod-2;
        for(;exponent>0;exponent>>=1)
        {
            if(exponent&1)
                result=result*base%mod;
            base=base*base%mod;
        }
        return result;
    }
    void init()
    {
        n=0;
    }
    void insert(ll xx,ll yy)
    {
        x[++n]=xx;
        y[n]=yy;
        ll s=1;
        for(int i=1;i<n;++i)
        {
            w[i]=w[i]*(x[i]-xx)%mod;
            s=s*(xx-x[i])%mod;
        }
        w[n]=s;
    }
    void init2()
    {
        for(int i=1;i<=n;++i)
            w[i]=y[i]*inv(w[i])%mod;
    }
    ll lag(ll k)
    {
        ll s=1;
        for(int i=1;i<=n;++i)
        {
            if(k==x[i])
                return y[i];
            s=s*(k-x[i])%mod;
        }
        ll ans=0;
        for(int i=1;i<=n;++i)
            ans=(ans+s*inv(k-x[i])%mod*w[i]%mod)%mod;
        return (ans+mod)%mod;
    }
}la;
la.init();
la.insert(x,y);
la.init2();
la.lag(k);

特化版

当给定点值 x_i 连续,可以 O(n) 预处理,O(n)g(k)

g(n)=\sum\limits_{i=1}^{k}y_i\prod\limits_{j=1,j\ne i}^{k}\frac{n-j}{i-j}

g(n)=\sum\limits_{i=1}^{k}y_i\left(\prod\limits_{j=1,j\ne i}^{k}(i-j)\right)^{-1}\prod\limits_{j=1,j\ne i}^{k}(n-j)

g(n)=\sum\limits_{i=1}^{k}y_i\left(\prod\limits_{j=1}^{i-1}(i-j)\prod\limits_{j=i+1}^{k}(i-j)\right)^{-1}\prod\limits_{j=1}^{i-1}(n-j)\prod\limits_{j=i+1}^{k}(n-j)

g(n)=\sum\limits_{i=1}^{k}y_i\left(\prod\limits_{j=1}^{i-1}j\prod\limits_{j=i-k}^{-1}j\right)^{-1}\prod\limits_{j=1}^{i-1}(n-j)\prod\limits_{j=i+1}^{k}(n-j)

pre_i=\prod\limits_{j=1}^i(n-j),suf_i=\prod\limits_{j=i}^k(n-j)

g(n)=\sum\limits_{i=1}^{k}y_i(-1)^{k-i}\left((i-1)!(k-i)!\right)^{-1}pre_{i-1}suf_{i+1}

const int N=1e6+3;
const ll mod=1e9+7;
long long fac[N],inv[N];
long long fast_power(long long base, long long exponent)
{
    long long result = 1;
    for (; exponent > 0; exponent >>= 1)
    {
        if (exponent & 1)
            result = result * base % mod;
        base = base * base % mod;
    }
    return result;
}
void getC()
{
    fac[0]=1;
    for(int i=1;i<N;++i)
        fac[i]=fac[i-1]*i%mod;
    inv[N-1]=fast_power(fac[N-1],mod-2);
    for(int i=N-1;i;--i)
        inv[i-1]=inv[i]*i%mod;
}
struct lagrange
{
    ll y[N],pre[N],suf[N];
    ll lag(ll m,int k)//g(m),k个点
    {
        if(m<=k)
            return y[m];
        pre[0]=suf[k+1]=1;
        for(int i=1;i<k;++i)
            pre[i]=pre[i-1]*(m-i)%mod;
        for(int i=k;i>1;--i)
            suf[i]=suf[i+1]*(m-i)%mod;
        ll ans=0;
        for(int i=1;i<=k;++i)
            ans=(ans+y[i]*inv[i-1]%mod*inv[k-i]%mod*pre[i-1]%mod*suf[i+1]%mod*((k-i)&1?-1:1))%mod;
        return (ans+mod)%mod;
    }
}la;
la.y[i]=y;
la.lag(m,k);