拉格朗日插值
约 349 个字 100 行代码 预计阅读时间 2 分钟
拉格朗日插值¶
原版¶
n 个不同点 (x_i,y_i) 确定唯一 n-1 次多项式 g
对于第 k 个点,构造一个多项式,满足:
F_k(x_k)=y_k
(k\ne i)F_k(x_i)=0
如果满足上述式子,则 g(X)=\sum\limits_{i=1}^n F_k(X)
构造0点,T_k(X)=\prod\limits_{i=1}^n[k\ne i](X-x_i)
构造 y_k 点,把 x_k 代入,T_k(x_k)=\prod\limits_{i=1}^n[k\ne i](x_k-x_i),这是个常数
修正这个多项式,使 F_k(x_k)=y_k,即 F_k(X)=\frac{y_kT_k(X)}{T_k(x_k)}
因此,g(X)=\sum\limits_{k=1}^ny_k\prod\limits_{i=1,k\ne i}^n\frac{X-x_i}{x_k-x_i}
根据上式,已知 n 个点值,可以 O(n^2) 预处理,O(nlog(mod)) 求出 g(k)
const int N=2e3+3;
const ll mod=998244353;
struct lagrange
{
int n;
ll x[N],y[N],w[N];
long long inv(long long base)
{
long long result=1,exponent=mod-2;
for(;exponent>0;exponent>>=1)
{
if(exponent&1)
result=result*base%mod;
base=base*base%mod;
}
return result;
}
void init()
{
n=0;
}
void insert(ll xx,ll yy)
{
x[++n]=xx;
y[n]=yy;
ll s=1;
for(int i=1;i<n;++i)
{
w[i]=w[i]*(x[i]-xx)%mod;
s=s*(xx-x[i])%mod;
}
w[n]=s;
}
void init2()
{
for(int i=1;i<=n;++i)
w[i]=y[i]*inv(w[i])%mod;
}
ll lag(ll k)
{
ll s=1;
for(int i=1;i<=n;++i)
{
if(k==x[i])
return y[i];
s=s*(k-x[i])%mod;
}
ll ans=0;
for(int i=1;i<=n;++i)
ans=(ans+s*inv(k-x[i])%mod*w[i]%mod)%mod;
return (ans+mod)%mod;
}
}la;
la.init();
la.insert(x,y);
la.init2();
la.lag(k);
特化版¶
当给定点值 x_i 连续,可以 O(n) 预处理,O(n) 求 g(k)
g(n)=\sum\limits_{i=1}^{k}y_i\prod\limits_{j=1,j\ne i}^{k}\frac{n-j}{i-j}
g(n)=\sum\limits_{i=1}^{k}y_i\left(\prod\limits_{j=1,j\ne i}^{k}(i-j)\right)^{-1}\prod\limits_{j=1,j\ne i}^{k}(n-j)
g(n)=\sum\limits_{i=1}^{k}y_i\left(\prod\limits_{j=1}^{i-1}(i-j)\prod\limits_{j=i+1}^{k}(i-j)\right)^{-1}\prod\limits_{j=1}^{i-1}(n-j)\prod\limits_{j=i+1}^{k}(n-j)
g(n)=\sum\limits_{i=1}^{k}y_i\left(\prod\limits_{j=1}^{i-1}j\prod\limits_{j=i-k}^{-1}j\right)^{-1}\prod\limits_{j=1}^{i-1}(n-j)\prod\limits_{j=i+1}^{k}(n-j)
pre_i=\prod\limits_{j=1}^i(n-j),suf_i=\prod\limits_{j=i}^k(n-j)
g(n)=\sum\limits_{i=1}^{k}y_i(-1)^{k-i}\left((i-1)!(k-i)!\right)^{-1}pre_{i-1}suf_{i+1}
const int N=1e6+3;
const ll mod=1e9+7;
long long fac[N],inv[N];
long long fast_power(long long base, long long exponent)
{
long long result = 1;
for (; exponent > 0; exponent >>= 1)
{
if (exponent & 1)
result = result * base % mod;
base = base * base % mod;
}
return result;
}
void getC()
{
fac[0]=1;
for(int i=1;i<N;++i)
fac[i]=fac[i-1]*i%mod;
inv[N-1]=fast_power(fac[N-1],mod-2);
for(int i=N-1;i;--i)
inv[i-1]=inv[i]*i%mod;
}
struct lagrange
{
ll y[N],pre[N],suf[N];
ll lag(ll m,int k)//g(m),k个点
{
if(m<=k)
return y[m];
pre[0]=suf[k+1]=1;
for(int i=1;i<k;++i)
pre[i]=pre[i-1]*(m-i)%mod;
for(int i=k;i>1;--i)
suf[i]=suf[i+1]*(m-i)%mod;
ll ans=0;
for(int i=1;i<=k;++i)
ans=(ans+y[i]*inv[i-1]%mod*inv[k-i]%mod*pre[i-1]%mod*suf[i+1]%mod*((k-i)&1?-1:1))%mod;
return (ans+mod)%mod;
}
}la;
la.y[i]=y;
la.lag(m,k);