1
2
3
4
5 package rsa
6
7
8
9 import (
10 "bytes"
11 "crypto"
12 "errors"
13 "hash"
14 "io"
15 "math/big"
16 )
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
32
33
34 hLen := hash.Size()
35 sLen := len(salt)
36 emLen := (emBits + 7) / 8
37
38
39
40
41
42
43
44 if len(mHash) != hLen {
45 return nil, errors.New("crypto/rsa: input must be hashed with given hash")
46 }
47
48
49
50 if emLen < hLen+sLen+2 {
51 return nil, errors.New("crypto/rsa: key size too small for PSS signature")
52 }
53
54 em := make([]byte, emLen)
55 psLen := emLen - sLen - hLen - 2
56 db := em[:psLen+1+sLen]
57 h := em[psLen+1+sLen : emLen-1]
58
59
60
61
62
63
64
65
66
67
68
69
70 var prefix [8]byte
71
72 hash.Write(prefix[:])
73 hash.Write(mHash)
74 hash.Write(salt)
75
76 h = hash.Sum(h[:0])
77 hash.Reset()
78
79
80
81
82
83
84
85 db[psLen] = 0x01
86 copy(db[psLen+1:], salt)
87
88
89
90
91
92 mgf1XOR(db, hash, h)
93
94
95
96
97 db[0] &= 0xff >> (8*emLen - emBits)
98
99
100 em[emLen-1] = 0xbc
101
102
103 return em, nil
104 }
105
106 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
107
108
109 hLen := hash.Size()
110 if sLen == PSSSaltLengthEqualsHash {
111 sLen = hLen
112 }
113 emLen := (emBits + 7) / 8
114 if emLen != len(em) {
115 return errors.New("rsa: internal error: inconsistent length")
116 }
117
118
119
120
121
122
123 if hLen != len(mHash) {
124 return ErrVerification
125 }
126
127
128 if emLen < hLen+sLen+2 {
129 return ErrVerification
130 }
131
132
133
134 if em[emLen-1] != 0xbc {
135 return ErrVerification
136 }
137
138
139
140 db := em[:emLen-hLen-1]
141 h := em[emLen-hLen-1 : emLen-1]
142
143
144
145
146 var bitMask byte = 0xff >> (8*emLen - emBits)
147 if em[0] & ^bitMask != 0 {
148 return ErrVerification
149 }
150
151
152
153
154 mgf1XOR(db, hash, h)
155
156
157
158 db[0] &= bitMask
159
160
161 if sLen == PSSSaltLengthAuto {
162 psLen := bytes.IndexByte(db, 0x01)
163 if psLen < 0 {
164 return ErrVerification
165 }
166 sLen = len(db) - psLen - 1
167 }
168
169
170
171
172
173 psLen := emLen - hLen - sLen - 2
174 for _, e := range db[:psLen] {
175 if e != 0x00 {
176 return ErrVerification
177 }
178 }
179 if db[psLen] != 0x01 {
180 return ErrVerification
181 }
182
183
184 salt := db[len(db)-sLen:]
185
186
187
188
189
190
191
192 var prefix [8]byte
193 hash.Write(prefix[:])
194 hash.Write(mHash)
195 hash.Write(salt)
196
197 h0 := hash.Sum(nil)
198
199
200 if !bytes.Equal(h0, h) {
201 return ErrVerification
202 }
203 return nil
204 }
205
206
207
208
209
210 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
211 emBits := priv.N.BitLen() - 1
212 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
213 if err != nil {
214 return nil, err
215 }
216 m := new(big.Int).SetBytes(em)
217 c, err := decryptAndCheck(rand, priv, m)
218 if err != nil {
219 return nil, err
220 }
221 s := make([]byte, priv.Size())
222 return c.FillBytes(s), nil
223 }
224
225 const (
226
227
228 PSSSaltLengthAuto = 0
229
230
231 PSSSaltLengthEqualsHash = -1
232 )
233
234
235 type PSSOptions struct {
236
237
238
239 SaltLength int
240
241
242
243
244 Hash crypto.Hash
245 }
246
247
248 func (opts *PSSOptions) HashFunc() crypto.Hash {
249 return opts.Hash
250 }
251
252 func (opts *PSSOptions) saltLength() int {
253 if opts == nil {
254 return PSSSaltLengthAuto
255 }
256 return opts.SaltLength
257 }
258
259
260
261
262
263
264 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
265 if opts != nil && opts.Hash != 0 {
266 hash = opts.Hash
267 }
268
269 saltLength := opts.saltLength()
270 switch saltLength {
271 case PSSSaltLengthAuto:
272 saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
273 case PSSSaltLengthEqualsHash:
274 saltLength = hash.Size()
275 }
276
277 salt := make([]byte, saltLength)
278 if _, err := io.ReadFull(rand, salt); err != nil {
279 return nil, err
280 }
281 return signPSSWithSalt(rand, priv, hash, digest, salt)
282 }
283
284
285
286
287
288
289
290 func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
291 if len(sig) != pub.Size() {
292 return ErrVerification
293 }
294 s := new(big.Int).SetBytes(sig)
295 m := encrypt(new(big.Int), pub, s)
296 emBits := pub.N.BitLen() - 1
297 emLen := (emBits + 7) / 8
298 if m.BitLen() > emLen*8 {
299 return ErrVerification
300 }
301 em := m.FillBytes(make([]byte, emLen))
302 return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
303 }
304
View as plain text