decode.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. package yu_proto_old
  2. import (
  3. "encoding"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "math"
  8. "reflect"
  9. "time"
  10. )
  11. // Constructors represents a map defining how to instantiate any interface
  12. // types that Decode() might encounter while reading and decoding structured
  13. // data. The keys are reflect.Type values denoting interface types. The
  14. // corresponding values are functions expected to instantiate, and initialize
  15. // as necessary, an appropriate concrete object type supporting that
  16. // interface. A caller could use this capability to support
  17. // dynamic instantiation of objects of the concrete type
  18. // appropriate for a given abstract type.
  19. type Constructors map[reflect.Type]func() interface{}
  20. // String returns an easy way to visualize what you have in your constructors.
  21. func (c *Constructors) String() string {
  22. var s string
  23. for k := range *c {
  24. s += k.String() + "=>" + "(func() interface {})" + "\t"
  25. }
  26. return s
  27. }
  28. // Decoder is the main struct used to decode a protobuf blob.
  29. type decoder struct {
  30. nm Constructors
  31. }
  32. // Decode a protocol buffer into a Go struct.
  33. // The caller must pass a pointer to the struct to decode into.
  34. //
  35. // Decode() currently does not explicitly check that all 'required' fields
  36. // are actually present in the input buffer being decoded.
  37. // If required fields are missing, then the corresponding fields
  38. // will be left unmodified, meaning they will take on
  39. // their default Go zero values if Decode() is passed a fresh struct.
  40. func Decode(buf []byte, structPtr interface{}) error {
  41. return DecodeWithConstructors(buf, structPtr, nil)
  42. }
  43. // DecodeWithConstructors is like Decode, but you can pass a map of
  44. // constructors with which to instantiate interface types.
  45. func DecodeWithConstructors(buf []byte, structPtr interface{}, cons Constructors) (err error) {
  46. defer func() {
  47. if r := recover(); r != nil {
  48. switch e := r.(type) {
  49. case string:
  50. err = errors.New(e)
  51. case error:
  52. err = e
  53. default:
  54. err = errors.New("Failed to decode the field")
  55. }
  56. }
  57. }()
  58. if structPtr == nil {
  59. return nil
  60. }
  61. if bu, ok := structPtr.(encoding.BinaryUnmarshaler); ok {
  62. return bu.UnmarshalBinary(buf)
  63. }
  64. de := decoder{cons}
  65. val := reflect.ValueOf(structPtr)
  66. // if its NOT a pointer, it is bad return an error
  67. if val.Kind() != reflect.Ptr {
  68. return errors.New("Decode has been given a non pointer type")
  69. }
  70. return de.message(buf, val.Elem())
  71. }
  72. // Decode a Protocol Buffers message into a Go struct.
  73. // The Kind of the passed value v must be Struct.
  74. func (de *decoder) message(buf []byte, sval reflect.Value) error {
  75. if sval.Kind() != reflect.Struct {
  76. return errors.New("not a struct")
  77. }
  78. for i := 0; i < sval.NumField(); i++ {
  79. switch field := sval.Field(i); field.Kind() {
  80. case reflect.Interface:
  81. // Interface are not reset because the decoder won't
  82. // be able to instantiate it again in some scenarios.
  83. default:
  84. if field.CanSet() {
  85. field.Set(reflect.Zero(field.Type()))
  86. }
  87. }
  88. }
  89. // Decode all the fields
  90. fields := ProtoFields(sval.Type())
  91. fieldi := 0
  92. for len(buf) > 0 {
  93. // Parse the key
  94. key, n := binary.Uvarint(buf)
  95. if n <= 0 {
  96. return errors.New("bad protobuf field key")
  97. }
  98. buf = buf[n:]
  99. wiretype := int(key & 7)
  100. fieldnum := key >> 3
  101. // Lookup the corresponding struct field.
  102. // Leave field with a zero Value if fieldnum is out-of-range.
  103. // In this case, as well as for blank fields,
  104. // value() will just skip over and discard the field content.
  105. var field reflect.Value
  106. for fieldi < len(fields) && fields[fieldi].ID < int64(fieldnum) {
  107. fieldi++
  108. }
  109. if fieldi < len(fields) && fields[fieldi].ID == int64(fieldnum) {
  110. // For fields within embedded structs, ensure the embedded values aren't nil.
  111. index := fields[fieldi].Index
  112. path := make([]int, 0, len(index))
  113. for _, id := range index {
  114. path = append(path, id)
  115. field = sval.FieldByIndex(path)
  116. if field.Kind() == reflect.Ptr && field.IsNil() {
  117. field.Set(reflect.New(field.Type().Elem()))
  118. }
  119. }
  120. }
  121. // For more debugging output, uncomment the following three lines.
  122. // if fieldi < len(fields){
  123. // fmt.Printf("Decoding FieldName %+v\n", fields[fieldi].Field)
  124. // }
  125. // Decode the field's value
  126. rem, err := de.value(wiretype, buf, field)
  127. if err != nil {
  128. if fieldi < len(fields) && fields[fieldi] != nil {
  129. return fmt.Errorf("Error while decoding field %+v: %v", fields[fieldi].Field, err)
  130. }
  131. return err
  132. }
  133. buf = rem
  134. }
  135. return nil
  136. }
  137. // Pull a value from the buffer and put it into a reflective Value.
  138. func (de *decoder) value(wiretype int, buf []byte,
  139. val reflect.Value) ([]byte, error) {
  140. // Break out the value from the buffer based on the wire type
  141. var v uint64
  142. var n int
  143. var vb []byte
  144. switch wiretype {
  145. case 0: // varint
  146. v, n = binary.Uvarint(buf)
  147. if n <= 0 {
  148. return nil, errors.New("bad protobuf varint value")
  149. }
  150. buf = buf[n:]
  151. case 5: // 32-bit
  152. if len(buf) < 4 {
  153. return nil, errors.New("bad protobuf 32-bit value")
  154. }
  155. v = uint64(buf[0]) |
  156. uint64(buf[1])<<8 |
  157. uint64(buf[2])<<16 |
  158. uint64(buf[3])<<24
  159. buf = buf[4:]
  160. case 1: // 64-bit
  161. if len(buf) < 8 {
  162. return nil, errors.New("bad protobuf 64-bit value")
  163. }
  164. v = uint64(buf[0]) |
  165. uint64(buf[1])<<8 |
  166. uint64(buf[2])<<16 |
  167. uint64(buf[3])<<24 |
  168. uint64(buf[4])<<32 |
  169. uint64(buf[5])<<40 |
  170. uint64(buf[6])<<48 |
  171. uint64(buf[7])<<56
  172. buf = buf[8:]
  173. case 2: // length-delimited
  174. v, n = binary.Uvarint(buf)
  175. if n <= 0 || v > uint64(len(buf)-n) {
  176. return nil, errors.New(
  177. "bad protobuf length-delimited value")
  178. }
  179. vb = buf[n : n+int(v) : n+int(v)]
  180. buf = buf[n+int(v):]
  181. default:
  182. return nil, errors.New("unknown protobuf wire-type")
  183. }
  184. // We've gotten the value out of the buffer,
  185. // now put it into the appropriate reflective Value.
  186. if err := de.putvalue(wiretype, val, v, vb); err != nil {
  187. return nil, err
  188. }
  189. return buf, nil
  190. }
  191. func (de *decoder) decodeSignedInt(wiretype int, v uint64) (int64, error) {
  192. if wiretype == 0 { // encoded as varint
  193. sv := int64(v) >> 1
  194. if v&1 != 0 {
  195. sv = ^sv
  196. }
  197. return sv, nil
  198. } else if wiretype == 5 { // sfixed32
  199. return int64(int32(v)), nil
  200. } else if wiretype == 1 { // sfixed64
  201. return int64(v), nil
  202. } else {
  203. return -1, errors.New("bad wiretype for sint")
  204. }
  205. }
  206. func (de *decoder) putvalue(wiretype int, val reflect.Value,
  207. v uint64, vb []byte) error {
  208. // If val is not settable, it either represents an out-of-range field
  209. // or an in-range but blank (padding) field in the struct.
  210. // In this case, simply ignore and discard the field's content.
  211. if !val.CanSet() {
  212. return nil
  213. }
  214. switch val.Kind() {
  215. case reflect.Bool:
  216. if wiretype != 0 {
  217. return errors.New("bad wiretype for bool")
  218. }
  219. if v > 1 {
  220. return errors.New("invalid bool value")
  221. }
  222. val.SetBool(v != 0)
  223. case reflect.Int, reflect.Int32, reflect.Int64:
  224. // Signed integers may be encoded either zigzag-varint or fixed
  225. // Note that protobufs don't support 8- or 16-bit ints.
  226. if val.Kind() == reflect.Int && val.Type().Size() < 8 {
  227. return errors.New("detected a 32bit machine, please use either int64 or int32")
  228. }
  229. sv, err := de.decodeSignedInt(wiretype, v)
  230. if err != nil {
  231. fmt.Println("Error Reflect.Int for v=", v, "wiretype=", wiretype, "for Value=", val.Type().Name())
  232. return err
  233. }
  234. val.SetInt(sv)
  235. case reflect.Uint, reflect.Uint32, reflect.Uint64:
  236. // Varint-encoded 32-bit and 64-bit unsigned integers.
  237. if val.Kind() == reflect.Uint && val.Type().Size() < 8 {
  238. return errors.New("detected a 32bit machine, please use either uint64 or uint32")
  239. }
  240. if wiretype == 0 {
  241. val.SetUint(v)
  242. } else if wiretype == 5 { // ufixed32
  243. val.SetUint(uint64(uint32(v)))
  244. } else if wiretype == 1 { // ufixed64
  245. val.SetUint(uint64(v))
  246. } else {
  247. return errors.New("bad wiretype for uint")
  248. }
  249. case reflect.Float32:
  250. // Fixed-length 32-bit floats.
  251. if wiretype != 5 {
  252. return errors.New("bad wiretype for float32")
  253. }
  254. val.SetFloat(float64(math.Float32frombits(uint32(v))))
  255. case reflect.Float64:
  256. // Fixed-length 64-bit floats.
  257. if wiretype != 1 {
  258. return errors.New("bad wiretype for float64")
  259. }
  260. val.SetFloat(math.Float64frombits(v))
  261. case reflect.String:
  262. // Length-delimited string.
  263. if wiretype != 2 {
  264. return errors.New("bad wiretype for string")
  265. }
  266. val.SetString(string(vb))
  267. case reflect.Struct:
  268. // Embedded message
  269. if val.Type() == timeType {
  270. sv, err := de.decodeSignedInt(wiretype, v)
  271. if err != nil {
  272. return err
  273. }
  274. t := time.Unix(sv/int64(time.Second), sv%int64(time.Second))
  275. val.Set(reflect.ValueOf(t))
  276. return nil
  277. } else if enc, ok := val.Addr().Interface().(encoding.BinaryUnmarshaler); ok {
  278. return enc.UnmarshalBinary(vb[:])
  279. }
  280. if wiretype != 2 {
  281. return errors.New("bad wiretype for embedded message")
  282. }
  283. return de.message(vb, val)
  284. case reflect.Ptr:
  285. // Optional field
  286. // Instantiate pointer's element type.
  287. if val.IsNil() {
  288. val.Set(de.instantiate(val.Type().Elem()))
  289. }
  290. return de.putvalue(wiretype, val.Elem(), v, vb)
  291. case reflect.Slice, reflect.Array:
  292. // Repeated field or byte-slice
  293. if wiretype != 2 {
  294. return errors.New("bad wiretype for repeated field")
  295. }
  296. return de.slice(val, vb)
  297. case reflect.Map:
  298. if wiretype != 2 {
  299. return errors.New("bad wiretype for repeated field")
  300. }
  301. if val.IsNil() {
  302. // make(map[k]v):
  303. val.Set(reflect.MakeMap(val.Type()))
  304. }
  305. return de.mapEntry(val, vb)
  306. case reflect.Interface:
  307. data := vb[:]
  308. // Abstract field: instantiate via dynamic constructor.
  309. if val.IsNil() {
  310. id := GeneratorID{}
  311. var g InterfaceGeneratorFunc
  312. if len(id) < len(vb) {
  313. copy(id[:], vb[:len(id)])
  314. g = generators.get(id)
  315. }
  316. if g == nil {
  317. // Backwards compatible usage of the default constructors
  318. val.Set(de.instantiate(val.Type()))
  319. } else {
  320. // As pointers to interface are discouraged in Go, we use
  321. // the generator only for interface types
  322. data = vb[len(id):]
  323. val.Set(reflect.ValueOf(g()))
  324. }
  325. }
  326. // If the object support self-decoding, use that.
  327. if enc, ok := val.Interface().(encoding.BinaryUnmarshaler); ok {
  328. if wiretype != 2 {
  329. return errors.New("bad wiretype for bytes")
  330. }
  331. return enc.UnmarshalBinary(data)
  332. }
  333. // Decode into the object the interface points to.
  334. // XXX perhaps better ONLY to support self-decoding
  335. // for interface fields?
  336. return Decode(vb, val.Interface())
  337. default:
  338. panic("unsupported value kind " + val.Kind().String())
  339. }
  340. return nil
  341. }
  342. // Instantiate an arbitrary type, handling dynamic interface types.
  343. // Returns a Ptr value.
  344. func (de *decoder) instantiate(t reflect.Type) reflect.Value {
  345. // If it's an interface type, lookup a dynamic constructor for it.
  346. if t.Kind() == reflect.Interface {
  347. newfunc, ok := de.nm[t]
  348. if !ok {
  349. panic("no constructor for interface " + t.String())
  350. }
  351. return reflect.ValueOf(newfunc())
  352. }
  353. // Otherwise, for all concrete types, just instantiate directly.
  354. return reflect.New(t)
  355. }
  356. var sfixed32type = reflect.TypeOf(Sfixed32(0))
  357. var sfixed64type = reflect.TypeOf(Sfixed64(0))
  358. var ufixed32type = reflect.TypeOf(Ufixed32(0))
  359. var ufixed64type = reflect.TypeOf(Ufixed64(0))
  360. // Handle decoding of slices
  361. func (de *decoder) slice(slval reflect.Value, vb []byte) error {
  362. // Find the element type, and create a temporary instance of it.
  363. eltype := slval.Type().Elem()
  364. val := reflect.New(eltype).Elem()
  365. // Decide on the wiretype to use for decoding.
  366. var wiretype int
  367. switch eltype.Kind() {
  368. case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Int,
  369. reflect.Uint32, reflect.Uint64, reflect.Uint:
  370. if (eltype.Kind() == reflect.Int || eltype.Kind() == reflect.Uint) && eltype.Size() < 8 {
  371. return errors.New("detected a 32bit machine, please either use (u)int64 or (u)int32")
  372. }
  373. switch eltype {
  374. case sfixed32type:
  375. wiretype = 5 // Packed 32-bit representation
  376. case sfixed64type:
  377. wiretype = 1 // Packed 64-bit representation
  378. case ufixed32type:
  379. wiretype = 5 // Packed 32-bit representation
  380. case ufixed64type:
  381. wiretype = 1 // Packed 64-bit representation
  382. default:
  383. wiretype = 0 // Packed varint representation
  384. }
  385. case reflect.Float32:
  386. wiretype = 5 // Packed 32-bit representation
  387. case reflect.Float64:
  388. wiretype = 1 // Packed 64-bit representation
  389. case reflect.Uint8: // Unpacked byte-slice
  390. if slval.Kind() == reflect.Array {
  391. if slval.Len() != len(vb) {
  392. return errors.New("array length and buffer length differ")
  393. }
  394. for i := 0; i < slval.Len(); i++ {
  395. // no SetByte method in reflect so has to pass down by uint64
  396. slval.Index(i).SetUint(uint64(vb[i]))
  397. }
  398. } else {
  399. slval.SetBytes(vb)
  400. }
  401. return nil
  402. default: // Other unpacked repeated types
  403. // Just unpack and append one value from vb.
  404. if err := de.putvalue(2, val, 0, vb); err != nil {
  405. return err
  406. }
  407. if slval.Kind() != reflect.Slice {
  408. return errors.New("append to non-slice")
  409. }
  410. slval.Set(reflect.Append(slval, val))
  411. return nil
  412. }
  413. // Decode packed values from the buffer and append them to the slice.
  414. for len(vb) > 0 {
  415. rem, err := de.value(wiretype, vb, val)
  416. if err != nil {
  417. return err
  418. }
  419. slval.Set(reflect.Append(slval, val))
  420. vb = rem
  421. }
  422. return nil
  423. }
  424. // Handles the entry k,v of a map[K]V
  425. func (de *decoder) mapEntry(slval reflect.Value, vb []byte) error {
  426. mKey := reflect.New(slval.Type().Key())
  427. mVal := reflect.New(slval.Type().Elem())
  428. k := mKey.Elem()
  429. v := mVal.Elem()
  430. key, n := binary.Uvarint(vb)
  431. if n <= 0 {
  432. return errors.New("bad protobuf field key")
  433. }
  434. buf := vb[n:]
  435. wiretype := int(key & 7)
  436. var err error
  437. buf, err = de.value(wiretype, buf, k)
  438. if err != nil {
  439. return err
  440. }
  441. for len(buf) > 0 { // for repeated values (slices etc)
  442. key, n = binary.Uvarint(buf)
  443. if n <= 0 {
  444. return errors.New("bad protobuf field key")
  445. }
  446. buf = buf[n:]
  447. wiretype = int(key & 7)
  448. buf, err = de.value(wiretype, buf, v)
  449. if err != nil {
  450. return err
  451. }
  452. }
  453. if !k.IsValid() || !v.IsValid() {
  454. // We did not decode the key or the value in the map entry.
  455. // Either way, it's an invalid map entry.
  456. return errors.New("proto: bad map data: missing key/val")
  457. }
  458. slval.SetMapIndex(k, v)
  459. return nil
  460. }