diff --git a/prover-ray/crypto/koalabear/fiatshamir/poseidon2.go b/prover-ray/crypto/koalabear/fiatshamir/poseidon2.go index 08fe59edb7c..e8c1a8a4046 100644 --- a/prover-ray/crypto/koalabear/fiatshamir/poseidon2.go +++ b/prover-ray/crypto/koalabear/fiatshamir/poseidon2.go @@ -38,7 +38,7 @@ func (fs *FiatShamir) UpdateExt(vec ...field.Ext) { if len(vec) == 0 { return } - vElems := unsafe.Slice((*field.Element)(unsafe.Pointer(&vec[0])), 4*len(vec)) //nolint + vElems := unsafe.Slice((*field.Element)(unsafe.Pointer(&vec[0])), field.ExtensionDegree*len(vec)) //nolint fs.h.WriteElements(vElems...) } @@ -96,6 +96,8 @@ func (fs *FiatShamir) RandomField() field.Octuplet { } // RandomFext samples a random extension field element from the transcript. +// Uses the first ExtensionDegree (=6) of the 8 octuplet coordinates; the +// remaining two are discarded. func (fs *FiatShamir) RandomFext() field.Ext { s := fs.RandomField() // already calls safeguardUpdate() var res field.Ext @@ -103,6 +105,8 @@ func (fs *FiatShamir) RandomFext() field.Ext { res.B0.A1 = s[1] res.B1.A0 = s[2] res.B1.A1 = s[3] + res.B2.A0 = s[4] + res.B2.A1 = s[5] return res } diff --git a/prover-ray/crypto/koalabear/reedsolomon/reedsolomon.go b/prover-ray/crypto/koalabear/reedsolomon/reedsolomon.go index 90e57adfc1e..f7d39ddd254 100644 --- a/prover-ray/crypto/koalabear/reedsolomon/reedsolomon.go +++ b/prover-ray/crypto/koalabear/reedsolomon/reedsolomon.go @@ -146,7 +146,7 @@ func (r *RsParams) rsEncodeExt(v []field.Ext) []field.Ext { smallDomain := r.Domains[0] largeDomain := r.Domains[1] - smallDomain.FFTInverseExt(expandedCoeffs[:n], fft.DIF, fft.WithNbTasks(1)) + smallDomain.FFTInverseExt6(expandedCoeffs[:n], fft.DIF, fft.WithNbTasks(1)) // this loop dispatches the values that are all located at the beginning // of the domain to the entire domain by homothety @@ -155,16 +155,16 @@ func (r *RsParams) rsEncodeExt(v []field.Ext) []field.Ext { expandedCoeffs[j] = field.Ext{} } - largeDomain.FFTExt(expandedCoeffs, fft.DIT, fft.WithNbTasks(1)) + largeDomain.FFTExt6(expandedCoeffs, fft.DIT, fft.WithNbTasks(1)) return expandedCoeffs } // fast path; we avoid the bit reverse operations and work on the smaller domain. - inputCoeffs := extensions.Vector(expandedCoeffs[:r.NbColumns()]) - r.Domains[0].FFTInverseExt(inputCoeffs, fft.DIF, fft.WithNbTasks(1)) + inputCoeffs := extensions.VectorE6(expandedCoeffs[:r.NbColumns()]) + r.Domains[0].FFTInverseExt6(inputCoeffs, fft.DIF, fft.WithNbTasks(1)) inputCoeffs.MulByElement(inputCoeffs, r.CosetTableBitReverse) - r.Domains[0].FFTExt(inputCoeffs, fft.DIT, fft.WithNbTasks(1)) + r.Domains[0].FFTExt6(inputCoeffs, fft.DIT, fft.WithNbTasks(1)) for j := r.NbColumns() - 1; j >= 0; j-- { expandedCoeffs[rho*j+1] = expandedCoeffs[j] expandedCoeffs[rho*j] = v[j] @@ -184,7 +184,7 @@ func (r *RsParams) IsCodewordExt(v []field.Ext) error { coeffs := make([]field.Ext, r.NbEncodedColumns()) copy(coeffs, v) - r.Domains[1].FFTInverseExt(coeffs, fft.DIF, fft.WithNbTasks(1)) + r.Domains[1].FFTInverseExt6(coeffs, fft.DIF, fft.WithNbTasks(1)) utils.BitReverse(coeffs) for i := r.NbColumns(); i < r.NbEncodedColumns(); i++ { c := coeffs[i] @@ -210,6 +210,6 @@ func (r *RsParams) EncodeFromMonomials(coefficients []field.Ext) []field.Ext { copy(buf, coefficients) // DIT FFT expects bit-reversed input; natural-order evaluations come out. utils.BitReverse(buf) - r.Domains[1].FFTExt(buf, fft.DIT) + r.Domains[1].FFTExt6(buf, fft.DIT) return buf } diff --git a/prover-ray/crypto/koalabear/ringsis/ringsis.go b/prover-ray/crypto/koalabear/ringsis/ringsis.go index aab02c8e088..ea26c00c04b 100644 --- a/prover-ray/crypto/koalabear/ringsis/ringsis.go +++ b/prover-ray/crypto/koalabear/ringsis/ringsis.go @@ -197,7 +197,7 @@ func (k Key) HashModXnMinus1(limbs []field.Ext) []field.Ext { inputReader = nil } - domain.FFTExt(polyK, fft.DIF) + domain.FFTExt6(polyK, fft.DIF) domain.FFT(a, fft.DIF) var tmp field.Ext @@ -212,7 +212,7 @@ func (k Key) HashModXnMinus1(limbs []field.Ext) []field.Ext { } // by linearity, we defer the fft inverse at the end - domain.FFTInverseExt(r, fft.DIT) + domain.FFTInverseExt6(r, fft.DIT) // also account for the Montgommery issue : in gnark's implementation // the key is implictly multiplied by RInv diff --git a/prover-ray/crypto/koalabear/vortex/utils.go b/prover-ray/crypto/koalabear/vortex/utils.go index 7369b51f023..892c8a7d027 100644 --- a/prover-ray/crypto/koalabear/vortex/utils.go +++ b/prover-ray/crypto/koalabear/vortex/utils.go @@ -20,8 +20,9 @@ func init() { } var ( - // ErrSizeNotAMultipleOfFour is returned when the input size is not a multiple of 4. - ErrSizeNotAMultipleOfFour = errors.New("size(inputs) should be a multiple of 4") + // ErrSizeNotAMultipleOfExtDegree is returned when the input size is not a + // multiple of field.ExtensionDegree. + ErrSizeNotAMultipleOfExtDegree = fmt.Errorf("size(inputs) should be a multiple of %d", field.ExtensionDegree) // ErrSizesDontMatch is returned when input and output slices have different sizes. ErrSizesDontMatch = errors.New("size(inputs) should be equal to size(outputs)") // ErrNotAPowerOfTwo is returned when the input size is not a power of two. @@ -32,47 +33,53 @@ func fftExtInvHintEmulated(_ *big.Int, inputs []*big.Int, output []*big.Int) err return emulated.UnwrapHint(inputs, output, fftExtInvHintNative) } -// Each chunk of 4 inputs corresponds to a E4 element. +// Each chunk of field.ExtensionDegree (=6) inputs corresponds to one Ext element. func fftExtInvHintNative(_ *big.Int, inputs, outputs []*big.Int) error { - if len(inputs)%4 != 0 { - return ErrSizeNotAMultipleOfFour + const d = field.ExtensionDegree + + if len(inputs)%d != 0 { + return ErrSizeNotAMultipleOfExtDegree } - n := len(inputs) / 4 + n := len(inputs) / d _n := ecc.NextPowerOfTwo(uint64(n)) if _n != uint64(n) { return ErrNotAPowerOfTwo } - d := fft.NewDomain(uint64(n)) + dom := fft.NewDomain(uint64(n)) _res := make([]field.Ext, n) for i := 0; i < n; i++ { - _res[i].B0.A0.SetBigInt(inputs[4*i]) - _res[i].B0.A1.SetBigInt(inputs[4*i+1]) - _res[i].B1.A0.SetBigInt(inputs[4*i+2]) - _res[i].B1.A1.SetBigInt(inputs[4*i+3]) + _res[i].B0.A0.SetBigInt(inputs[d*i+0]) + _res[i].B0.A1.SetBigInt(inputs[d*i+1]) + _res[i].B1.A0.SetBigInt(inputs[d*i+2]) + _res[i].B1.A1.SetBigInt(inputs[d*i+3]) + _res[i].B2.A0.SetBigInt(inputs[d*i+4]) + _res[i].B2.A1.SetBigInt(inputs[d*i+5]) } - d.FFTInverseExt(_res, fft.DIF) + dom.FFTInverseExt6(_res, fft.DIF) utils.BitReverse(_res) // we're supposed to have tail of zeros. Check - for i := len(outputs) / 4; i < len(_res); i++ { + for i := len(outputs) / d; i < len(_res); i++ { if !_res[i].IsZero() { return fmt.Errorf("fftExtInvHintNative: expected zero at position %d, got %s", i, _res[i].String()) } } // now truncate to avoid returning the zeros. In non-native we range check // the results so it would lead to overhead - _res = _res[:len(outputs)/4] + _res = _res[:len(outputs)/d] for i := range _res { - _res[i].B0.A0.BigInt(outputs[4*i]) - _res[i].B0.A1.BigInt(outputs[4*i+1]) - _res[i].B1.A0.BigInt(outputs[4*i+2]) - _res[i].B1.A1.BigInt(outputs[4*i+3]) + _res[i].B0.A0.BigInt(outputs[d*i+0]) + _res[i].B0.A1.BigInt(outputs[d*i+1]) + _res[i].B1.A0.BigInt(outputs[d*i+2]) + _res[i].B1.A1.BigInt(outputs[d*i+3]) + _res[i].B2.A0.BigInt(outputs[d*i+4]) + _res[i].B2.A1.BigInt(outputs[d*i+5]) } return nil diff --git a/prover-ray/crypto/koalabear/vortex/verifier_common.go b/prover-ray/crypto/koalabear/vortex/verifier_common.go index b97dd57b4ed..a2c7e127e37 100644 --- a/prover-ray/crypto/koalabear/vortex/verifier_common.go +++ b/prover-ray/crypto/koalabear/vortex/verifier_common.go @@ -77,7 +77,7 @@ func CheckLinComb( y := polynomials.EvalCanonical(field.VecFromBase(fullCol), field.ElemFromExt(alpha)).AsExt() other := evals[selectedColID] - if y != other { + if !y.Equal(&other) { return fmt.Errorf("the linear combination is inconsistent %v : %v", y.String(), other.String()) } } diff --git a/prover-ray/go.mod b/prover-ray/go.mod index efab8ee83a4..83c5c406a3f 100644 --- a/prover-ray/go.mod +++ b/prover-ray/go.mod @@ -4,7 +4,7 @@ go 1.25.7 require ( github.com/consensys/gnark v0.14.1-0.20260219004710-bbfb2f70a565 - github.com/consensys/gnark-crypto v0.20.1 + github.com/consensys/gnark-crypto v0.20.2-0.20260514182922-df0578435b08 github.com/consensys/go-corset v1.2.10 github.com/go-playground/assert/v2 v2.2.0 github.com/sirupsen/logrus v1.9.4 diff --git a/prover-ray/go.sum b/prover-ray/go.sum index 441e3693950..9a7849aaf0b 100644 --- a/prover-ray/go.sum +++ b/prover-ray/go.sum @@ -13,8 +13,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/consensys/gnark v0.14.1-0.20260219004710-bbfb2f70a565 h1:NlOAmbLYsVb/hcuOBxza6CAA+233tB0eFiunGVEMyv4= github.com/consensys/gnark v0.14.1-0.20260219004710-bbfb2f70a565/go.mod h1:EoWWbEboQRydCqJDSA7zrFxucIeoy/5R+MDx04oFpF4= -github.com/consensys/gnark-crypto v0.20.1 h1:PXDUBvk8AzhvWowHLWBEAfUQcV1/aZgWIqD6eMpXmDg= -github.com/consensys/gnark-crypto v0.20.1/go.mod h1:RBWrSgy+IDbGR69RRV313th3M/aZU1ubk2om+qHuTSc= +github.com/consensys/gnark-crypto v0.20.2-0.20260514182922-df0578435b08 h1:EdljQKaHxACX5JMSTXlVM9R8qASU/W1husqDsB2b5tU= +github.com/consensys/gnark-crypto v0.20.2-0.20260514182922-df0578435b08/go.mod h1:NzeBHSZ49bIM7RtrNTYYR2kymTqwvI/A4eTgQlyQc+Q= github.com/consensys/go-corset v1.2.10 h1:uKUICiHmERuMWzDRiRJr285fV2WncNGiCENSdNcQodY= github.com/consensys/go-corset v1.2.10/go.mod h1:QKFoNJZHdCrDslg9XFjk+GoFMgrhKSVdBNnx4hq7WJA= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/prover-ray/maths/koalabear/circuit/api_test.go b/prover-ray/maths/koalabear/circuit/api_test.go index 52c5c1d61de..a34d85e1f08 100644 --- a/prover-ray/maths/koalabear/circuit/api_test.go +++ b/prover-ray/maths/koalabear/circuit/api_test.go @@ -97,13 +97,14 @@ func TestVarEmulated(t *testing.T) { // TestExtCircuit tests Ext operations type TestExtCircuit struct { - A, B Ext - AddAB Ext - SubAB Ext - MulAB Ext - SquareA Ext - DivAB Ext - InvA Ext + A, B Ext + AddAB Ext + SubAB Ext + MulAB Ext + SquareA Ext + MulByNonResidueA Ext + DivAB Ext + InvA Ext } func (c *TestExtCircuit) Define(api frontend.API) error { @@ -121,6 +122,9 @@ func (c *TestExtCircuit) Define(api frontend.API) error { squareA := f.SquareExt(c.A) f.AssertIsEqualExt(squareA, c.SquareA) + mulByNonResidueA := f.MulByNonResidueExt(c.A) + f.AssertIsEqualExt(mulByNonResidueA, c.MulByNonResidueA) + divAB := f.DivExt(c.A, c.B) f.AssertIsEqualExt(divAB, c.DivAB) @@ -131,29 +135,33 @@ func (c *TestExtCircuit) Define(api frontend.API) error { } func getExtWitness() *TestExtCircuit { - var a, b, addab, subab, mulab, squarea, inva, divab field.Ext + var a, b, addab, subab, mulab, squarea, mulByNonResidueA, inva, divab field.Ext if _, err := a.SetRandom(); err != nil { panic(err) } if _, err := b.SetRandom(); err != nil { panic(err) } + var nonResidueRoot field.Ext + nonResidueRoot.B1.A0.SetOne() addab.Add(&a, &b) subab.Sub(&a, &b) mulab.Mul(&a, &b) squarea.Square(&a) + mulByNonResidueA.Mul(&a, &nonResidueRoot) divab.Div(&a, &b) inva.Inverse(&a) return &TestExtCircuit{ - A: NewExt(a), - B: NewExt(b), - AddAB: NewExt(addab), - SubAB: NewExt(subab), - MulAB: NewExt(mulab), - SquareA: NewExt(squarea), - DivAB: NewExt(divab), - InvA: NewExt(inva), + A: NewExt(a), + B: NewExt(b), + AddAB: NewExt(addab), + SubAB: NewExt(subab), + MulAB: NewExt(mulab), + SquareA: NewExt(squarea), + MulByNonResidueA: NewExt(mulByNonResidueA), + DivAB: NewExt(divab), + InvA: NewExt(inva), } } diff --git a/prover-ray/maths/koalabear/circuit/doc.go b/prover-ray/maths/koalabear/circuit/doc.go index 5158b55ecbb..d20438ef61a 100644 --- a/prover-ray/maths/koalabear/circuit/doc.go +++ b/prover-ray/maths/koalabear/circuit/doc.go @@ -2,7 +2,7 @@ // // This package contains two main types: // - [Var]: A circuit variable over the KoalaBear base field -// - [Ext]: A circuit variable over the degree-4 extension field +// - [Ext]: A circuit variable over the degree-6 extension field // // These types abstract over native and emulated arithmetic, allowing the same // circuit code to work in both native KoalaBear circuits and emulated circuits diff --git a/prover-ray/maths/koalabear/circuit/ext.go b/prover-ray/maths/koalabear/circuit/ext.go index c1c05e36f7b..721d68d78cd 100644 --- a/prover-ray/maths/koalabear/circuit/ext.go +++ b/prover-ray/maths/koalabear/circuit/ext.go @@ -19,16 +19,17 @@ func init() { mulExtHintNative, mulExtHintEmulated) } -// E2 is a quadratic extension element . +// E2 is a quadratic extension element. // It represents an element of F_p^2 = F_p[u] / (u^2 - 3). type E2 struct { A0, A1 Element } -// Ext is a circuit variable over the degree-4 extension field. -// It represents an element of F_p^4 = F_p^2[v] / (v^2 - u). +// Ext is a circuit variable over the degree-6 extension field. +// It represents an element of F_p^6 = F_p^2[v] / (v^3 - (u+1)), i.e. each +// element is stored as (B0, B1, B2) with each Bi in E2. type Ext struct { - B0, B1 E2 + B0, B1, B2 E2 } // --- Ext Constructors (for witness assignment) --- @@ -38,6 +39,7 @@ func NewExt(v field.Ext) Ext { return Ext{ B0: newE2(v.B0), B1: newE2(v.B1), + B2: newE2(v.B2), } } @@ -54,6 +56,7 @@ func NewFromBaseExt(v any) Ext { return Ext{ B0: E2{A0: NewElement(v), A1: z}, B1: E2{A0: z, A1: z}, + B2: E2{A0: z, A1: z}, } } @@ -63,21 +66,23 @@ func NewExtFromFrontendVar(v frontend.Variable) Ext { return Ext{ B0: E2{A0: WrapFrontendVariable(v), A1: z}, B1: E2{A0: z, A1: z}, + B2: E2{A0: z, A1: z}, } } -// NewExtFrom4FrontendVars creates an Ext from 4 frontend.Variable values. -// The order is: B0.A0, B0.A1, B1.A0, B1.A1. -func NewExtFrom4FrontendVars(b0a0, b0a1, b1a0, b1a1 frontend.Variable) Ext { +// NewExtFrom6FrontendVars creates an Ext from 6 frontend.Variable values, one +// per coordinate (in the order B0.A0, B0.A1, B1.A0, B1.A1, B2.A0, B2.A1). +func NewExtFrom6FrontendVars(b0a0, b0a1, b1a0, b1a1, b2a0, b2a1 frontend.Variable) Ext { return Ext{ B0: E2{A0: WrapFrontendVariable(b0a0), A1: WrapFrontendVariable(b0a1)}, B1: E2{A0: WrapFrontendVariable(b1a0), A1: WrapFrontendVariable(b1a1)}, + B2: E2{A0: WrapFrontendVariable(b2a0), A1: WrapFrontendVariable(b2a1)}, } } -// Coordinates returns all 4 base field coordinates. -func (x Ext) Coordinates() (b0a0, b0a1, b1a0, b1a1 Element) { - return x.B0.A0, x.B0.A1, x.B1.A0, x.B1.A1 +// Coordinates returns all 6 base field coordinates. +func (x Ext) Coordinates() (b0a0, b0a1, b1a0, b1a1, b2a0, b2a1 Element) { + return x.B0.A0, x.B0.A1, x.B1.A0, x.B1.A1, x.B2.A0, x.B2.A1 } // FromBaseVar creates an Ext from a Var (for in-circuit conversion). @@ -87,6 +92,7 @@ func FromBaseVar(v Element) Ext { return Ext{ B0: E2{A0: v, A1: z}, B1: E2{A0: z, A1: z}, + B2: E2{A0: z, A1: z}, } } @@ -95,22 +101,22 @@ func FromBaseVar(v Element) Ext { // ZeroExt returns the additive identity in the extension field. func (a *API) ZeroExt() Ext { z := a.Zero() - return Ext{B0: E2{A0: z, A1: z}, B1: E2{A0: z, A1: z}} + return Ext{B0: E2{A0: z, A1: z}, B1: E2{A0: z, A1: z}, B2: E2{A0: z, A1: z}} } // OneExt returns the multiplicative identity in the extension field. func (a *API) OneExt() Ext { z, o := a.Zero(), a.One() - return Ext{B0: E2{A0: o, A1: z}, B1: E2{A0: z, A1: z}} + return Ext{B0: E2{A0: o, A1: z}, B1: E2{A0: z, A1: z}, B2: E2{A0: z, A1: z}} } // FromBaseExt creates an Ext element with a base field value in the constant term. func (a *API) FromBaseExt(x Element) Ext { z := a.Zero() - return Ext{B0: E2{A0: x, A1: z}, B1: E2{A0: z, A1: z}} + return Ext{B0: E2{A0: x, A1: z}, B1: E2{A0: z, A1: z}, B2: E2{A0: z, A1: z}} } -// ConstExt creates a constant Ext element from an field.Ext. +// ConstExt creates a constant Ext element from a field.Ext. // This should be used during circuit definition to create constant extension field values. // For witness assignment, use NewExt instead. func (a *API) ConstExt(v field.Ext) Ext { @@ -123,6 +129,10 @@ func (a *API) ConstExt(v field.Ext) Ext { A0: a.Const(int64(v.B1.A0.Uint64())), A1: a.Const(int64(v.B1.A1.Uint64())), }, + B2: E2{ + A0: a.Const(int64(v.B2.A0.Uint64())), + A1: a.Const(int64(v.B2.A1.Uint64())), + }, } } @@ -131,25 +141,29 @@ func (a *API) ConstExt(v field.Ext) Ext { // AddExt returns x + y in the extension field. func (a *API) AddExt(x, y Ext) Ext { return Ext{ - B0: E2{A0: a.Add(x.B0.A0, y.B0.A0), A1: a.Add(x.B0.A1, y.B0.A1)}, - B1: E2{A0: a.Add(x.B1.A0, y.B1.A0), A1: a.Add(x.B1.A1, y.B1.A1)}, + B0: a.e2Add(x.B0, y.B0), + B1: a.e2Add(x.B1, y.B1), + B2: a.e2Add(x.B2, y.B2), } } // SubExt returns x - y in the extension field. func (a *API) SubExt(x, y Ext) Ext { return Ext{ - B0: E2{A0: a.Sub(x.B0.A0, y.B0.A0), A1: a.Sub(x.B0.A1, y.B0.A1)}, - B1: E2{A0: a.Sub(x.B1.A0, y.B1.A0), A1: a.Sub(x.B1.A1, y.B1.A1)}, + B0: a.e2Sub(x.B0, y.B0), + B1: a.e2Sub(x.B1, y.B1), + B2: a.e2Sub(x.B2, y.B2), } } // NegExt returns -x in the extension field. func (a *API) NegExt(x Ext) Ext { z := a.Zero() + zero := E2{A0: z, A1: z} return Ext{ - B0: E2{A0: a.Sub(z, x.B0.A0), A1: a.Sub(z, x.B0.A1)}, - B1: E2{A0: a.Sub(z, x.B1.A0), A1: a.Sub(z, x.B1.A1)}, + B0: a.e2Sub(zero, x.B0), + B1: a.e2Sub(zero, x.B1), + B2: a.e2Sub(zero, x.B2), } } @@ -157,22 +171,22 @@ func (a *API) NegExt(x Ext) Ext { func (a *API) DoubleExt(x Ext) Ext { two := big.NewInt(2) return Ext{ - B0: E2{A0: a.MulConst(x.B0.A0, two), A1: a.MulConst(x.B0.A1, two)}, - B1: E2{A0: a.MulConst(x.B1.A0, two), A1: a.MulConst(x.B1.A1, two)}, + B0: a.e2MulConst(x.B0, two), + B1: a.e2MulConst(x.B1, two), + B2: a.e2MulConst(x.B2, two), } } -// qnrE2 is the non-residue constant for E2 extension (value 3). -// Used with MulConst to avoid unnecessary range checks. +// qnrE2 is the quadratic non-residue constant for E2: u^2 = 3. var qnrE2 = big.NewInt(3) -// e2MulByNonResidue multiplies an E2 by the non-residue u (where u^2 = 3). -// Returns (3*a1, a0). -func (a *API) e2MulByNonResidue(x E2) E2 { - return E2{ - A0: a.MulConst(x.A1, qnrE2), - A1: x.A0, - } +// e2MulByCubicNonResidue multiplies an E2 by the cubic non-residue (u+1). +// Given x = a0 + a1*u, (a0 + a1*u)*(1+u) = (a0+3*a1) + (a0+a1)*u (because u^2=3). +func (a *API) e2MulByCubicNonResidue(x E2) E2 { + z1 := a.Add(x.A0, x.A1) + z0 := a.MulConst(x.A1, qnrE2) // 3*a1 + z0 = a.Add(z0, x.A0) // a0 + 3*a1 + return E2{A0: z0, A1: z1} } // e2Add returns x + y in E2. @@ -231,22 +245,41 @@ func (a *API) e2MulConst(x E2, c *big.Int) E2 { } } -// MulExt returns x * y in the extension field using Karatsuba. -// (B0 + B1*v) * (C0 + C1*v) where v^2 = u +// MulExt returns x * y in the extension field using Karatsuba over E2. +// Implements Algorithm 13 from https://eprint.iacr.org/2010/354.pdf, specialized +// for E6 = E2[v]/(v^3 - (u+1)). Costs 6 E2 multiplications (โ‰ˆ18 variable base muls). func (a *API) MulExt(x, y Ext, more ...*Ext) Ext { - l1 := a.e2Add(x.B0, x.B1) - l2 := a.e2Add(y.B0, y.B1) - u := a.e2Mul(l1, l2) // (B0+B1)(C0+C1) - ac := a.e2Mul(x.B0, y.B0) // B0*C0 - bd := a.e2Mul(x.B1, y.B1) // B1*C1 - - sum := a.e2Add(ac, bd) - b1 := a.e2Sub(u, sum) // (B0+B1)(C0+C1) - B0*C0 - B1*C1 - - bdNR := a.e2MulByNonResidue(bd) - b0 := a.e2Add(ac, bdNR) - - result := Ext{B0: b0, B1: b1} + t0 := a.e2Mul(x.B0, y.B0) + t1 := a.e2Mul(x.B1, y.B1) + t2 := a.e2Mul(x.B2, y.B2) + + // z0 = ((B1+B2)*(C1+C2) - t1 - t2) * (u+1) + t0 + c0 := a.e2Add(x.B1, x.B2) + tmp := a.e2Add(y.B1, y.B2) + c0 = a.e2Mul(c0, tmp) + c0 = a.e2Sub(c0, t1) + c0 = a.e2Sub(c0, t2) + c0 = a.e2MulByCubicNonResidue(c0) + c0 = a.e2Add(c0, t0) + + // z1 = (B0+B1)*(C0+C1) - t0 - t1 + t2*(u+1) + c1 := a.e2Add(x.B0, x.B1) + tmp = a.e2Add(y.B0, y.B1) + c1 = a.e2Mul(c1, tmp) + c1 = a.e2Sub(c1, t0) + c1 = a.e2Sub(c1, t1) + t2NR := a.e2MulByCubicNonResidue(t2) + c1 = a.e2Add(c1, t2NR) + + // z2 = (B0+B2)*(C0+C2) - t0 - t2 + t1 + c2 := a.e2Add(x.B0, x.B2) + tmp = a.e2Add(y.B0, y.B2) + c2 = a.e2Mul(c2, tmp) + c2 = a.e2Sub(c2, t0) + c2 = a.e2Sub(c2, t2) + c2 = a.e2Add(c2, t1) + + result := Ext{B0: c0, B1: c1, B2: c2} if len(more) > 0 { return a.MulExt(result, *more[0], more[1:]...) @@ -254,17 +287,47 @@ func (a *API) MulExt(x, y Ext, more ...*Ext) Ext { return result } -// SquareExt returns x^2 in the extension field. +// SquareExt returns x^2 in the extension field, following Algorithm 16 from +// https://eprint.iacr.org/2010/354.pdf, specialized for E6 = E2[v]/(v^3-(u+1)). func (a *API) SquareExt(x Ext) Ext { - sum := a.e2Add(x.B0, x.B1) - d := a.e2Square(x.B0) - c := a.e2Square(x.B1) - sum = a.e2Square(sum) - bc := a.e2Add(d, c) - b1 := a.e2Sub(sum, bc) - cNR := a.e2MulByNonResidue(c) - b0 := a.e2Add(cNR, d) - return Ext{B0: b0, B1: b1} + // c4 = 2*B0*B1 + c4 := a.e2Mul(x.B0, x.B1) + c4 = a.e2MulConst(c4, big.NewInt(2)) + + // c5 = B2^2 + c5 := a.e2Square(x.B2) + + // c1 = c5*(u+1) + c4 + c1 := a.e2MulByCubicNonResidue(c5) + c1 = a.e2Add(c1, c4) + + // c2 = c4 - c5 + c2 := a.e2Sub(c4, c5) + + // c3 = B0^2 + c3 := a.e2Square(x.B0) + + // c4 = B0 - B1 + B2 + c4 = a.e2Sub(x.B0, x.B1) + c4 = a.e2Add(c4, x.B2) + + // c5 = 2*B1*B2 + c5 = a.e2Mul(x.B1, x.B2) + c5 = a.e2MulConst(c5, big.NewInt(2)) + + // c4 = c4^2 + c4 = a.e2Square(c4) + + // c0 = c5*(u+1) + c3 + c0 := a.e2MulByCubicNonResidue(c5) + c0 = a.e2Add(c0, c3) + + // z.B2 = c2 + c4 + c5 - c3 + b2 := a.e2Add(c2, c4) + b2 = a.e2Add(b2, c5) + b2 = a.e2Sub(b2, c3) + + return Ext{B0: c0, B1: c1, B2: b2} } // MulByE2Ext multiplies an Ext by an E2 element. @@ -272,6 +335,7 @@ func (a *API) MulByE2Ext(x Ext, c E2) Ext { return Ext{ B0: a.e2Mul(x.B0, c), B1: a.e2Mul(x.B1, c), + B2: a.e2Mul(x.B2, c), } } @@ -280,6 +344,7 @@ func (a *API) MulByFpExt(x Ext, c Element) Ext { return Ext{ B0: a.e2MulByFp(x.B0, c), B1: a.e2MulByFp(x.B1, c), + B2: a.e2MulByFp(x.B2, c), } } @@ -288,6 +353,7 @@ func (a *API) MulConstExt(x Ext, c *big.Int) Ext { return Ext{ B0: a.e2MulConst(x.B0, c), B1: a.e2MulConst(x.B1, c), + B2: a.e2MulConst(x.B2, c), } } @@ -298,14 +364,9 @@ func (a *API) ModReduceExt(x Ext) Ext { return x } return Ext{ - B0: E2{ - A0: a.ModReduce(x.B0.A0), - A1: a.ModReduce(x.B0.A1), - }, - B1: E2{ - A0: a.ModReduce(x.B1.A0), - A1: a.ModReduce(x.B1.A1), - }, + B0: E2{A0: a.ModReduce(x.B0.A0), A1: a.ModReduce(x.B0.A1)}, + B1: E2{A0: a.ModReduce(x.B1.A0), A1: a.ModReduce(x.B1.A1)}, + B2: E2{A0: a.ModReduce(x.B2.A0), A1: a.ModReduce(x.B2.A1)}, } } @@ -314,58 +375,47 @@ func (a *API) AddByBaseExt(x Ext, y Element) Ext { return Ext{ B0: E2{A0: a.Add(x.B0.A0, y), A1: x.B0.A1}, B1: x.B1, + B2: x.B2, } } // SumExt returns x + y + z... func (a *API) SumExt(xs ...Ext) Ext { + // One scratch slice reused across the six coordinate-wise reductions to + // avoid allocating 6ร— len(xs) Elements per call on the hot witness path. + coords := make([]Element, len(xs)) - res := Ext{} - - // summing the B0.A0 terms using gnark's optimized [Sum] function. - b0A0s := make([]Element, len(xs)) - for i := range xs { - b0A0s[i] = xs[i].B0.A0 - } - res.B0.A0 = a.Sum(b0A0s...) - - // summing the B0.A1 terms using gnark's optimized [Sum] function. - b0A1s := make([]Element, len(xs)) - for i := range xs { - b0A1s[i] = xs[i].B0.A1 - } - res.B0.A1 = a.Sum(b0A1s...) - - // summing the B0.A0 terms using gnark's optimized [Sum] function. - b1A0s := make([]Element, len(xs)) - for i := range xs { - b1A0s[i] = xs[i].B1.A0 + sumCoord := func(get func(x Ext) Element) Element { + for i := range xs { + coords[i] = get(xs[i]) + } + return a.Sum(coords...) } - res.B1.A0 = a.Sum(b1A0s...) - // summing the B1.A1 terms using gnark's optimized [Sum] function. - b1A1s := make([]Element, len(xs)) - for i := range xs { - b1A1s[i] = xs[i].B1.A1 + return Ext{ + B0: E2{ + A0: sumCoord(func(x Ext) Element { return x.B0.A0 }), + A1: sumCoord(func(x Ext) Element { return x.B0.A1 }), + }, + B1: E2{ + A0: sumCoord(func(x Ext) Element { return x.B1.A0 }), + A1: sumCoord(func(x Ext) Element { return x.B1.A1 }), + }, + B2: E2{ + A0: sumCoord(func(x Ext) Element { return x.B2.A0 }), + A1: sumCoord(func(x Ext) Element { return x.B2.A1 }), + }, } - res.B1.A1 = a.Sum(b1A1s...) - - return res } -// MulByNonResidueExt multiplies by the non-residue v (where v^2 = u). +// MulByNonResidueExt multiplies x by v, where v is the irreducible cubic root +// generator (v^3 = u+1). Equivalent to a single-coordinate cyclic shift with +// (u+1) wrap on the highest slot. func (a *API) MulByNonResidueExt(x Ext) Ext { return Ext{ - B0: a.e2MulByNonResidue(x.B1), + B0: a.e2MulByCubicNonResidue(x.B2), B1: x.B0, - } -} - -// ConjugateExt returns the conjugate of x. -func (a *API) ConjugateExt(x Ext) Ext { - return Ext{ - B0: x.B0, - B1: E2{A0: a.Neg(x.B1.A0), A1: a.Neg(x.B1.A1)}, + B2: x.B1, } } @@ -375,7 +425,8 @@ func (a *API) ConjugateExt(x Ext) Ext { func (a *API) IsZeroExt(x Ext) frontend.Variable { b0Zero := a.And(a.IsZero(x.B0.A0), a.IsZero(x.B0.A1)) b1Zero := a.And(a.IsZero(x.B1.A0), a.IsZero(x.B1.A1)) - return a.And(b0Zero, b1Zero) + b2Zero := a.And(a.IsZero(x.B2.A0), a.IsZero(x.B2.A1)) + return a.And(a.And(b0Zero, b1Zero), b2Zero) } // SelectExt returns x if sel=1, y otherwise. @@ -389,6 +440,10 @@ func (a *API) SelectExt(sel frontend.Variable, x, y Ext) Ext { A0: a.Select(sel, x.B1.A0, y.B1.A0), A1: a.Select(sel, x.B1.A1, y.B1.A1), }, + B2: E2{ + A0: a.Select(sel, x.B2.A0, y.B2.A0), + A1: a.Select(sel, x.B2.A1, y.B2.A1), + }, } } @@ -398,6 +453,8 @@ func (a *API) AssertIsEqualExt(x, y Ext) { a.AssertIsEqual(x.B0.A1, y.B0.A1) a.AssertIsEqual(x.B1.A0, y.B1.A0) a.AssertIsEqual(x.B1.A1, y.B1.A1) + a.AssertIsEqual(x.B2.A0, y.B2.A0) + a.AssertIsEqual(x.B2.A1, y.B2.A1) } // --- Ext Division and Inverse --- @@ -405,7 +462,8 @@ func (a *API) AssertIsEqualExt(x, y Ext) { // InverseExt returns 1/x in the extension field. func (a *API) InverseExt(x Ext) Ext { hint := a.inverseExtHint() - res, err := a.NewHint(hint, 4, x.B0.A0, x.B0.A1, x.B1.A0, x.B1.A1) + res, err := a.NewHint(hint, extDegree, + x.B0.A0, x.B0.A1, x.B1.A0, x.B1.A1, x.B2.A0, x.B2.A1) if err != nil { panic(err) } @@ -420,9 +478,9 @@ func (a *API) InverseExt(x Ext) Ext { // DivExt returns x / y in the extension field. func (a *API) DivExt(x, y Ext) Ext { hint := a.divExtHint() - res, err := a.NewHint(hint, 4, - x.B0.A0, x.B0.A1, x.B1.A0, x.B1.A1, - y.B0.A0, y.B0.A1, y.B1.A0, y.B1.A1) + res, err := a.NewHint(hint, extDegree, + x.B0.A0, x.B0.A1, x.B1.A0, x.B1.A1, x.B2.A0, x.B2.A1, + y.B0.A0, y.B0.A1, y.B1.A0, y.B1.A1, y.B2.A0, y.B2.A1) if err != nil { panic(err) } @@ -439,14 +497,18 @@ func (a *API) DivByBaseExt(x Ext, y Element) Ext { return Ext{ B0: E2{A0: a.Div(x.B0.A0, y), A1: a.Div(x.B0.A1, y)}, B1: E2{A0: a.Div(x.B1.A0, y), A1: a.Div(x.B1.A1, y)}, + B2: E2{A0: a.Div(x.B2.A0, y), A1: a.Div(x.B2.A1, y)}, } } -// extFromVars creates an Ext from 4 Vars. +const extDegree = field.ExtensionDegree + +// extFromVars creates an Ext from 6 Vars. func (a *API) extFromVars(v []Element) Ext { return Ext{ B0: E2{A0: v[0], A1: v[1]}, B1: E2{A0: v[2], A1: v[3]}, + B2: E2{A0: v[4], A1: v[5]}, } } @@ -506,7 +568,7 @@ func (a *API) ExpVariableExponentExt(x Ext, exp frontend.Variable, expNumBits in // PrintlnExt prints Ext variables for debugging. func (a *API) PrintlnExt(vars ...Ext) { for i := range vars { - a.Println(vars[i].B0.A0, vars[i].B0.A1, vars[i].B1.A0, vars[i].B1.A1) + a.Println(vars[i].B0.A0, vars[i].B0.A1, vars[i].B1.A0, vars[i].B1.A1, vars[i].B2.A0, vars[i].B2.A1) } } @@ -515,43 +577,49 @@ func (a *API) PrintlnExt(vars ...Ext) { // NewHintExt calls a hint function with Ext inputs and outputs. func (a *API) NewHintExt(f solver.Hint, nbOutputs int, inputs ...Ext) ([]Ext, error) { if a.IsNative() { - flatInputs := make([]frontend.Variable, 4*len(inputs)) + flatInputs := make([]frontend.Variable, extDegree*len(inputs)) for i, r := range inputs { - flatInputs[4*i] = r.B0.A0.Native() - flatInputs[4*i+1] = r.B0.A1.Native() - flatInputs[4*i+2] = r.B1.A0.Native() - flatInputs[4*i+3] = r.B1.A1.Native() + flatInputs[extDegree*i+0] = r.B0.A0.Native() + flatInputs[extDegree*i+1] = r.B0.A1.Native() + flatInputs[extDegree*i+2] = r.B1.A0.Native() + flatInputs[extDegree*i+3] = r.B1.A1.Native() + flatInputs[extDegree*i+4] = r.B2.A0.Native() + flatInputs[extDegree*i+5] = r.B2.A1.Native() } - flatRes, err := a.nativeAPI.NewHint(f, 4*nbOutputs, flatInputs...) + flatRes, err := a.nativeAPI.NewHint(f, extDegree*nbOutputs, flatInputs...) if err != nil { return nil, err } res := make([]Ext, nbOutputs) for i := range res { res[i] = Ext{ - B0: E2{A0: Element{V: flatRes[4*i]}, A1: Element{V: flatRes[4*i+1]}}, - B1: E2{A0: Element{V: flatRes[4*i+2]}, A1: Element{V: flatRes[4*i+3]}}, + B0: E2{A0: Element{V: flatRes[extDegree*i+0]}, A1: Element{V: flatRes[extDegree*i+1]}}, + B1: E2{A0: Element{V: flatRes[extDegree*i+2]}, A1: Element{V: flatRes[extDegree*i+3]}}, + B2: E2{A0: Element{V: flatRes[extDegree*i+4]}, A1: Element{V: flatRes[extDegree*i+5]}}, } } return res, nil } - flatInputs := make([]*emulated.Element[emulated.KoalaBear], 4*len(inputs)) + flatInputs := make([]*emulated.Element[emulated.KoalaBear], extDegree*len(inputs)) for i, r := range inputs { - flatInputs[4*i] = r.B0.A0.Emulated() - flatInputs[4*i+1] = r.B0.A1.Emulated() - flatInputs[4*i+2] = r.B1.A0.Emulated() - flatInputs[4*i+3] = r.B1.A1.Emulated() - } - flatRes, err := a.emulatedAPI.NewHint(f, 4*nbOutputs, flatInputs...) + flatInputs[extDegree*i+0] = r.B0.A0.Emulated() + flatInputs[extDegree*i+1] = r.B0.A1.Emulated() + flatInputs[extDegree*i+2] = r.B1.A0.Emulated() + flatInputs[extDegree*i+3] = r.B1.A1.Emulated() + flatInputs[extDegree*i+4] = r.B2.A0.Emulated() + flatInputs[extDegree*i+5] = r.B2.A1.Emulated() + } + flatRes, err := a.emulatedAPI.NewHint(f, extDegree*nbOutputs, flatInputs...) if err != nil { return nil, err } res := make([]Ext, nbOutputs) for i := range res { res[i] = Ext{ - B0: E2{A0: Element{EV: *flatRes[4*i]}, A1: Element{EV: *flatRes[4*i+1]}}, - B1: E2{A0: Element{EV: *flatRes[4*i+2]}, A1: Element{EV: *flatRes[4*i+3]}}, + B0: E2{A0: Element{EV: *flatRes[extDegree*i+0]}, A1: Element{EV: *flatRes[extDegree*i+1]}}, + B1: E2{A0: Element{EV: *flatRes[extDegree*i+2]}, A1: Element{EV: *flatRes[extDegree*i+3]}}, + B2: E2{A0: Element{EV: *flatRes[extDegree*i+4]}, A1: Element{EV: *flatRes[extDegree*i+5]}}, } } return res, nil @@ -569,17 +637,30 @@ func inverseE2Hint(_ *big.Int, inputs []*big.Int, res []*big.Int) error { return nil } +func extFromInputs(inputs []*big.Int) (e field.Ext) { + e.B0.A0.SetBigInt(inputs[0]) + e.B0.A1.SetBigInt(inputs[1]) + e.B1.A0.SetBigInt(inputs[2]) + e.B1.A1.SetBigInt(inputs[3]) + e.B2.A0.SetBigInt(inputs[4]) + e.B2.A1.SetBigInt(inputs[5]) + return e +} + +func extToOutputs(e field.Ext, out []*big.Int) { + e.B0.A0.BigInt(out[0]) + e.B0.A1.BigInt(out[1]) + e.B1.A0.BigInt(out[2]) + e.B1.A1.BigInt(out[3]) + e.B2.A0.BigInt(out[4]) + e.B2.A1.BigInt(out[5]) +} + func inverseExtHintNative(_ *big.Int, inputs []*big.Int, res []*big.Int) error { - var a, c field.Ext - a.B0.A0.SetBigInt(inputs[0]) - a.B0.A1.SetBigInt(inputs[1]) - a.B1.A0.SetBigInt(inputs[2]) - a.B1.A1.SetBigInt(inputs[3]) + a := extFromInputs(inputs) + var c field.Ext c.Inverse(&a) - c.B0.A0.BigInt(res[0]) - c.B0.A1.BigInt(res[1]) - c.B1.A0.BigInt(res[2]) - c.B1.A1.BigInt(res[3]) + extToOutputs(c, res) return nil } @@ -595,20 +676,11 @@ func (a *API) inverseExtHint() solver.Hint { } func divExtHintNative(_ *big.Int, inputs []*big.Int, res []*big.Int) error { - var x, y, c field.Ext - x.B0.A0.SetBigInt(inputs[0]) - x.B0.A1.SetBigInt(inputs[1]) - x.B1.A0.SetBigInt(inputs[2]) - x.B1.A1.SetBigInt(inputs[3]) - y.B0.A0.SetBigInt(inputs[4]) - y.B0.A1.SetBigInt(inputs[5]) - y.B1.A0.SetBigInt(inputs[6]) - y.B1.A1.SetBigInt(inputs[7]) + x := extFromInputs(inputs[:extDegree]) + y := extFromInputs(inputs[extDegree : 2*extDegree]) + var c field.Ext c.Div(&x, &y) - c.B0.A0.BigInt(res[0]) - c.B0.A1.BigInt(res[1]) - c.B1.A0.BigInt(res[2]) - c.B1.A1.BigInt(res[3]) + extToOutputs(c, res) return nil } @@ -624,20 +696,11 @@ func (a *API) divExtHint() solver.Hint { } func mulExtHintNative(_ *big.Int, inputs []*big.Int, res []*big.Int) error { - var x, y, c field.Ext - x.B0.A0.SetBigInt(inputs[0]) - x.B0.A1.SetBigInt(inputs[1]) - x.B1.A0.SetBigInt(inputs[2]) - x.B1.A1.SetBigInt(inputs[3]) - y.B0.A0.SetBigInt(inputs[4]) - y.B0.A1.SetBigInt(inputs[5]) - y.B1.A0.SetBigInt(inputs[6]) - y.B1.A1.SetBigInt(inputs[7]) + x := extFromInputs(inputs[:extDegree]) + y := extFromInputs(inputs[extDegree : 2*extDegree]) + var c field.Ext c.Mul(&x, &y) - c.B0.A0.BigInt(res[0]) - c.B0.A1.BigInt(res[1]) - c.B1.A0.BigInt(res[2]) - c.B1.A1.BigInt(res[3]) + extToOutputs(c, res) return nil } @@ -650,24 +713,21 @@ func (a *API) IsConstantZeroExt(e Ext) bool { return a.IsConstantZero(e.B0.A0) && a.IsConstantZero(e.B0.A1) && a.IsConstantZero(e.B1.A0) && - a.IsConstantZero(e.B1.A1) + a.IsConstantZero(e.B1.A1) && + a.IsConstantZero(e.B2.A0) && + a.IsConstantZero(e.B2.A1) } // BaseValueOfElement returns true if the Ext element actually represents a -// a base field element and returns it as an [Element]. Namely, the function -// checks if the non-constant terms of the extension element are zero constants -// and returns the constant term if so. +// base field element and returns it as an [Element]. The function checks that +// every non-constant coordinate is a zero constant. func (a *API) BaseValueOfElement(e Ext) (*Element, bool) { - - var ( - b1a0IsConst = a.IsConstantZero(e.B1.A0) - b0a1IsConst = a.IsConstantZero(e.B0.A1) - b1a1IsConst = a.IsConstantZero(e.B1.A1) - ) - - if !b1a0IsConst || !b0a1IsConst || !b1a1IsConst { + if !a.IsConstantZero(e.B0.A1) || + !a.IsConstantZero(e.B1.A0) || + !a.IsConstantZero(e.B1.A1) || + !a.IsConstantZero(e.B2.A0) || + !a.IsConstantZero(e.B2.A1) { return nil, false } - return &e.B0.A0, true } diff --git a/prover-ray/maths/koalabear/field/ext.go b/prover-ray/maths/koalabear/field/ext.go index d1b98cafd20..65cb38d9d52 100644 --- a/prover-ray/maths/koalabear/field/ext.go +++ b/prover-ray/maths/koalabear/field/ext.go @@ -1,7 +1,6 @@ package field import ( - "encoding/binary" "errors" "fmt" "math/big" @@ -10,31 +9,28 @@ import ( "reflect" "runtime" - "github.com/consensys/gnark-crypto/field/koalabear" "github.com/consensys/gnark-crypto/field/koalabear/extensions" "github.com/consensys/linea-monorepo/prover-ray/utils/parallel" ) // ExtensionDegree is the degree of the field extension over the base field. -const ExtensionDegree int = 4 +const ExtensionDegree int = 6 -// Ext is the degree-4 extension field element type alias. -type Ext = extensions.E4 +// Ext is the degree-6 extension field element type alias. +// +// Layout: ๐”ฝ_{p^6} = ๐”ฝ_{p^2}[v] / (v^3 โˆ’ (u+1)) where ๐”ฝ_{p^2} = ๐”ฝ_p[u] / (u^2 โˆ’ 3). +// An element is stored as (B0, B1, B2) of three E2 = (A0, A1) pairs, packing +// six base-field coordinates contiguously in memory. +type Ext = extensions.E6 -// NewExtFromString only sets the first coordinate of the field extension -func NewExtFromString(s string) (res Ext) { - res.B0.A0 = NewFromString(s) - return res -} - -// RootPowers stores the non-quadratic residue used to defined the extension. -// [1, 3] encoding the relations v^2=u and u^2=3. -var RootPowers = []int{1, 3} +// RootPowers stores the irreducible polynomial coefficients used to define the +// tower. The first triple encodes v^3 = u + 1, the trailing 3 encodes u^2 = 3. +var RootPowers = []int{1, 1, 0, 3} // BatchInvertExt compute the inverses of all elements in the provided slice // using the Montgommery trick. Zeroes are ignored. func BatchInvertExt(a []Ext) []Ext { - return extensions.BatchInvertE4(a) + return extensions.BatchInvertE6(a) } // BatchInvertExtInto computes the inverses of all elements in a and writes the result into res. @@ -71,20 +67,31 @@ func BatchInvertExtInto(a, res []Ext) { } } -// PseudoRandExt returns a random field extension element. +// NewExtFromString only sets the first coordinate of the field extension +func NewExtFromString(s string) (res Ext) { + res.B0.A0 = NewFromString(s) + return res +} + +// PseudoRandExt returns a random field extension element with all six +// coordinates drawn from the same RNG. func PseudoRandExt(rng *rand.Rand) Ext { - result := new(Ext).SetZero() - result.B0.A0 = PseudoRand(rng) - result.B0.A1 = PseudoRand(rng) - result.B1.A0 = PseudoRand(rng) - result.B1.A1 = PseudoRand(rng) - return *result + var res Ext + res.B0.A0 = PseudoRand(rng) + res.B0.A1 = PseudoRand(rng) + res.B1.A0 = PseudoRand(rng) + res.B1.A1 = PseudoRand(rng) + res.B2.A0 = PseudoRand(rng) + res.B2.A1 = PseudoRand(rng) + return res } -// IsBase checks if the field extensionElement is a base element. An Element is -// considered a base element if all coordinates are 0 except for the first one. +// IsBase checks if the field extensionElement is a base element. An element +// lies in the base field iff every non-constant coordinate is zero. func IsBase(z *Ext) bool { - return z.B0.A1[0] == 0 && z.B1.A0[0] == 0 && z.B1.A1[0] == 0 + return z.B0.A1[0] == 0 && + z.B1.A0[0] == 0 && z.B1.A1[0] == 0 && + z.B2.A0[0] == 0 && z.B2.A1[0] == 0 } // GetBase attempts to unlift a field extension element into a base field @@ -109,6 +116,8 @@ func MulByBase(z *Ext, first *Ext, second *Element) *Ext { z.B0.A1.Mul(&first.B0.A1, second) z.B1.A0.Mul(&first.B1.A0, second) z.B1.A1.Mul(&first.B1.A1, second) + z.B2.A0.Mul(&first.B2.A0, second) + z.B2.A1.Mul(&first.B2.A1, second) return z } @@ -118,6 +127,8 @@ func DivByBase(z *Ext, first *Ext, second *Element) *Ext { z.B0.A1.Div(&first.B0.A1, second) z.B1.A0.Div(&first.B1.A0, second) z.B1.A1.Div(&first.B1.A1, second) + z.B2.A0.Div(&first.B2.A0, second) + z.B2.A1.Div(&first.B2.A1, second) return z } @@ -169,15 +180,22 @@ func SetInterface(z *Ext, i1 interface{}) (*Ext, error) { } z.B0.A1.SetZero() z.B1.SetZero() + z.B2.SetZero() return z, nil case *big.Int: if c1 == nil { return nil, errors.New("can't set fr.Element with ") } z.B0.A0.SetBigInt(c1) + z.B0.A1.SetZero() + z.B1.SetZero() + z.B2.SetZero() return z, nil case big.Int: z.B0.A0.SetBigInt(&c1) + z.B0.A1.SetZero() + z.B1.SetZero() + z.B2.SetZero() return z, nil case []byte: z := BytesToExt(c1) @@ -196,9 +214,10 @@ func ExtToText(z *Ext, base int) string { return "" } - res := fmt.Sprintf("%s + %s*u + (%s + %s*u)*v", z.B0.A0.Text(base), z.B0.A1.Text(base), z.B1.A0.Text(base), - z.B1.A1.Text(base)) - return res + return fmt.Sprintf("%s + %s*u + (%s + %s*u)*v + (%s + %s*u)*v^2", + z.B0.A0.Text(base), z.B0.A1.Text(base), + z.B1.A0.Text(base), z.B1.A1.Text(base), + z.B2.A0.Text(base), z.B2.A1.Text(base)) } // ParBatchInvertExt computes inverses of all elements in a in parallel using numCPU goroutines. @@ -236,12 +255,14 @@ func ZeroExt() Ext { return res } -// ExtToUint64s returns the value of z as a tuple of 4 uint64. -func ExtToUint64s(z *Ext) (uint64, uint64, uint64, uint64) { +// ExtToUint64s returns the value of z as a tuple of 6 uint64. +func ExtToUint64s(z *Ext) (uint64, uint64, uint64, uint64, uint64, uint64) { return uint64(z.B0.A0.Bits()[0]), uint64(z.B0.A1.Bits()[0]), uint64(z.B1.A0.Bits()[0]), - uint64(z.B1.A1.Bits()[0]) + uint64(z.B1.A1.Bits()[0]), + uint64(z.B2.A0.Bits()[0]), + uint64(z.B2.A1.Bits()[0]) } // SetExtFromUInt sets z to v and returns z. After conversion, z is in the base @@ -249,14 +270,20 @@ func ExtToUint64s(z *Ext) (uint64, uint64, uint64, uint64) { func SetExtFromUInt(z *Ext, v uint64) *Ext { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form z.B0.A0.SetUint64(v) - return z // z.toMont() + z.B0.A1.SetZero() + z.B1.SetZero() + z.B2.SetZero() + return z } // SetExtFromInt sets z to v and returns z. After conversion, z is in the base // field and can safely be converted into an [Element]. func SetExtFromInt(z *Ext, v int64) *Ext { z.B0.A0.SetInt64(v) - return z // z.toMont() + z.B0.A1.SetZero() + z.B1.SetZero() + z.B2.SetZero() + return z } // SetExtFromBase sets z to v and returns z. After conversion, z is in the base @@ -266,6 +293,8 @@ func SetExtFromBase(z *Ext, x *Element) *Ext { z.B0.A1[0] = 0 z.B1.A0[0] = 0 z.B1.A1[0] = 0 + z.B2.A0[0] = 0 + z.B2.A1[0] = 0 return z } @@ -277,23 +306,27 @@ func Uint64ToExt(b uint64) Ext { return res } -// UintsToExt constructs a field extension from 4 uint64 -func UintsToExt(v1, v2, v3, v4 uint64) Ext { +// UintsToExt constructs a field extension from 6 uint64. +func UintsToExt(v1, v2, v3, v4, v5, v6 uint64) Ext { var z Ext z.B0.A0.SetUint64(v1) z.B0.A1.SetUint64(v2) z.B1.A0.SetUint64(v3) z.B1.A1.SetUint64(v4) + z.B2.A0.SetUint64(v5) + z.B2.A1.SetUint64(v6) return z } -// IntsToExt constructs a field extension from 4 int64 -func IntsToExt(v1, v2, v3, v4 int64) Ext { +// IntsToExt constructs a field extension from 6 int64. +func IntsToExt(v1, v2, v3, v4, v5, v6 int64) Ext { var z Ext z.B0.A0.SetInt64(v1) z.B0.A1.SetInt64(v2) z.B1.A0.SetInt64(v3) z.B1.A1.SetInt64(v4) + z.B2.A0.SetInt64(v5) + z.B2.A1.SetInt64(v6) return z } @@ -336,15 +369,16 @@ func ExpByIntExt(z *Ext, x Ext, k int) *Ext { return z } -// ExtToBytes returns the value of z as a big-endian byte array (16 bytes). -func ExtToBytes(z *Ext) (res [Bytes * 4]byte) { - var result [Bytes * 4]byte +// ExtToBytes returns the value of z as a big-endian byte array (one block per +// coordinate, totalling 6 * Bytes bytes). +func ExtToBytes(z *Ext) (res [Bytes * ExtensionDegree]byte) { + var result [Bytes * ExtensionDegree]byte valBytes := z.B0.A0.Bytes() - copy(result[0:Bytes], valBytes[:]) + copy(result[0*Bytes:1*Bytes], valBytes[:]) valBytes = z.B0.A1.Bytes() - copy(result[Bytes:2*Bytes], valBytes[:]) + copy(result[1*Bytes:2*Bytes], valBytes[:]) valBytes = z.B1.A0.Bytes() copy(result[2*Bytes:3*Bytes], valBytes[:]) @@ -352,15 +386,24 @@ func ExtToBytes(z *Ext) (res [Bytes * 4]byte) { valBytes = z.B1.A1.Bytes() copy(result[3*Bytes:4*Bytes], valBytes[:]) + valBytes = z.B2.A0.Bytes() + copy(result[4*Bytes:5*Bytes], valBytes[:]) + + valBytes = z.B2.A1.Bytes() + copy(result[5*Bytes:6*Bytes], valBytes[:]) + return result } -// BytesToExt constructs an extension field element from a 16-byte big-endian encoding. +// BytesToExt constructs an extension field element from a (6 * Bytes)-byte +// big-endian encoding produced by [ExtToBytes]. func BytesToExt(data []byte) Ext { var res Ext - res.B0.A0 = koalabear.Element{binary.BigEndian.Uint32(data[0:4])} - res.B0.A1 = koalabear.Element{binary.BigEndian.Uint32(data[4:8])} - res.B1.A0 = koalabear.Element{binary.BigEndian.Uint32(data[8:12])} - res.B1.A1 = koalabear.Element{binary.BigEndian.Uint32(data[12:16])} + res.B0.A0.SetBytes(data[0*Bytes : 1*Bytes]) + res.B0.A1.SetBytes(data[1*Bytes : 2*Bytes]) + res.B1.A0.SetBytes(data[2*Bytes : 3*Bytes]) + res.B1.A1.SetBytes(data[3*Bytes : 4*Bytes]) + res.B2.A0.SetBytes(data[4*Bytes : 5*Bytes]) + res.B2.A1.SetBytes(data[5*Bytes : 6*Bytes]) return res } diff --git a/prover-ray/maths/koalabear/field/ext_bench_test.go b/prover-ray/maths/koalabear/field/ext_bench_test.go new file mode 100644 index 00000000000..16e6cd44474 --- /dev/null +++ b/prover-ray/maths/koalabear/field/ext_bench_test.go @@ -0,0 +1,135 @@ +package field + +import ( + "math/rand/v2" + "testing" +) + +// Benchmarks intended to be diffed across the E4โ†’E6 migration. + +const benchN = 1 << 16 + +func benchSeededRng() *rand.Rand { + // #nosec G404 -- deterministic seed for reproducible benchmarks. + return rand.New(rand.NewPCG(1, 2)) +} + +func benchExts(n int) []Ext { + rng := benchSeededRng() + v := make([]Ext, n) + for i := range v { + v[i] = PseudoRandExt(rng) + } + return v +} + +func benchElems(n int) []Element { + rng := benchSeededRng() + v := make([]Element, n) + for i := range v { + v[i] = PseudoRand(rng) + } + return v +} + +// BenchmarkExtMul measures cost of a single full extension multiplication. +func BenchmarkExtMul(b *testing.B) { + xs := benchExts(2) + x, y := xs[0], xs[1] + var z Ext + b.ResetTimer() + for i := 0; i < b.N; i++ { + z.Mul(&x, &y) + } + _ = z +} + +// BenchmarkExtSquare measures cost of a single extension squaring. +func BenchmarkExtSquare(b *testing.B) { + xs := benchExts(1) + x := xs[0] + var z Ext + b.ResetTimer() + for i := 0; i < b.N; i++ { + z.Square(&x) + } + _ = z +} + +// BenchmarkExtMulByBase measures cost of multiplying an extension element by a +// base field scalar (the common "scale" path). +func BenchmarkExtMulByBase(b *testing.B) { + x := benchExts(1)[0] + s := benchElems(1)[0] + var z Ext + b.ResetTimer() + for i := 0; i < b.N; i++ { + z.MulByElement(&x, &s) + } + _ = z +} + +// BenchmarkExtInverse measures cost of a single extension inversion. +func BenchmarkExtInverse(b *testing.B) { + x := benchExts(1)[0] + var z Ext + b.ResetTimer() + for i := 0; i < b.N; i++ { + z.Inverse(&x) + } + _ = z +} + +// BenchmarkBatchInvertExt measures the cost of the Montgomery batch inversion +// over a fixed-size vector. +func BenchmarkBatchInvertExt(b *testing.B) { + a := benchExts(benchN) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = BatchInvertExt(a) + } +} + +// BenchmarkParBatchInvertExt measures parallel batch inversion across all CPU +// cores; the only difference from the sequential one is the goroutine fan-out. +func BenchmarkParBatchInvertExt(b *testing.B) { + a := benchExts(benchN) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ParBatchInvertExt(a, 0) + } +} + +// BenchmarkVecAddExtExt measures element-wise extension addition over a vector. +func BenchmarkVecAddExtExt(b *testing.B) { + a := benchExts(benchN) + c := benchExts(benchN) + out := make([]Ext, benchN) + b.ResetTimer() + for i := 0; i < b.N; i++ { + VecAddExtExt(out, a, c) + } +} + +// BenchmarkVecMulExtExt measures element-wise extension multiplication. +func BenchmarkVecMulExtExt(b *testing.B) { + a := benchExts(benchN) + c := benchExts(benchN) + out := make([]Ext, benchN) + b.ResetTimer() + for i := 0; i < b.N; i++ { + VecMulExtExt(out, a, c) + } +} + +// BenchmarkVecScaleBaseExt measures the cost of scaling an extension vector +// by a single base scalar (uses MulByElement under the hood). +func BenchmarkVecScaleBaseExt(b *testing.B) { + a := benchExts(benchN) + s := benchElems(1)[0] + out := make([]Ext, benchN) + b.ResetTimer() + for i := 0; i < b.N; i++ { + VecScaleBaseExt(out, s, a) + } +} diff --git a/prover-ray/maths/koalabear/field/ext_test.go b/prover-ray/maths/koalabear/field/ext_test.go index c74439aada4..2ccc822046b 100644 --- a/prover-ray/maths/koalabear/field/ext_test.go +++ b/prover-ray/maths/koalabear/field/ext_test.go @@ -5,35 +5,45 @@ import ( "testing" ) -// TestNewExtFromUints verifies that NewExtFromUints correctly sets all four +// TestNewExtFromUints verifies that UintsToExt correctly sets all six // extension coordinates to the given canonical values. func TestNewExtFromUints(t *testing.T) { - e := UintsToExt(10, 20, 30, 40) - if got := e.B0.A0.Uint64(); got != 10 { - t.Errorf("B0.A0 = %d, want 10", got) + e := UintsToExt(10, 20, 30, 40, 50, 60) + checks := []struct { + name string + got uint64 + want uint64 + }{ + {"B0.A0", e.B0.A0.Uint64(), 10}, + {"B0.A1", e.B0.A1.Uint64(), 20}, + {"B1.A0", e.B1.A0.Uint64(), 30}, + {"B1.A1", e.B1.A1.Uint64(), 40}, + {"B2.A0", e.B2.A0.Uint64(), 50}, + {"B2.A1", e.B2.A1.Uint64(), 60}, } - if got := e.B0.A1.Uint64(); got != 20 { - t.Errorf("B0.A1 = %d, want 20", got) - } - if got := e.B1.A0.Uint64(); got != 30 { - t.Errorf("B1.A0 = %d, want 30", got) - } - if got := e.B1.A1.Uint64(); got != 40 { - t.Errorf("B1.A1 = %d, want 40", got) + for _, c := range checks { + if c.got != c.want { + t.Errorf("%s = %d, want %d", c.name, c.got, c.want) + } } } // TestNewExtFromInt verifies positive and negative int64 inputs. // Negative values are reduced mod p, consistent with Element.SetInt64. func TestNewExtFromInt(t *testing.T) { - e := IntsToExt(1, 2, 3, 4) - for i, got := range []uint64{e.B0.A0.Uint64(), e.B0.A1.Uint64(), e.B1.A0.Uint64(), e.B1.A1.Uint64()} { - if got != uint64(i+1) { - t.Errorf("coordinate %d = %d, want %d", i, got, i+1) + e := IntsToExt(1, 2, 3, 4, 5, 6) + got := []uint64{ + e.B0.A0.Uint64(), e.B0.A1.Uint64(), + e.B1.A0.Uint64(), e.B1.A1.Uint64(), + e.B2.A0.Uint64(), e.B2.A1.Uint64(), + } + for i, v := range got { + if v != uint64(i+1) { + t.Errorf("coordinate %d = %d, want %d", i, v, i+1) } } // Negative: SetInt64(-1) gives p-1. - e2 := IntsToExt(-1, 0, 0, 0) + e2 := IntsToExt(-1, 0, 0, 0, 0, 0) var want Element want.SetInt64(-1) checkElem(t, want, e2.B0.A0) @@ -45,7 +55,7 @@ func TestNewExtFromString(t *testing.T) { e := NewExtFromString("42") want := NewFromString("42") checkElem(t, want, e.B0.A0) - if !e.B0.A1.IsZero() || !e.B1.A0.IsZero() || !e.B1.A1.IsZero() { + if !extUpperIsZero(e) { t.Error("NewExtFromString: extension coordinates should be zero") } } @@ -117,7 +127,7 @@ func TestIsBaseFunc(t *testing.T) { for range testN { e := PseudoRandExt(rng) // Skip the rare case where all extension coordinates happen to be zero. - if e.B0.A1.IsZero() && e.B1.A0.IsZero() && e.B1.A1.IsZero() { + if extUpperIsZero(e) { continue } if IsBase(&e) { @@ -141,7 +151,7 @@ func TestGetBase(t *testing.T) { } for range testN { e := PseudoRandExt(rng) - if e.B0.A1.IsZero() && e.B1.A0.IsZero() && e.B1.A1.IsZero() { + if extUpperIsZero(e) { continue } if _, isBase := GetBase(&e); isBase { @@ -335,13 +345,43 @@ func TestSetInterface(t *testing.T) { } }) + t.Run("baseInputsClearUpperCoordinates", func(t *testing.T) { + cases := []struct { + name string + input any + want uint64 + }{ + {name: "string", input: "123", want: 123}, + {name: "*big.Int", input: big.NewInt(999), want: 999}, + {name: "big.Int", input: *big.NewInt(777), want: 777}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + z := PseudoRandExt(rng) + if _, err := SetInterface(&z, tc.input); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if z.B0.A0.Uint64() != tc.want { + t.Errorf("B0.A0 = %d, want %d", z.B0.A0.Uint64(), tc.want) + } + if !extUpperIsZero(z) { + t.Error("SetInterface base input should clear all extension coordinates") + } + }) + } + }) + t.Run("[]byte", func(t *testing.T) { // The []byte case in SetInterface returns a new *Ext built via // BytesToExt rather than modifying the receiver. This is an // inconsistency with every other type case (all of which modify the // receiver z), but it is the current behaviour, so the test checks // the returned pointer. - data := []byte{0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, 8} + data := make([]byte, ExtensionDegree*Bytes) + for i := range data { + data[i] = byte((i*17 + 3) & 0xff) + } var z Ext result, err := SetInterface(&z, data) if err != nil { @@ -416,22 +456,20 @@ func TestZeroExtFunc(t *testing.T) { } } -// TestToUint64s verifies consistency of ToUint64s: zero produces (0,0,0,0) and +// TestToUint64s verifies consistency of ToUint64s: zero produces all-zeros and // two copies of the same Ext produce identical tuples. func TestToUint64s(t *testing.T) { z := ZeroExt() - u1, u2, u3, u4 := ExtToUint64s(&z) - if u1 != 0 || u2 != 0 || u3 != 0 || u4 != 0 { - t.Errorf("ToUint64s(ZeroExt) = (%d,%d,%d,%d), want (0,0,0,0)", u1, u2, u3, u4) + u0, u1, u2, u3, u4, u5 := ExtToUint64s(&z) + if u0 != 0 || u1 != 0 || u2 != 0 || u3 != 0 || u4 != 0 || u5 != 0 { + t.Errorf("ToUint64s(ZeroExt) = (%d,%d,%d,%d,%d,%d), want all zeros", u0, u1, u2, u3, u4, u5) } rng := newRng() for range testN { a := PseudoRandExt(rng) b := a // value copy - a1, a2, a3, a4 := ExtToUint64s(&a) - b1, b2, b3, b4 := ExtToUint64s(&b) - if a1 != b1 || a2 != b2 || a3 != b3 || a4 != b4 { + if extToUint64sTuple(&a) != extToUint64sTuple(&b) { t.Error("ToUint64s: equal Ext values should produce equal tuples") } } @@ -445,7 +483,7 @@ func TestSetExtFromUInt(t *testing.T) { if z.B0.A0.Uint64() != 42 { t.Errorf("B0.A0 = %d, want 42", z.B0.A0.Uint64()) } - if !z.B0.A1.IsZero() || !z.B1.A0.IsZero() || !z.B1.A1.IsZero() { + if !extUpperIsZero(z) { t.Error("SetExtFromUInt: extension coordinates should be zero") } } @@ -474,21 +512,21 @@ func TestSetExtFromBase(t *testing.T) { var z Ext SetExtFromBase(&z, &e) checkElem(t, e, z.B0.A0) - if !z.B0.A1.IsZero() || !z.B1.A0.IsZero() || !z.B1.A1.IsZero() { + if !extUpperIsZero(z) { t.Error("SetExtFromBase: extension coordinates should be zero") } } } -// TestNewExtFromUint verifies that NewExtFromUint sets B0.A0 to the given +// TestNewExtFromUint verifies that Uint64ToExt sets B0.A0 to the given // value and leaves the remaining coordinates at zero. func TestNewExtFromUint(t *testing.T) { e := Uint64ToExt(55) if e.B0.A0.Uint64() != 55 { t.Errorf("B0.A0 = %d, want 55", e.B0.A0.Uint64()) } - if !e.B0.A1.IsZero() || !e.B1.A0.IsZero() || !e.B1.A1.IsZero() { - t.Error("NewExtFromUint: extension coordinates should be zero") + if !extUpperIsZero(e) { + t.Error("Uint64ToExt: extension coordinates should be zero") } } @@ -546,13 +584,11 @@ func TestExpByIntExt(t *testing.T) { // 1. All-zero bytes produce the zero extension element. // 2. BytesToExt and SetInterface([]byte) agree on their output. // -// Note: BytesToExt stores the input uint32 values as raw Montgomery-form -// internal representations. ToUint64s uses Bits()[0] which converts *from* -// Montgomery form (multiplying by Rโปยน = MontConstantInv), so -// ToUint64s(BytesToExt(data)) โ‰  the raw numeric bytes โ€” this is expected. +// Note: BytesToExt expects the canonical coordinate encoding produced by +// Element.Bytes. func TestBytesToExt(t *testing.T) { // All-zero bytes produce the zero element (zero is zero in any representation). - zeroData := make([]byte, 16) + zeroData := make([]byte, ExtensionDegree*Bytes) if e := BytesToExt(zeroData); !e.IsZero() { t.Error("BytesToExt(zeros) should be the zero extension element") } @@ -563,19 +599,21 @@ func TestBytesToExt(t *testing.T) { src := PseudoRandExt(rng) b := ExtToBytes(&src) fromFunc := BytesToExt(b[:]) + if !extEq(src, fromFunc) { + t.Error("BytesToExt(ExtToBytes(src)) should round-trip to src") + } var dummy Ext result, err := SetInterface(&dummy, b[:]) if err != nil { t.Fatalf("SetInterface([]byte): %v", err) } - // SetInterface for []byte returns a new *Ext (see SetInterface/[]byte test). if !extEq(fromFunc, *result) { t.Error("BytesToExt and SetInterface([]byte) should produce equal results") } } } -// TestExtFromBytes verifies that ExtFromBytes extracts the canonical byte +// TestExtFromBytes verifies that ExtToBytes extracts the canonical byte // representation of each coordinate, consistent with Element.Bytes(). func TestExtFromBytes(t *testing.T) { rng := newRng() @@ -583,23 +621,13 @@ func TestExtFromBytes(t *testing.T) { e := PseudoRandExt(rng) b := ExtToBytes(&e) - b00 := e.B0.A0.Bytes() - b01 := e.B0.A1.Bytes() - b10 := e.B1.A0.Bytes() - b11 := e.B1.A1.Bytes() - - for i := 0; i < Bytes; i++ { - if b[i] != b00[i] { - t.Errorf("B0.A0 byte %d: got %d, want %d", i, b[i], b00[i]) - } - if b[Bytes+i] != b01[i] { - t.Errorf("B0.A1 byte %d: got %d, want %d", i, b[Bytes+i], b01[i]) - } - if b[2*Bytes+i] != b10[i] { - t.Errorf("B1.A0 byte %d: got %d, want %d", i, b[2*Bytes+i], b10[i]) - } - if b[3*Bytes+i] != b11[i] { - t.Errorf("B1.A1 byte %d: got %d, want %d", i, b[3*Bytes+i], b11[i]) + coords := []Element{e.B0.A0, e.B0.A1, e.B1.A0, e.B1.A1, e.B2.A0, e.B2.A1} + for k, c := range coords { + cb := c.Bytes() + for i := 0; i < Bytes; i++ { + if b[k*Bytes+i] != cb[i] { + t.Errorf("coord %d byte %d: got %d, want %d", k, i, b[k*Bytes+i], cb[i]) + } } } } @@ -610,3 +638,16 @@ func TestExtFromBytes(t *testing.T) { func TestRandomElementExt(_ *testing.T) { _ = RandomElementExt() } + +// extUpperIsZero reports whether every coordinate except B0.A0 is zero. +func extUpperIsZero(e Ext) bool { + return e.B0.A1.IsZero() && + e.B1.A0.IsZero() && e.B1.A1.IsZero() && + e.B2.A0.IsZero() && e.B2.A1.IsZero() +} + +// extToUint64sTuple wraps ExtToUint64s into a comparable struct for tests. +func extToUint64sTuple(e *Ext) [6]uint64 { + a, b, c, d, f, g := ExtToUint64s(e) + return [6]uint64{a, b, c, d, f, g} +} diff --git a/prover-ray/maths/koalabear/field/gen.go b/prover-ray/maths/koalabear/field/gen.go index 6b4e50b5dbb..fb8b7545d31 100644 --- a/prover-ray/maths/koalabear/field/gen.go +++ b/prover-ray/maths/koalabear/field/gen.go @@ -4,7 +4,7 @@ package field import "math/rand/v2" // Gen is a union type that holds either a base field element ([Element]) -// or a degree-4 extension field element ([Ext]). The embedded [Ext] is the +// or a degree-6 extension field element ([Ext]). The embedded [Ext] is the // canonical storage in both cases; [Gen.IsBase] tracks whether the value // was constructed from a base element and has remained in the base field // through subsequent operations. @@ -76,8 +76,8 @@ func (e Gen) Sub(b Gen) Gen { } // Mul returns e * b. When both operands are base, it uses the cheaper base -// field multiplication (1 mul vs ~9). When exactly one is base, it uses -// [Ext.MulByElement] (4 muls vs ~9). +// field multiplication (1 mul vs ~24). When exactly one is base, it uses +// [Ext.MulByElement] (6 muls vs ~24). // The result is tagged base iff both operands are base. func (e Gen) Mul(b Gen) Gen { if e.isBase && b.isBase { diff --git a/prover-ray/maths/koalabear/field/gen_test.go b/prover-ray/maths/koalabear/field/gen_test.go index f44094b12d7..31e0a4f045c 100644 --- a/prover-ray/maths/koalabear/field/gen_test.go +++ b/prover-ray/maths/koalabear/field/gen_test.go @@ -22,7 +22,9 @@ func extEq(a, b Ext) bool { return a.B0.A0.Equal(&b.B0.A0) && a.B0.A1.Equal(&b.B0.A1) && a.B1.A0.Equal(&b.B1.A0) && - a.B1.A1.Equal(&b.B1.A1) + a.B1.A1.Equal(&b.B1.A1) && + a.B2.A0.Equal(&b.B2.A0) && + a.B2.A1.Equal(&b.B2.A1) } // checkExt marks t as failed if want != got. diff --git a/prover-ray/maths/koalabear/field/vec.go b/prover-ray/maths/koalabear/field/vec.go index 50e3b9d91be..51b31c19d47 100644 --- a/prover-ray/maths/koalabear/field/vec.go +++ b/prover-ray/maths/koalabear/field/vec.go @@ -7,7 +7,7 @@ import ( ) // Vec holds a vector of field elements in either the base field ๐”ฝ_p or -// the degree-4 extension ๐”ฝ_{p^4}. Exactly one of base or ext is non-nil; +// the degree-6 extension ๐”ฝ_{p^6}. Exactly one of base or ext is non-nil; // this invariant is the caller's responsibility when constructing via // [VecFromBase] or [VecFromExt]. // @@ -85,7 +85,7 @@ func VecAddBaseBase(res, a, b []Element) { // VecAddExtBase sets res[i] = a[i] + b[i] where a is an extension vector and // b is a base vector. Only the first coordinate of each a[i] is updated; the -// remaining three are copied unchanged. Cost: 1 base addition per element. +// remaining five are copied unchanged. Cost: 1 base addition per element. // All slices must have equal length. func VecAddExtBase(res []Ext, a []Ext, b []Element) { mustEqualLen(len(res), len(a), len(b)) @@ -103,7 +103,7 @@ func VecAddBaseExt(res []Ext, a []Element, b []Ext) { } // VecAddExtExt sets res[i] = a[i] + b[i] over the extension field. -// Cost: 4 base additions per element. All slices must have equal length. +// Cost: 6 base additions per element. All slices must have equal length. func VecAddExtExt(res, a, b []Ext) { mustEqualLen(len(res), len(a), len(b)) for i := range res { @@ -156,20 +156,20 @@ func VecSubExtBase(res []Ext, a []Ext, b []Element) { // VecSubBaseExt sets res[i] = a[i] - b[i] where a is a base vector and b is // an extension vector. Note that subtraction is not commutative, so this // cannot simply delegate to VecSubExtBase. -// Cost: 4 base negations + 1 base addition per element. +// Cost: 6 base negations + 1 base addition per element. // All slices must have equal length. func VecSubBaseExt(res []Ext, a []Element, b []Ext) { mustEqualLen(len(res), len(a), len(b)) for i := range res { // Compute Lift(a[i]) - b[i] without allocating: - // negate all four components of b[i], then add a[i] to the first. + // negate all six components of b[i], then add a[i] to the first. res[i].Neg(&b[i]) res[i].B0.A0.Add(&res[i].B0.A0, &a[i]) } } // VecSubExtExt sets res[i] = a[i] - b[i] over the extension field. -// Cost: 4 base subtractions per element. All slices must have equal length. +// Cost: 6 base subtractions per element. All slices must have equal length. func VecSubExtExt(res, a, b []Ext) { mustEqualLen(len(res), len(a), len(b)) for i := range res { @@ -212,7 +212,7 @@ func VecMulBaseBase(res, a, b []Element) { // VecMulExtBase sets res[i] = a[i] * b[i] where a is an extension vector and // b is a base vector. Uses [Ext.MulByElement] which exploits the base-field // structure of b[i]. -// Cost: 4 base multiplications per element (vs ~9 for full extension mul). +// Cost: 6 base multiplications per element (vs ~24 for full extension mul). // All slices must have equal length. func VecMulExtBase(res []Ext, a []Ext, b []Element) { mustEqualLen(len(res), len(a), len(b)) @@ -224,14 +224,14 @@ func VecMulExtBase(res []Ext, a []Ext, b []Element) { // VecMulBaseExt sets res[i] = a[i] * b[i] where a is a base vector and b is // an extension vector. Multiplication is commutative, so this delegates to // [VecMulExtBase]. -// Cost: 4 base multiplications per element. +// Cost: 6 base multiplications per element. // All slices must have equal length. func VecMulBaseExt(res []Ext, a []Element, b []Ext) { VecMulExtBase(res, b, a) } // VecMulExtExt sets res[i] = a[i] * b[i] over the extension field. -// Cost: ~9 base multiplications per element (Karatsuba over E2). +// Cost: ~24 base multiplications per element (Karatsuba over E2 for E6). // All slices must have equal length. func VecMulExtExt(res, a, b []Ext) { mustEqualLen(len(res), len(a), len(b)) @@ -309,7 +309,7 @@ func VecScaleBaseBase(res []Element, s Element, a []Element) { // VecScaleBaseExt sets res[i] = s * a[i] where s is a base scalar and a is an // extension vector. Uses [Ext.MulByElement] to exploit the base structure of s. -// Cost: 4 base multiplications per element. +// Cost: 6 base multiplications per element. // res and a must have equal length. func VecScaleBaseExt(res []Ext, s Element, a []Ext) { mustEqualLen2(len(res), len(a)) @@ -321,7 +321,7 @@ func VecScaleBaseExt(res []Ext, s Element, a []Ext) { // VecScaleExtBase sets res[i] = s * a[i] where s is an extension scalar and a // is a base vector. Uses [Ext.MulByElement] to exploit the base structure of // each a[i]. -// Cost: 4 base multiplications per element. +// Cost: 6 base multiplications per element. // res and a must have equal length. func VecScaleExtBase(res []Ext, s Ext, a []Element) { mustEqualLen2(len(res), len(a)) @@ -331,7 +331,7 @@ func VecScaleExtBase(res []Ext, s Ext, a []Element) { } // VecScaleExtExt sets res[i] = s * a[i] where s and a are both extension-field. -// Cost: ~9 base multiplications per element. +// Cost: ~24 base multiplications per element. // res and a must have equal length. func VecScaleExtExt(res []Ext, s Ext, a []Ext) { mustEqualLen2(len(res), len(a)) diff --git a/prover-ray/maths/koalabear/polynomials/canonical.go b/prover-ray/maths/koalabear/polynomials/canonical.go index f3690069f8a..7025f6d658a 100644 --- a/prover-ray/maths/koalabear/polynomials/canonical.go +++ b/prover-ray/maths/koalabear/polynomials/canonical.go @@ -44,7 +44,7 @@ func LCEvalsToCoefficients(d *fft.Domain, v []field.Ext) []field.Ext { } coeffs := make([]field.Ext, n) copy(coeffs, v) - d.FFTInverseExt(coeffs, fft.DIF) + d.FFTInverseExt6(coeffs, fft.DIF) utils.BitReverse(coeffs) return coeffs } diff --git a/prover-ray/maths/koalabear/polynomials/canonical_test.go b/prover-ray/maths/koalabear/polynomials/canonical_test.go index b6177a70c07..1876bf2eda8 100644 --- a/prover-ray/maths/koalabear/polynomials/canonical_test.go +++ b/prover-ray/maths/koalabear/polynomials/canonical_test.go @@ -26,27 +26,16 @@ func randBase(_ *rand.Rand) field.Element { } func randExt(_ *rand.Rand) field.Ext { - var e field.Ext - if _, err := e.B0.A0.SetRandom(); err != nil { - panic(err) - } - if _, err := e.B0.A1.SetRandom(); err != nil { - panic(err) - } - if _, err := e.B1.A0.SetRandom(); err != nil { - panic(err) - } - if _, err := e.B1.A1.SetRandom(); err != nil { - panic(err) - } - return e + return field.RandomElementExt() } func extEq(a, b field.Ext) bool { return a.B0.A0.Equal(&b.B0.A0) && a.B0.A1.Equal(&b.B0.A1) && a.B1.A0.Equal(&b.B1.A0) && - a.B1.A1.Equal(&b.B1.A1) + a.B1.A1.Equal(&b.B1.A1) && + a.B2.A0.Equal(&b.B2.A0) && + a.B2.A1.Equal(&b.B2.A1) } // hornerExt evaluates p(X) = ฮฃแตข p[i]ยทXโฑ at x using Horner's method directly diff --git a/prover-ray/maths/koalabear/polynomials/doc.go b/prover-ray/maths/koalabear/polynomials/doc.go index 484abdcc002..171645410be 100644 --- a/prover-ray/maths/koalabear/polynomials/doc.go +++ b/prover-ray/maths/koalabear/polynomials/doc.go @@ -1,5 +1,5 @@ // Package polynomials provides native polynomial evaluation utilities over the -// KoalaBear field and its degree-4 extension, using the union types +// KoalaBear field and its degree-6 extension, using the union types // [field.Vec] and [field.Gen] for type-aware dispatch. // // Two evaluation bases are supported: diff --git a/prover-ray/maths/koalabear/polynomials/lagrange_bench_test.go b/prover-ray/maths/koalabear/polynomials/lagrange_bench_test.go new file mode 100644 index 00000000000..ad0c4b0977e --- /dev/null +++ b/prover-ray/maths/koalabear/polynomials/lagrange_bench_test.go @@ -0,0 +1,36 @@ +package polynomials + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover-ray/maths/koalabear/field" +) + +const benchSize = 1 << 14 + +// BenchmarkEvalLagrangeExtExt evaluates a Lagrange-basis extension polynomial +// at an extension point โ€” this is the hottest path that depends on extension +// multiplication and batch inversion together. +func BenchmarkEvalLagrangeExtExt(b *testing.B) { + rng := newRng() + evals := make([]field.Ext, benchSize) + for i := range evals { + evals[i] = randExt(rng) + } + z := field.ElemFromExt(randExt(rng)) + poly := field.VecFromExt(evals) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = EvalLagrange(poly, z) + } +} + +// BenchmarkComputeLagrangeAtZExt builds the entire Lแตข(z) vector for z โˆˆ F_{p^k}. +func BenchmarkComputeLagrangeAtZExt(b *testing.B) { + rng := newRng() + z := field.ElemFromExt(randExt(rng)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ComputeLagrangeAtZ(z, uint64(benchSize)) + } +} diff --git a/prover-ray/maths/koalabear/polynomials/lagrange_test.go b/prover-ray/maths/koalabear/polynomials/lagrange_test.go index 0dcd6538386..e40982c510d 100644 --- a/prover-ray/maths/koalabear/polynomials/lagrange_test.go +++ b/prover-ray/maths/koalabear/polynomials/lagrange_test.go @@ -8,7 +8,7 @@ import ( "github.com/consensys/linea-monorepo/prover-ray/maths/koalabear/field" ) -// fftExtInplace applies the FFT to each of the 4 coordinates of an []Ext slice +// fftExtInplace applies the FFT to each of the 6 coordinates of an []Ext slice // individually, using the base-field FFT provided by *fft.Domain. // After the call, poly holds the Lagrange (evaluation) form of the input. func fftExtInplace(poly []field.Ext, d *fft.Domain) { @@ -30,6 +30,8 @@ func fftExtInplace(poly []field.Ext, d *fft.Domain) { copyCoord(func(i int) *field.Element { return &poly[i].B0.A1 }) copyCoord(func(i int) *field.Element { return &poly[i].B1.A0 }) copyCoord(func(i int) *field.Element { return &poly[i].B1.A1 }) + copyCoord(func(i int) *field.Element { return &poly[i].B2.A0 }) + copyCoord(func(i int) *field.Element { return &poly[i].B2.A1 }) } // fftBaseInplace applies the FFT to a []Element slice, converting from @@ -59,6 +61,8 @@ func ifftExtInplace(poly []field.Ext, d *fft.Domain) { copyCoord(func(i int) *field.Element { return &poly[i].B0.A1 }) copyCoord(func(i int) *field.Element { return &poly[i].B1.A0 }) copyCoord(func(i int) *field.Element { return &poly[i].B1.A1 }) + copyCoord(func(i int) *field.Element { return &poly[i].B2.A0 }) + copyCoord(func(i int) *field.Element { return &poly[i].B2.A1 }) } func TestEvalLagrange(t *testing.T) { diff --git a/prover-ray/transition_fp6_worklog.md b/prover-ray/transition_fp6_worklog.md new file mode 100644 index 00000000000..b2c5c5c2236 --- /dev/null +++ b/prover-ray/transition_fp6_worklog.md @@ -0,0 +1,287 @@ +# transition_fp6_worklog.md โ€” KoalaBear E4 โ†’ E6 migration + +Worklog for the migration of the `prover-ray` package from the degree-4 +KoalaBear extension field (`extensions.E4`) to the degree-6 extension +(`extensions.E6`) shipped in `gnark-crypto@v0.20.2-0.20260514182922-df0578435b08`. + +The target was every reference to the degree-4 extension under `prover-ray/` +and all of its sub-packages (`maths`, `crypto`, `wiop`). This is not a +source- or binary-compatible change: constructors, public byte encodings, and +some circuit helper names now expose six extension coordinates. + +--- + +## 1. Tower & layout + +Before (E4): + +- ๐”ฝ_{p^4} = ๐”ฝ_{p^2}[v] / (vยฒ โˆ’ u) +- ๐”ฝ_{p^2} = ๐”ฝ_p[u] / (uยฒ โˆ’ 3) +- Storage layout: `{B0, B1 E2}` where `E2 = {A0, A1 fr.Element}` โ†’ 4 base coords. + +After (E6): + +- ๐”ฝ_{p^6} = ๐”ฝ_{p^2}[v] / (vยณ โˆ’ (u+1)) *(cubic non-residue: `u+1`)* +- ๐”ฝ_{p^2} = ๐”ฝ_p[u] / (uยฒ โˆ’ 3) *(quadratic non-residue: 3)* +- Storage layout: `{B0, B1, B2 E2}` โ†’ 6 base coords; gnark-crypto names the + type `extensions.E6`. + +The `field.ExtensionDegree` constant moved from `4` โ†’ `6`. `field.RootPowers` +was repurposed to encode the new tower: `[1, 1, 0, 3]` packs the cubic +relation `vยณ = u + 1` (first three entries: u + 1) followed by `uยฒ = 3`. + +--- + +## 2. File-by-file summary + +### `maths/koalabear/field/ext.go` + +Full rewrite for E6. + +- `Ext = extensions.E6` +- `ExtensionDegree = 6` +- All coordinate-reading helpers (`IsBase`, `GetBase`, `MulByBase`, + `DivByBase`, `SetExtFromBase`, `PseudoRandExt`, `ExtToUint64s`, + `UintsToExt`, `IntsToExt`) updated to scan/write six coordinates. +- `BatchInvertExt` delegates to `extensions.BatchInvertE6`. +- `BatchInvertExtInto`, `ParBatchInvertExt` unchanged in shape โ€” they + multiplex over the new `Ext` alias. +- `ExtToBytes` / `BytesToExt` now operate on `ExtensionDegree * Bytes` + bytes (24) per element. +- `ExtToText` formats the six coordinates. + +### `maths/koalabear/field/{vec.go, gen.go}` + +Comment-level updates: the cost annotations went from "4 muls / ~9 muls" to +"6 muls / ~24 muls" to reflect E6 arithmetic. No behavioral change. + +### `maths/koalabear/field/{ext_test.go, gen_test.go, vec_test.go}` + +- Test helpers updated to cover six coordinates (`extEq`, `extUpperIsZero`). +- `UintsToExt(10,20,30,40)` โ†’ `UintsToExt(10,20,30,40,50,60)`, etc. +- `TestExtFromBytes`: switched to a coordinate-table loop instead of four + hardcoded `B0.A0` / `B0.A1` / `B1.A0` / `B1.A1` blocks (now six). + +### `maths/koalabear/field/ext_bench_test.go` *(new)* + +Micro-benchmarks for diffing E4 vs E6. Covers `Mul`, `Square`, `MulByElement`, +`Inverse`, `BatchInvert`, `ParBatchInvert`, vector add/mul/scale. + +### `maths/koalabear/circuit/ext.go` + +Complete rewrite of the gnark circuit layer. + +- `Ext = { B0, B1, B2 E2 }` where each `E2 = {A0, A1 Element}`. +- New helper `e2MulByCubicNonResidue` for multiplying an E2 by `(u + 1)`. +- `MulExt` reimplemented via Algorithm 13 of : + 6 E2 multiplications (โ‰ˆ24 base muls) instead of 3 (โ‰ˆ9 base muls). +- `SquareExt` reimplemented via Algorithm 16 of the same paper. +- `MulByNonResidueExt` updated: multiplying by `v` shifts coordinates and + wraps the last slot through `e2MulByCubicNonResidue`. +- `InverseExt` / `DivExt` hint signatures expanded from 4 to 6 inputs/outputs + per element. The duplicated hint coordinate marshalling was factored into + `extFromInputs` / `extToOutputs`. +- `NewExtFrom4FrontendVars` renamed to `NewExtFrom6FrontendVars`. +- `BaseValueOfElement`, `IsConstantZeroExt`, `IsZeroExt`, `SelectExt`, + `AssertIsEqualExt`, `SumExt` extended to all six coordinates. +- `ConjugateExt` removed (was unused; the natural cubic-extension Frobenius + is non-trivial and would be misleading to expose without callers). + +### `maths/koalabear/polynomials/{canonical_test.go, lagrange_test.go}` + +- `randExt` delegates to `field.RandomElementExt()` (which already covers + six coordinates). +- `extEq` and the local FFT coordinate-extract loop in + `lagrange_test.go` updated for six coordinates. + +### `maths/koalabear/polynomials/lagrange_bench_test.go` *(new)* + +Benchmarks for `EvalLagrange` and `ComputeLagrangeAtZ` on extension inputs. + +### `crypto/koalabear/fiatshamir/poseidon2.go` + +- `UpdateExt` constant `4` replaced by `field.ExtensionDegree`. +- `RandomFext` now consumes 6 of the 8 hashed Poseidon outputs (positions + 4โ€“5 added) instead of 4. The trailing two octuplet slots are discarded. + +### `crypto/koalabear/reedsolomon/reedsolomon.go` + +- `FFTExt` โ†’ `FFTExt6` and `FFTInverseExt` โ†’ `FFTInverseExt6` everywhere. +- The vector cast in the fast path is now `extensions.VectorE6` instead of + `extensions.Vector` (which still aliases `[]E4` in upstream). + +### `crypto/koalabear/ringsis/ringsis.go` + +- `FFTExt` โ†’ `FFTExt6`, `FFTInverseExt` โ†’ `FFTInverseExt6`. + +### `crypto/koalabear/vortex/utils.go` + +- The hint function now expects chunks of `field.ExtensionDegree (= 6)` big + ints instead of 4. Renamed `ErrSizeNotAMultipleOfFour` โ†’ + `ErrSizeNotAMultipleOfExtDegree` and updated the error message. +- `FFTInverseExt` โ†’ `FFTInverseExt6`. + +### `crypto/koalabear/vortex/verifier_common.go` + +- The verifier now uses the local `polynomials.EvalCanonical` helper, whose + `field.Vec` / `field.Gen` dispatch supports both base-field and E6 + coefficients. +- The opening-check comparison switched from `!=` (compile-error on E6) to + `!y.Equal(&other)`. + +### `wiop/compilers/global/global.go` + +- Replaced the coordinate-by-coordinate `applyBaseFFT4` helper (4 disjoint + base FFTs) by a direct `largeDomain.FFTInverseExt6(..., DIF, OnCoset())` + / `domain.FFTExt6(..., DIT)` call. This: + - eliminates 4 coordinate-slice allocations per proof (the + `scratchC0..C3` fields and their `AllocField` calls are gone); + - traverses the E6 array once per FFT pass instead of six times for the + six coordinates, which is more cache-friendly on HPC nodes. +- `proverBucket` lost the `scratchC0..C3` fields; `Plan` no longer + allocates them. + +### `wiop` query tests and helpers + +Files such as `query_lagrange_eval.go`, `query_lagrange_eval_test.go`, and +`query_vanishing_test.go` did not need structural changes. + +No structural change needed: the test helpers and production code only +access `B0.A0` (the lifted base slot). Those references work unchanged on +E6 because the layout retains the same first-coordinate. + +### `go.mod` + +Dependency was already updated to +`github.com/consensys/gnark-crypto v0.20.2-0.20260514182922-df0578435b08` +which exposes `extensions.E6`, `BatchInvertE6`, `VectorE6`, `FFTExt6`, +`FFTInverseExt6`. + +--- + +## 3. Performance impact + +Microbenchmarks (Apple M5 Max, `-count=3`, `benchstat`). +Baseline = E4 main, after = E6 migration. Times in ns / ยตs / ms as appropriate. + +``` + โ”‚ baseline (E4) โ”‚ after (E6) โ”‚ delta +ExtMul-18 โ”‚ 5.34 ns โ”‚ 22.0 ns โ”‚ +311 % +ExtSquare-18 โ”‚ 4.77 ns โ”‚ 15.7 ns โ”‚ +229 % +ExtMulByBase-18 โ”‚ 2.04 ns โ”‚ 2.89 ns โ”‚ +42 % +ExtInverse-18 โ”‚ 57.7 ns โ”‚ 83.6 ns โ”‚ +45 % +BatchInvertExt n=65536 โ”‚ 1.81 ms โ”‚ 5.70 ms โ”‚ +215 % +ParBatchInvertExt n=65536 โ”‚ 436 ยตs โ”‚ 891 ยตs โ”‚ +104 % +VecAddExtExt n=65536 โ”‚ 109 ยตs โ”‚ 115 ยตs โ”‚ +5 % +VecMulExtExt n=65536 โ”‚ 393 ยตs โ”‚ 1820 ยตs โ”‚ +363 % +VecScaleBaseExt n=65536 โ”‚ 142 ยตs โ”‚ 212 ยตs โ”‚ +49 % +EvalLagrangeExtExt n=16k โ”‚ 669 ยตs โ”‚ 2047 ยตs โ”‚ +206 % +ComputeLagrangeAtZExt 16k โ”‚ 4023 ยตs โ”‚ 5100 ยตs โ”‚ +27 % +geomean (sec/op) โ”‚ 9.18 ยตs โ”‚ 19.9 ยตs โ”‚ +117 % +``` + +### Reading these numbers + +- **Single `Mul` slowed โ‰ˆ4ร—.** The E4 `Mul` in gnark-crypto is a custom + `uint64`-accumulating kernel using 4 Montgomery reductions; the E6 `Mul` + is the textbook Karatsuba-over-E2 (6 E2 muls = โ‰ˆ24 base muls) without an + inlined accumulator. The reduction-per-coordinate cost is real and + unavoidable in the chosen tower. +- **`VecMulExtExt` slowed โ‰ˆ4.6ร—.** The E4 vector kernel has an AVX-512 path + (`vectorMul_avx512`); E6 has no AVX path in gnark-crypto yet. On Apple + silicon that asm is not used anyway (ARM), but on Linux/x86 HPC the gap + will be even larger until an E6 SIMD kernel ships. Future work for + upstream gnark-crypto. +- **`VecAddExtExt` is almost a wash** (+5%). Addition is bandwidth-bound; + growing the working set from 16 bytes/elem to 24 bytes/elem barely + registers because the loop body is already cheap. +- **`ExtMulByBase`** only sees +42% because it's 6 base muls instead of 4. +- **Inversion path** is dominated by `Inverse(E2)` (one) + several E2 muls + (six for the cofactor expansion); +45% closely matches the proportional + arithmetic cost change. +- **Allocation footprint grew โ‰ˆ50%** as expected (24 B/element vs 16 B + before). + +### Implications for HPC throughput + +The hottest paths in the prover are: + +1. **Vortex `TransversalHash`** โ€” pure base-field SIS; unaffected. +2. **Quotient computation** in `wiop/compilers/global` โ€” replaces 4 base FFTs + with 1 E6 FFT (`FFTExt6`). Wall-time should be neutral-to-better thanks + to cache locality, even though the arithmetic is +50% per coord (six + coordinates ร— E6 butterflies, vs four base butterflies). +3. **Reed-Solomon encoding on extension columns** โ€” same `FFTExt6` switch. +4. **Lagrange evaluation** โ€” the new degree dominates: โ‰ˆ3ร— slower per call. + These calls are not in the innermost prover loop, so end-to-end impact + should remain modest, but worth re-benchmarking on the full prover when + ready. + +--- + +## 4. Behavioural compatibility + +- `field.Ext` is a struct value, so direct pointer comparison `!=` (which + existed in `verifier_common.go`) had to be changed to `!Equal(&...)`. + All other call sites only read/write `B0.A0` and similar named accessors, + which carry over cleanly. +- `ConjugateExt` was removed from the circuit API โ€” the natural Frobenius + conjugation in a cubic tower is not the same shape as the previous + "negate B1" definition. No callers were using it. +- Public byte-encoding sizes changed: + - `len(ExtToBytes(z))` = 24 bytes (was 16). + - `BytesToExt` consumes 24 bytes. + This means any persisted artifact serialized with the previous E4 layout + is **not** loadable as E6 โ€” re-encode any saved transcripts or proofs. + +--- + +## 5. Cleanup / simplify pass + +Ran code-reuse, code-quality and efficiency audits on the diff. Findings and +resolutions: + +- **Test-only helper exported**: `ExtToUint64sTuple` in `ext_test.go` was + capitalized despite being test-internal. Renamed to `extToUint64sTuple`. +- **Allocation thrashing in `SumExt`**: the closure-based getter was + allocating six `[]Element` slices per call. Hoisted the scratch slice to + a single allocation reused across the six coordinate reductions. +- **Reuse suggestion: `EvalCanonical` for Horner kernels in vortex**: + after the merge from `main`, `EvalCanonical` is `field.Gen`-aware, so the + verifier uses it directly instead of local Horner kernels. +- **`NewHintExt` native vs emulated DRY-up**: kept the two branches + separate. The element constructors and input/output element types + differ; unifying via reflection or type parameters would obscure rather + than help. + +After the simplify pass: `go vet ./...` clean, `go build ./...` clean, +`go test ./...` passes. + +--- + +## 6. Verification matrix + +``` +go build ./... # OK +go vet ./... # OK +go test ./... -count=1 -timeout=600s + maths/koalabear/{circuit,field,polynomials} OK + crypto/koalabear/{poseidon2,reedsolomon,ringsis,smt,vortex} OK + wiop, wiop/compilers/{global,logderivativesum,lookuptologderivsum,rangecheck}, + wiop/codegen, zkcdriver OK +go test -bench=Benchmark... -count=3 -run=^$ OK +``` + +--- + +## 7. Open follow-ups (not in scope) + +- Upstream gnark-crypto SIMD kernels for E6 vector operations (`Mul`, + `ScalarMul`, `MulByElement`, `Butterfly`). Today, E4 has AVX-512 fast + paths and E6 does not; the gap is most visible on x86 Linux HPC. When + these land, regen baselines and revisit. +- Profile end-to-end prover throughput on a representative trace to confirm + the FFT cache-locality win (moving from 4 base FFTs to 1 E6 FFT) holds + outside the microbench harness. +- Decide whether `field.RootPowers` is still useful; nothing in the repo + consumes it as of this migration. diff --git a/prover-ray/wiop/compilers/global/global.go b/prover-ray/wiop/compilers/global/global.go index 0c54515f1f1..2a3a0a6216a 100644 --- a/prover-ray/wiop/compilers/global/global.go +++ b/prover-ray/wiop/compilers/global/global.go @@ -97,11 +97,7 @@ type proverBucket struct { // Pre-allocated scratch slices populated by Plan; nil until Plan is called. // When non-nil, Run uses these instead of allocating fresh memory. - scratchAgg []field.Ext // aggregate[j], length N = n*ratio - scratchC0 []field.Element // coordinate slice 0 for applyBaseFFT4, length N - scratchC1 []field.Element // coordinate slice 1 for applyBaseFFT4, length N - scratchC2 []field.Element // coordinate slice 2 for applyBaseFFT4, length N - scratchC3 []field.Element // coordinate slice 3 for applyBaseFFT4, length N + scratchAgg []field.Ext // aggregate[j], length N = n*ratio } // verifierBucket holds everything the verifier needs for one ratio bucket. @@ -375,10 +371,6 @@ func (a *QuotientProverAction) Plan(ctx *wiop.PlanningContext) { bkt := &a.buckets[i] N := n * bkt.ratio bkt.scratchAgg = ctx.AllocExt(N) - bkt.scratchC0 = ctx.AllocField(N) - bkt.scratchC1 = ctx.AllocField(N) - bkt.scratchC2 = ctx.AllocField(N) - bkt.scratchC3 = ctx.AllocField(N) } } @@ -389,7 +381,11 @@ func (a *QuotientProverAction) Run(rt wiop.Runtime) { n := a.m.RuntimeSize(rt) if !a.m.IsDynamic() && n != a.m.Size() { - panic(fmt.Sprintf("wiop/compilers: global quotient prover action called with runtime size %d but module size is %d", n, a.m.Size())) + panic(fmt.Sprintf( + "wiop/compilers: global quotient prover action called with runtime size %d but module size is %d", + n, + a.m.Size(), + )) } coinExt := rt.GetCoinValue(a.mergeCoin).Ext @@ -485,11 +481,8 @@ func (a *QuotientProverAction) Run(rt wiop.Runtime) { } // --- IFFT on the large coset: coset evals โ†’ canonical coefficients --- - // Operates component-wise on the 4 base-field components of Ext. - // Use pre-allocated coordinate scratch buffers when available. - applyBaseFFT4(largeDomain, aggregate[:N], func(d *fft.Domain, c []field.Element) { - d.FFTInverse(c, fft.DIF, fft.OnCoset()) - }, bkt.scratchC0, bkt.scratchC1, bkt.scratchC2, bkt.scratchC3) + // FFTInverseExt6 operates directly on the contiguous E6 layout. + largeDomain.FFTInverseExt6(aggregate[:N], fft.DIF, fft.OnCoset()) // --- Split into ratio chunks and FFT each to standard Lagrange form --- for k := range ratio { @@ -571,7 +564,11 @@ func (gv *Verifier) Check(rt wiop.Runtime) error { n := gv.m.RuntimeSize(rt) if !gv.m.IsDynamic() && n != gv.m.Size() { - panic(fmt.Sprintf("wiop/compilers: global quotient Check called with runtime size %d but module size is %d", n, gv.m.Size())) + panic(fmt.Sprintf( + "wiop/compilers: global quotient Check called with runtime size %d but module size is %d", + n, + gv.m.Size(), + )) } r := rt.GetCoinValue(gv.evalCoin) coinExt := rt.GetCoinValue(gv.mergeCoin).Ext @@ -860,47 +857,9 @@ func computeRatio(v *wiop.Vanishing) int { // Extension-field FFT helpers // --------------------------------------------------------------------------- -// extFFT applies the forward standard-domain FFT to the extension-field slice v, -// operating component-wise. After the call, v contains standard Lagrange -// evaluations. +// extFFT applies the forward standard-domain FFT to the extension-field slice +// v. The gnark-crypto FFTExt6 implementation handles the six E6 coordinates +// directly on the contiguous layout. func extFFT(d *fft.Domain, v []field.Ext) { - applyBaseFFT4(d, v, func(d *fft.Domain, c []field.Element) { - d.FFT(c, fft.DIT) - }, nil, nil, nil, nil) -} - -// applyBaseFFT4 deinterleaves v into four base-field coordinate slices, applies -// fn to each, then reassembles the result back into v. If any of c0..c3 is -// non-nil and long enough, it is used as scratch instead of allocating. -func applyBaseFFT4(d *fft.Domain, v []field.Ext, fn func(*fft.Domain, []field.Element), - c0, c1, c2, c3 []field.Element) { - n := len(v) - if len(c0) < n { - c0 = make([]field.Element, n) - } - if len(c1) < n { - c1 = make([]field.Element, n) - } - if len(c2) < n { - c2 = make([]field.Element, n) - } - if len(c3) < n { - c3 = make([]field.Element, n) - } - for i, e := range v { - c0[i] = e.B0.A0 - c1[i] = e.B0.A1 - c2[i] = e.B1.A0 - c3[i] = e.B1.A1 - } - fn(d, c0) - fn(d, c1) - fn(d, c2) - fn(d, c3) - for i := range v { - v[i].B0.A0 = c0[i] - v[i].B0.A1 = c1[i] - v[i].B1.A0 = c2[i] - v[i].B1.A1 = c3[i] - } + d.FFTExt6(v, fft.DIT) }