ProtobufWriter.java (15470B) download
1package protobuf;
2
3import java.io.IOException;
4import java.io.InputStream;
5import java.nio.charset.StandardCharsets;
6import java.util.Iterator;
7import java.util.function.BinaryOperator;
8import java.util.function.Consumer;
9import java.util.function.Function;
10import java.util.function.Supplier;
11import java.util.function.UnaryOperator;
12
13import protobuf.exception.InputException;
14import protobuf.exception.OverflowException;
15import protobuf.exception.UnexpectedTagException;
16import protobuf.exception.WireTypeException;
17
18/**
19 * Represents an interface for parsing Protobuf wire elements, providing methods
20 * for reading various Protobuf data types.
21 */
22public class ProtobufWriter {
23 private final MessageIterator message;
24 private final WireType type;
25 private final int tag;
26 private boolean resetType = false;
27
28 public ProtobufReader(MessageIterator message, WireType type, int tag) {
29 this.message = message;
30 this.type = type;
31 this.tag = tag;
32 }
33
34 /**
35 * Gets the underlying input stream.
36 *
37 * @return The input stream.
38 */
39 public InputStream getInputStream() {
40 return message.input;
41 }
42
43 /**
44 * Gets the Protobuf wire type.
45 *
46 * @return The wire type.
47 */
48 public WireType getType() {
49 return resetType ? null : type;
50 }
51
52 /**
53 * Resets the Protobuf wire type.
54 * Useful for parsing packed streams.
55 */
56 public void resetType() {
57 resetType = true;
58 }
59
60 /**
61 * Gets the tag associated with the wire element.
62 *
63 * @return The wire tag.
64 */
65 public int tag() {
66 return tag;
67 }
68
69 /**
70 * Reads a signed variable-length integer as a 64-bit integer.
71 *
72 * @return The parsed signed 64-bit integer.
73 */
74 public long svarint64() {
75 long n = varint64();
76 return (n & 0x01) == 0
77 ? (n >> 1)
78 : -(n >> 1) - 1;
79 }
80
81 /**
82 * Parses a variable-length integer as a signed 64-bit integer.
83 *
84 * @return The parsed signed 64-bit integer.
85 */
86 public long varint64() {
87 return varint64(false);
88 }
89
90 /**
91 * Reads a signed variable-length integer as a 32-bit integer.
92 *
93 * @return The parsed signed 32-bit integer.
94 */
95 public long varint64(boolean ignoreType) {
96 if (!ignoreType && getType() != null && getType() != WireType.VARINT)
97 throw new WireTypeException(WireType.VARINT, getType());
98
99 long result = 0;
100 long b = 0;
101 int shift = 0;
102 while (shift < 64 && message.length > 0) {
103 try {
104 b = message.input.read();
105 } catch (IOException exc) {
106 throw new InputException(exc);
107 }
108 if (b == -1)
109 break;
110
111 message.length--;
112
113 result |= (b & 0x7f) << shift;
114 shift += 7;
115 if ((b & 0x80) == 0)
116 return result;
117 }
118
119 throw new OverflowException("input exceed");
120 }
121
122 public long svarint32() {
123 int n = varint32();
124 return (n & 0x01) == 0
125 ? (n >> 1)
126 : -(n >> 1);
127
128 }
129
130 /**
131 * Parses a variable-length integer as a signed 32-bit integer.
132 *
133 * @return The parsed signed 32-bit integer.
134 */
135 public int varint32() {
136 return varint32(false);
137 }
138
139 private int varint32(boolean ignoreType) {
140 if (!ignoreType && getType() != null && getType() != WireType.VARINT)
141 throw new WireTypeException(WireType.VARINT, getType());
142
143 int result = 0;
144 int b = 0;
145 int shift = 0;
146 while (shift < 32 && message.length > 0) {
147 try {
148 b = message.input.read();
149 } catch (IOException exc) {
150 throw new InputException(exc);
151 }
152 if (b == -1)
153 break;
154
155 message.length--;
156
157 result |= (b & 0x7f) << shift;
158 shift += 7;
159 if ((b & 0x80) == 0)
160 return result;
161 }
162 throw new OverflowException("input exceed");
163 }
164
165 /**
166 * Skips the variable-length integer.
167 */
168 public void skipVarint() {
169 if (getType() != null && getType() != WireType.VARINT)
170 throw new WireTypeException(WireType.VARINT, getType());
171
172 int b = 0;
173 while (message.length > 0) {
174 try {
175 b = message.input.read();
176 } catch (IOException exc) {
177 throw new InputException(exc);
178 }
179 if (b == -1)
180 break;
181
182 message.length--;
183 if ((b & 0x80) == 0)
184 return;
185 }
186 throw new OverflowException("input exceed");
187 }
188
189 /**
190 * Reads a fixed 64-bit integer.
191 *
192 * @return The parsed 64-bit integer.
193 */
194 public long fixed64() {
195 if (getType() != null && getType() != WireType.I64)
196 throw new WireTypeException(WireType.I64, getType());
197
198 if (message.length < 8)
199 throw new OverflowException("input exceed");
200
201 byte[] bytes;
202 try {
203 bytes = message.input.readNBytes(8);
204 } catch (IOException exc) {
205 throw new InputException(exc);
206 }
207 long result = 0;
208
209 for (int i = bytes.length - 1; i >= 0; i--) {
210 result <<= 8;
211 result |= bytes[i];
212 }
213
214 return result;
215 }
216
217 /**
218 * Skips a fixed 64-bit integer.
219 */
220 public void skip64() {
221 if (getType() != null && getType() != WireType.I64)
222 throw new WireTypeException(WireType.I64, getType());
223
224 if (message.length < 8)
225 throw new OverflowException("input exceed");
226
227 message.length -= 8;
228 try {
229 message.input.skipNBytes(8);
230 } catch (IOException exc) {
231 throw new InputException(exc);
232 }
233 }
234
235 /**
236 * Reads a fixed 32-bit integer.
237 *
238 * @return The parsed 32-bit integer.
239 */
240 public int fixed32() {
241 if (getType() != null && getType() != WireType.I32)
242 throw new WireTypeException(WireType.I32, getType());
243
244 if (message.length < 4)
245 throw new OverflowException("input exceed");
246
247 byte[] bytes;
248 try {
249 bytes = message.input.readNBytes(4);
250 } catch (IOException exc) {
251 throw new InputException(exc);
252 }
253 int result = 0;
254
255 for (int i = bytes.length - 1; i >= 0; i--) {
256 result <<= 8;
257 result |= bytes[i];
258 }
259
260 return result;
261 }
262
263 /**
264 * Skips a fixed 32-bit integer.
265 */
266 public void skip32() {
267 if (getType() != null && getType() != WireType.I32)
268 throw new WireTypeException(WireType.I32, getType());
269
270 if (message.length < 4)
271 throw new OverflowException("input exceed");
272
273 message.length -= 4;
274 try {
275 message.input.skipNBytes(4);
276 } catch (IOException exc) {
277 throw new InputException(exc);
278 }
279 }
280
281 /**
282 * Reads a byte array.
283 *
284 * @return The read byte array.
285 */
286 public byte[] bytes() {
287 if (getType() != null && getType() != WireType.LEN)
288 throw new WireTypeException(WireType.LEN, getType());
289
290 int len = varint32(true);
291 if (message.length < len)
292 throw new OverflowException("input exceed");
293
294 message.length -= len;
295 try {
296 return message.input.readNBytes(len);
297 } catch (IOException exc) {
298 throw new InputException(exc);
299 }
300 }
301
302 /**
303 * Skips a byte array.
304 */
305 public void skipBytes() {
306 if (getType() != null && getType() != WireType.LEN)
307 throw new WireTypeException(WireType.LEN, getType());
308
309 int len = varint32(true);
310 if (message.length < len)
311 throw new OverflowException("input exceed");
312
313 message.length -= len;
314 try {
315 message.input.skipNBytes(len);
316 } catch (IOException exc) {
317 throw new InputException(exc);
318 }
319 }
320
321 /**
322 * Reads a string.
323 *
324 * @return The read string.
325 */
326 public String string() {
327 return new String(bytes(), StandardCharsets.UTF_8);
328 }
329
330 /**
331 * Reads a message using the provided handler.
332 *
333 * @param handler The message handler.
334 * @param <T> The type of the parsed message.
335 * @return The parsed message.
336 */
337 public <T> T message(Message<T> handler) {
338 if (getType() != null && getType() != WireType.LEN)
339 throw new WireTypeException(WireType.LEN, getType());
340
341 int len = varint32(true);
342 if (message.length < len)
343 throw new OverflowException("input exceed");
344
345 message.length -= len;
346
347 return handler.parse(message.input, len);
348 }
349
350 /**
351 * Reads a message using the provided handler and applies a mapping function to
352 * the byte array.
353 *
354 * @param handler The message handler.
355 * @param map The mapping function for the byte array.
356 * @param <T> The type of the parsed message.
357 * @return The parsed message.
358 */
359 public <T> T message(Message<T> handler, UnaryOperator<byte[]> map) {
360 byte[] buffer = bytes();
361 return handler.parse(map.apply(buffer));
362 }
363
364 /**
365 * Creates an iterator for packed values.
366 *
367 * @param scalar The scalar supplier.
368 * @param <T> The type of the scalar values.
369 * @return The iterator for packed values.
370 */
371 public <T> Iterator<T> packed(Supplier<T> scalar) {
372 return packed(scalar, v -> v);
373 }
374
375 /**
376 * Creates an iterator for packed values, applying a mapping function to each
377 * scalar value.
378 *
379 * @param scalar The scalar supplier.
380 * @param map The mapping function for scalar values.
381 * @param <T> The type of the scalar values.
382 * @param <M> The type of the mapped values.
383 * @return The iterator for packed values.
384 */
385 public <T, M> Iterator<M> packed(Supplier<T> scalar, Function<T, M> map) {
386 if (getType() != null && getType() != WireType.LEN)
387 throw new WireTypeException(WireType.LEN, getType());
388
389 int len = varint32(true);
390 if (message.length < len)
391 throw new OverflowException("input exceed");
392
393 int end = message.length - len;
394
395 resetType();
396
397 return new Iterator<>() {
398 public boolean hasNext() {
399 if (message.length < end)
400 throw new OverflowException("packed string overused");
401 return message.length > end;
402 }
403
404 public M next() {
405 return map.apply(scalar.get());
406 }
407 };
408 }
409
410 /**
411 * Creates an iterator for packed values with an initial value and a binary
412 * operator.
413 *
414 * @param scalar The scalar supplier.
415 * @param init The initial value.
416 * @param operator The binary operator.
417 * @param <T> The type of the scalar values.
418 * @return The iterator for packed values.
419 */
420 public <T> Iterator<T> packed(Supplier<T> scalar, T init, BinaryOperator<T> operator) {
421 return packed(scalar, v -> v, init, operator);
422 }
423
424 /**
425 * Creates an iterator for packed values with an initial value, a binary
426 * operator, and a mapping function.
427 *
428 * @param scalar The scalar supplier.
429 * @param map The mapping function for scalar values.
430 * @param init The initial value.
431 * @param operator The binary operator.
432 * @param <T> The type of the scalar values.
433 * @param <M> The type of the mapped values.
434 * @return The iterator for packed values.
435 */
436 public <T, M> Iterator<M> packed(Supplier<T> scalar, Function<T, M> map, M init, BinaryOperator<M> operator) {
437 if (getType() != null && getType() != WireType.LEN)
438 throw new WireTypeException(WireType.LEN, getType());
439
440 int len = varint32(true);
441 if (message.length < len)
442 throw new OverflowException("input exceed");
443
444 int end = message.length - len;
445
446 resetType();
447
448 return new Iterator<>() {
449 M value = init;
450
451 public boolean hasNext() {
452 if (message.length < end)
453 throw new OverflowException("packed string overused");
454 return message.length > end;
455 }
456
457 public M next() {
458 return value = operator.apply(value, map.apply(scalar.get()));
459 }
460 };
461 }
462
463 /**
464 * Defers the execution of a consumer with a supplied value.
465 *
466 * @param supplier The value supplier.
467 * @param defer The consumer to be deferred.
468 * @param <T> The type of the supplied value.
469 */
470 public <T> void delayed(Supplier<T> supplier, Consumer<T> defer) {
471 T buffer = supplier.get();
472 message.delayed.add(() -> defer.accept(buffer));
473 }
474
475 /**
476 * Defers the execution of a consumer with a mapped supplied value.
477 *
478 * @param supplier The value supplier.
479 * @param map The mapping function for the supplied value.
480 * @param defer The consumer to be deferred.
481 * @param <T> The type of the supplied value.
482 */
483 public <T> void delayed(Supplier<T> supplier, UnaryOperator<T> map, Consumer<T> defer) {
484 T buffer = supplier.get();
485 message.delayed.add(() -> defer.accept(map.apply(buffer)));
486 }
487
488 /**
489 * Defers the execution of a consumer with a parsed message using the provided
490 * handler.
491 *
492 * @param handler The message handler.
493 * @param defer The consumer to be deferred.
494 * @param <T> The type of the parsed message.
495 */
496 public <T> void delayed(Message<T> handler, Consumer<T> defer) {
497 byte[] buffer = bytes();
498 message.delayed.add(() -> defer.accept(handler.parse(buffer)));
499 }
500
501 /**
502 * Defers the execution of a consumer with a mapped parsed message using the
503 * provided handler.
504 *
505 * @param handler The message handler.
506 * @param map The mapping function for the parsed message.
507 * @param defer The consumer to be deferred.
508 * @param <T> The type of the parsed message.
509 */
510 public <T> void delayed(Message<T> handler, UnaryOperator<byte[]> map, Consumer<T> defer) {
511 byte[] buffer = bytes();
512 message.delayed.add(() -> defer.accept(handler.parse(map.apply(buffer))));
513 }
514
515 /**
516 * Skips the current wire element based on its type.
517 */
518 public void skip() {
519 switch (getType()) {
520 case VARINT:
521 skipVarint();
522 break;
523 case I64:
524 skip64();
525 break;
526 case LEN:
527 skipBytes();
528 break;
529 case I32:
530 skip32();
531 break;
532 case SGROUP:
533 case EGROUP:
534 throw new UnsupportedOperationException("cannot skip sgroup of egroup");
535 }
536 }
537
538 /**
539 * Throws an {@link UnexpectedTagException} for the current tag.
540 */
541 public void throwUnexpected() {
542 throw new UnexpectedTagException(tag);
543 }
544}
545}