diff options
author | Ian Lance Taylor <ian@gcc.gnu.org> | 2011-03-16 23:05:44 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2011-03-16 23:05:44 +0000 |
commit | 5133f00ef8baab894d92de1e8b8baae59815a8b6 (patch) | |
tree | 44176975832a3faf1626836e70c97d5edd674122 /libgo/go/gob | |
parent | f617201f55938fc89b532f2240bdf77bea946471 (diff) | |
download | gcc-5133f00ef8baab894d92de1e8b8baae59815a8b6.tar.gz |
Update to current version of Go library (revision 94d654be2064).
From-SVN: r171076
Diffstat (limited to 'libgo/go/gob')
-rw-r--r-- | libgo/go/gob/codec_test.go | 74 | ||||
-rw-r--r-- | libgo/go/gob/decode.go | 205 | ||||
-rw-r--r-- | libgo/go/gob/decoder.go | 169 | ||||
-rw-r--r-- | libgo/go/gob/doc.go | 48 | ||||
-rw-r--r-- | libgo/go/gob/encode.go | 100 | ||||
-rw-r--r-- | libgo/go/gob/encoder.go | 93 | ||||
-rw-r--r-- | libgo/go/gob/encoder_test.go | 135 | ||||
-rw-r--r-- | libgo/go/gob/type.go | 204 |
8 files changed, 740 insertions, 288 deletions
diff --git a/libgo/go/gob/codec_test.go b/libgo/go/gob/codec_test.go index af941c629cd..c822d6863ac 100644 --- a/libgo/go/gob/codec_test.go +++ b/libgo/go/gob/codec_test.go @@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) { t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes()) } } - decState := newDecodeState(nil, &b) + decState := newDecodeState(nil, b) for u := uint64(0); ; u = (u + 1) * 7 { b.Reset() encState.encodeUint(u) @@ -77,7 +77,7 @@ func verifyInt(i int64, t *testing.T) { var b = new(bytes.Buffer) encState := newEncoderState(nil, b) encState.encodeInt(i) - decState := newDecodeState(nil, &b) + decState := newDecodeState(nil, b) decState.buf = make([]byte, 8) j := decState.decodeInt() if i != j { @@ -315,7 +315,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un func newDecodeStateFromData(data []byte) *decodeState { b := bytes.NewBuffer(data) - state := newDecodeState(nil, &b) + state := newDecodeState(nil, b) state.fieldnum = -1 return state } @@ -342,7 +342,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a int } - instr := &decInstr{decOpMap[reflect.Int], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Int], 6, 0, 0, ovfl} state := newDecodeStateFromData(signedResult) execDec("int", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -355,7 +355,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a uint } - instr := &decInstr{decOpMap[reflect.Uint], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Uint], 6, 0, 0, ovfl} state := newDecodeStateFromData(unsignedResult) execDec("uint", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -446,7 +446,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a uintptr } - instr := &decInstr{decOpMap[reflect.Uintptr], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Uintptr], 6, 0, 0, ovfl} state := newDecodeStateFromData(unsignedResult) execDec("uintptr", instr, state, t, unsafe.Pointer(&data)) if data.a != 17 { @@ -511,7 +511,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a complex64 } - instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Complex64], 6, 0, 0, ovfl} state := newDecodeStateFromData(complexResult) execDec("complex", instr, state, t, unsafe.Pointer(&data)) if data.a != 17+19i { @@ -524,7 +524,7 @@ func TestScalarDecInstructions(t *testing.T) { var data struct { a complex128 } - instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl} + instr := &decInstr{decOpTable[reflect.Complex128], 6, 0, 0, ovfl} state := newDecodeStateFromData(complexResult) execDec("complex", instr, state, t, unsafe.Pointer(&data)) if data.a != 17+19i { @@ -973,18 +973,32 @@ func TestIgnoredFields(t *testing.T) { } } + +func TestBadRecursiveType(t *testing.T) { + type Rec ***Rec + var rec Rec + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(&rec) + if err == nil { + t.Error("expected error; got none") + } else if strings.Index(err.String(), "recursive") < 0 { + t.Error("expected recursive type error; got", err) + } + // Can't test decode easily because we can't encode one, so we can't pass one to a Decoder. +} + type Bad0 struct { - ch chan int - c float64 + CH chan int + C float64 } -var nilEncoder *Encoder func TestInvalidField(t *testing.T) { var bad0 Bad0 - bad0.ch = make(chan int) + bad0.CH = make(chan int) b := new(bytes.Buffer) - err := nilEncoder.encode(b, reflect.NewValue(&bad0)) + var nilEncoder *Encoder + err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) if err == nil { t.Error("expected error; got none") } else if strings.Index(err.String(), "type") < 0 { @@ -1088,11 +1102,11 @@ func (v Vector) Square() int { } type Point struct { - a, b int + X, Y int } func (p Point) Square() int { - return p.a*p.a + p.b*p.b + return p.X*p.X + p.Y*p.Y } // A struct with interfaces in it. @@ -1162,7 +1176,6 @@ func TestInterface(t *testing.T) { } } } - } // A struct with all basic types, stored in interfaces. @@ -1182,7 +1195,7 @@ func TestInterfaceBasic(t *testing.T) { int(1), int8(1), int16(1), int32(1), int64(1), uint(1), uint8(1), uint16(1), uint32(1), uint64(1), float32(1), 1.0, - complex64(0i), complex128(0i), + complex64(1i), complex128(1i), true, "hello", []byte("sailor"), @@ -1307,6 +1320,31 @@ func TestUnexportedFields(t *testing.T) { } } +var singletons = []interface{}{ + true, + 7, + 3.2, + "hello", + [3]int{11, 22, 33}, + []float32{0.5, 0.25, 0.125}, + map[string]int{"one": 1, "two": 2}, +} + +func TestDebugSingleton(t *testing.T) { + if debugFunc == nil { + return + } + b := new(bytes.Buffer) + // Accumulate a number of values and print them out all at once. + for _, x := range singletons { + err := NewEncoder(b).Encode(x) + if err != nil { + t.Fatal("encode:", err) + } + } + debugFunc(b) +} + // A type that won't be defined in the gob until we send it in an interface value. type OnTheFly struct { A int @@ -1325,7 +1363,7 @@ type DT struct { S []string } -func TestDebug(t *testing.T) { +func TestDebugStruct(t *testing.T) { if debugFunc == nil { return } diff --git a/libgo/go/gob/decode.go b/libgo/go/gob/decode.go index 2db75215c19..8f599e10041 100644 --- a/libgo/go/gob/decode.go +++ b/libgo/go/gob/decode.go @@ -30,15 +30,17 @@ type decodeState struct { dec *Decoder // The buffer is stored with an extra indirection because it may be replaced // if we load a type during decode (when reading an interface value). - b **bytes.Buffer + b *bytes.Buffer fieldnum int // the last field number read. buf []byte } -func newDecodeState(dec *Decoder, b **bytes.Buffer) *decodeState { +// We pass the bytes.Buffer separately for easier testing of the infrastructure +// without requiring a full Decoder. +func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState { d := new(decodeState) d.dec = dec - d.b = b + d.b = buf d.buf = make([]byte, uint64Size) return d } @@ -49,14 +51,15 @@ func overflow(name string) os.ErrorString { // decodeUintReader reads an encoded unsigned integer from an io.Reader. // Used only by the Decoder to read the message length. -func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) { - _, err = r.Read(buf[0:1]) +func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Error) { + width = 1 + _, err = r.Read(buf[0:width]) if err != nil { return } b := buf[0] if b <= 0x7f { - return uint64(b), nil + return uint64(b), width, nil } nb := -int(int8(b)) if nb > uint64Size { @@ -75,6 +78,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) { for i := 0; i < n; i++ { x <<= 8 x |= uint64(buf[i]) + width++ } return } @@ -405,10 +409,9 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { return *(*uintptr)(up) } -func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { - defer catchError(&err) - p = allocate(rtyp, p, indir) - state := newDecodeState(dec, b) +func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { + p = allocate(ut.base, p, ut.indir) + state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField basep := p delta := int(state.decodeUint()) @@ -424,10 +427,13 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes return nil } -func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { - defer catchError(&err) - p = allocate(rtyp, p, indir) - state := newDecodeState(dec, b) +// Indir is for the value, not the type. At the time of the call it may +// differ from ut.indir, which was computed when the engine was built. +// This state cannot arise for decodeSingle, which is called directly +// from the user's value, not from the innards of an engine. +func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) { + p = allocate(ut.base.(*reflect.StructType), p, indir) + state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 basep := p for state.b.Len() > 0 { @@ -454,9 +460,8 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b return nil } -func (dec *Decoder) ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) { - defer catchError(&err) - state := newDecodeState(dec, b) +func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { + state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 for state.b.Len() > 0 { delta := int(state.decodeUint()) @@ -477,6 +482,18 @@ func (dec *Decoder) ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Er return nil } +func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { + state := newDecodeState(dec, &dec.buf) + state.fieldnum = singletonField + delta := int(state.decodeUint()) + if delta != 0 { + errorf("gob decode: corrupted data: non-zero delta for singleton") + } + instr := &engine.instr[singletonField] + instr.op(instr, state, unsafe.Pointer(nil)) + return nil +} + func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} for i := 0; i < length; i++ { @@ -501,7 +518,7 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} - up := unsafe.Pointer(v.Addr()) + up := unsafe.Pointer(v.UnsafeAddr()) if indir > 1 { up = decIndirect(up, indir) } @@ -612,9 +629,17 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt if !ok { errorf("gob: name not registered for interface: %q", name) } + // Read the type id of the concrete value. + concreteId := dec.decodeTypeSequence(true) + if concreteId < 0 { + error(dec.err) + } + // Byte count of value is next; we don't care what it is (it's there + // in case we want to ignore the value by skipping it completely). + state.decodeUint() // Read the concrete value. value := reflect.MakeZero(typ) - dec.decodeValueFromBuffer(value, false, true) + dec.decodeValue(concreteId, value) if dec.err != nil { error(dec.err) } @@ -637,14 +662,16 @@ func (dec *Decoder) ignoreInterface(state *decodeState) { if err != nil { error(err) } - dec.decodeValueFromBuffer(nil, true, true) - if dec.err != nil { - error(err) + id := dec.decodeTypeSequence(true) + if id < 0 { + error(dec.err) } + // At this point, the decoder buffer contains a delimited value. Just toss it. + state.b.Next(int(state.decodeUint())) } // Index by Go types. -var decOpMap = []decOp{ +var decOpTable = [...]decOp{ reflect.Bool: decBool, reflect.Int8: decInt8, reflect.Int16: decInt16, @@ -674,35 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{ // Return the decoding op for the base type under rt and // the indirection count to reach it. -func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) { - typ, indir := indirect(rt) +func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { + ut := userType(rt) + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } + typ := ut.base + indir := ut.indir var op decOp k := typ.Kind() - if int(k) < len(decOpMap) { - op = decOpMap[k] + if int(k) < len(decOpTable) { + op = decOpTable[k] } if op == nil { + inProgress[rt] = &op // Special cases switch t := typ.(type) { case *reflect.ArrayType: name = "element of " + name elemId := dec.wireType[wireId].ArrayT.Elem - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } case *reflect.MapType: name = "element of " + name keyId := dec.wireType[wireId].MapT.Key elemId := dec.wireType[wireId].MapT.Elem - keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name) - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { up := unsafe.Pointer(p) - state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) + state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl) } case *reflect.SliceType: @@ -717,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } else { elemId = dec.wireType[wireId].SliceT.Elem } - elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) + elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: @@ -730,8 +765,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp error(err) } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - // indirect through enginePtr to delay evaluation for recursive structs - err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir) + // indirect through enginePtr to delay evaluation for recursive structs. + err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir) if err != nil { error(err) } @@ -745,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp if op == nil { errorf("gob: decode can't handle type %s", rt.String()) } - return op, indir + return &op, indir } // Return the decoding op for a field that has no destination. @@ -796,7 +831,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs - state.dec.ignoreStruct(*enginePtr, state.b) + state.dec.ignoreStruct(*enginePtr) } } } @@ -809,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { // Are these two gob Types compatible? // Answers the question for basic types, arrays, and slices. // Structs are considered ok; fields will be checked later. -func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { - fr, _ = indirect(fr) +func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool { + if rhs, ok := inProgress[fr]; ok { + return rhs == fw + } + inProgress[fr] = fw + fr = userType(fr).base switch t := fr.(type) { default: - // map, chan, etc: cannot handle. + // chan, etc: cannot handle. return false case *reflect.BoolType: return fw == tBool @@ -835,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { return false } array := wire.ArrayT - return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) + return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) case *reflect.MapType: wire, ok := dec.wireType[fw] if !ok || wire.MapT == nil { return false } MapType := wire.MapT - return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem) + return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress) case *reflect.SliceType: // Is it an array of bytes? if t.Elem().Kind() == reflect.Uint8 { @@ -855,8 +894,8 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { } else { sw = dec.wireType[fw].SliceT } - elem, _ := indirect(t.Elem()) - return sw != nil && dec.compatibleType(elem, sw.Elem) + elem := userType(t.Elem()).base + return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) case *reflect.StructType: return true } @@ -877,12 +916,22 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item name := rt.String() // best we can do - if !dec.compatibleType(rt, remoteId) { + if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) { return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId)) } - op, indir := dec.decOpFor(remoteId, rt, name) + op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp)) ovfl := os.ErrorString(`value for "` + name + `" out of range`) - engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} + engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl} + engine.numInstr = 1 + return +} + +func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) { + engine = new(decEngine) + engine.instr = make([]decInstr, 1) // one item + op := dec.decIgnoreOpFor(remoteId) + ovfl := overflow(dec.typeString(remoteId)) + engine.instr[0] = decInstr{op, 0, 0, 0, ovfl} engine.numInstr = 1 return } @@ -894,7 +943,6 @@ func isExported(name string) bool { } func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { - defer catchError(&err) srt, ok := rt.(*reflect.StructType) if !ok { return dec.compileSingle(remoteId, rt) @@ -905,13 +953,18 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng if t, ok := builtinIdToType[remoteId]; ok { wireStruct, _ = t.(*structType) } else { - wireStruct = dec.wireType[remoteId].StructT + wire := dec.wireType[remoteId] + if wire == nil { + error(errBadType) + } + wireStruct = wire.StructT } if wireStruct == nil { errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String()) } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.Field)) + seen := make(map[reflect.Type]*decOp) // Loop over the fields of the wire type. for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { wireField := wireStruct.Field[fieldnum] @@ -927,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl} continue } - if !dec.compatibleType(localField.Type, wireField.Id) { + if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } - op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name) - engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl} + op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) + engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl} engine.numInstr++ } return @@ -966,7 +1019,12 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er // To handle recursive types, mark this engine as underway before compiling. enginePtr = new(*decEngine) dec.ignorerCache[wireId] = enginePtr - *enginePtr, err = dec.compileDec(wireId, emptyStructType) + wire := dec.wireType[wireId] + if wire != nil && wire.StructT != nil { + *enginePtr, err = dec.compileDec(wireId, emptyStructType) + } else { + *enginePtr, err = dec.compileIgnoreSingle(wireId) + } if err != nil { dec.ignorerCache[wireId] = nil, false } @@ -974,22 +1032,41 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er return } -func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error { +func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) { + defer catchError(&err) + // If the value is nil, it means we should just ignore this item. + if val == nil { + return dec.decodeIgnoredValue(wireId) + } // Dereference down to the underlying struct type. - rt, indir := indirect(val.Type()) - enginePtr, err := dec.getDecEnginePtr(wireId, rt) + ut := userType(val.Type()) + base := ut.base + indir := ut.indir + enginePtr, err := dec.getDecEnginePtr(wireId, base) if err != nil { return err } engine := *enginePtr - if st, ok := rt.(*reflect.StructType); ok { + if st, ok := base.(*reflect.StructType); ok { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { - name := rt.Name() + name := base.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) } - return dec.decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir) + return dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir) + } + return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) +} + +func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error { + enginePtr, err := dec.getIgnoreEnginePtr(wireId) + if err != nil { + return err + } + wire := dec.wireType[wireId] + if wire != nil && wire.StructT != nil { + return dec.ignoreStruct(*enginePtr) } - return dec.decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir) + return dec.ignoreSingle(*enginePtr) } func init() { @@ -1004,8 +1081,8 @@ func init() { default: panic("gob: unknown size of int/uint") } - decOpMap[reflect.Int] = iop - decOpMap[reflect.Uint] = uop + decOpTable[reflect.Int] = iop + decOpTable[reflect.Uint] = uop // Finally uintptr switch reflect.Typeof(uintptr(0)).Bits() { @@ -1016,5 +1093,5 @@ func init() { default: panic("gob: unknown size of uintptr") } - decOpMap[reflect.Uintptr] = uop + decOpTable[reflect.Uintptr] = uop } diff --git a/libgo/go/gob/decoder.go b/libgo/go/gob/decoder.go index 664001a4b21..f7c994ffa78 100644 --- a/libgo/go/gob/decoder.go +++ b/libgo/go/gob/decoder.go @@ -17,14 +17,13 @@ import ( type Decoder struct { mutex sync.Mutex // each item must be received atomically r io.Reader // source of the data + buf bytes.Buffer // buffer for more efficient i/o from r wireType map[typeId]*wireType // map from remote ID to local description decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines ignorerCache map[typeId]**decEngine // ditto for ignored objects - state *decodeState // reads data from in-memory buffer countState *decodeState // reads counts from wire - buf []byte - countBuf [9]byte // counts may be uint64s (unlikely!), require 9 bytes - byteBuffer *bytes.Buffer + countBuf []byte // used for decoding integers while parsing messages + tmp []byte // temporary storage for i/o; saves reallocating err os.Error } @@ -33,128 +32,160 @@ func NewDecoder(r io.Reader) *Decoder { dec := new(Decoder) dec.r = r dec.wireType = make(map[typeId]*wireType) - dec.state = newDecodeState(dec, &dec.byteBuffer) // buffer set in Decode() dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine) dec.ignorerCache = make(map[typeId]**decEngine) + dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes return dec } -// recvType loads the definition of a type and reloads the Decoder's buffer. +// recvType loads the definition of a type. func (dec *Decoder) recvType(id typeId) { // Have we already seen this type? That's an error - if dec.wireType[id] != nil { + if id < firstUserId || dec.wireType[id] != nil { dec.err = os.ErrorString("gob: duplicate type received") return } // Type: wire := new(wireType) - dec.err = dec.decode(tWireType, reflect.NewValue(wire)) + dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire)) if dec.err != nil { return } // Remember we've seen this type. dec.wireType[id] = wire - - // Load the next parcel. - dec.recv() } -// Decode reads the next value from the connection and stores -// it in the data represented by the empty interface value. -// The value underlying e must be the correct type for the next -// data item received, and must be a pointer. -func (dec *Decoder) Decode(e interface{}) os.Error { - value := reflect.NewValue(e) - // If e represents a value as opposed to a pointer, the answer won't - // get back to the caller. Make sure it's a pointer. - if value.Type().Kind() != reflect.Ptr { - dec.err = os.ErrorString("gob: attempt to decode into a non-pointer") - return dec.err +// recvMessage reads the next count-delimited item from the input. It is the converse +// of Encoder.writeMessage. It returns false on EOF or other error reading the message. +func (dec *Decoder) recvMessage() bool { + // Read a count. + nbytes, _, err := decodeUintReader(dec.r, dec.countBuf) + if err != nil { + dec.err = err + return false } - return dec.DecodeValue(value) + dec.readMessage(int(nbytes)) + return dec.err == nil } -// recv reads the next count-delimited item from the input. It is the converse -// of Encoder.send. -func (dec *Decoder) recv() { - // Read a count. - var nbytes uint64 - nbytes, dec.err = decodeUintReader(dec.r, dec.countBuf[0:]) - if dec.err != nil { - return - } +// readMessage reads the next nbytes bytes from the input. +func (dec *Decoder) readMessage(nbytes int) { // Allocate the buffer. - if nbytes > uint64(len(dec.buf)) { - dec.buf = make([]byte, nbytes+1000) + if cap(dec.tmp) < nbytes { + dec.tmp = make([]byte, nbytes+100) // room to grow } - dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes]) + dec.tmp = dec.tmp[:nbytes] // Read the data - _, dec.err = io.ReadFull(dec.r, dec.buf[0:nbytes]) + _, dec.err = io.ReadFull(dec.r, dec.tmp) if dec.err != nil { if dec.err == os.EOF { dec.err = io.ErrUnexpectedEOF } return } + dec.buf.Write(dec.tmp) } -// decodeValueFromBuffer grabs the next value from the input. The Decoder's -// buffer already contains data. If the next item in the buffer is a type -// descriptor, it may be necessary to reload the buffer, but recvType does that. -func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignoreInterfaceValue, countPresent bool) { - for dec.state.b.Len() > 0 { - // Receive a type id. - id := typeId(dec.state.decodeInt()) +// toInt turns an encoded uint64 into an int, according to the marshaling rules. +func toInt(x uint64) int64 { + i := int64(x >> 1) + if x&1 != 0 { + i = ^i + } + return i +} + +func (dec *Decoder) nextInt() int64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return toInt(n) +} - // Is it a new type? - if id < 0 { // 0 is the error state, handled above - // If the id is negative, we have a type. - dec.recvType(-id) - if dec.err != nil { +func (dec *Decoder) nextUint() uint64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return n +} + +// decodeTypeSequence parses: +// TypeSequence +// (TypeDefinition DelimitedTypeDefinition*)? +// and returns the type id of the next value. It returns -1 at +// EOF. Upon return, the remainder of dec.buf is the value to be +// decoded. If this is an interface value, it can be ignored by +// simply resetting that buffer. +func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId { + for dec.err == nil { + if dec.buf.Len() == 0 { + if !dec.recvMessage() { break } - continue } - - // Make sure the type has been defined already or is a builtin type (for - // top-level singleton values). - if dec.wireType[id] == nil && builtinIdToType[id] == nil { - dec.err = errBadType - break + // Receive a type id. + id := typeId(dec.nextInt()) + if id >= 0 { + // Value follows. + return id } - // An interface value is preceded by a byte count. - if countPresent { - count := int(dec.state.decodeUint()) - if ignoreInterfaceValue { - // An interface value is preceded by a byte count. Just skip that many bytes. - dec.state.b.Next(int(count)) + // Type definition for (-id) follows. + dec.recvType(-id) + // When decoding an interface, after a type there may be a + // DelimitedValue still in the buffer. Skip its count. + // (Alternatively, the buffer is empty and the byte count + // will be absorbed by recvMessage.) + if dec.buf.Len() > 0 { + if !isInterface { + dec.err = os.ErrorString("extra data in buffer") break } - // Otherwise fall through and decode it. + dec.nextUint() } - dec.err = dec.decode(id, value) - break } + return -1 +} + +// Decode reads the next value from the connection and stores +// it in the data represented by the empty interface value. +// If e is nil, the value will be discarded. Otherwise, +// the value underlying e must either be the correct type for the next +// data item received, and must be a pointer. +func (dec *Decoder) Decode(e interface{}) os.Error { + if e == nil { + return dec.DecodeValue(nil) + } + value := reflect.NewValue(e) + // If e represents a value as opposed to a pointer, the answer won't + // get back to the caller. Make sure it's a pointer. + if value.Type().Kind() != reflect.Ptr { + dec.err = os.ErrorString("gob: attempt to decode into a non-pointer") + return dec.err + } + return dec.DecodeValue(value) } // DecodeValue reads the next value from the connection and stores // it in the data represented by the reflection value. // The value must be the correct type for the next -// data item received. +// data item received, or it may be nil, which means the +// value will be discarded. func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { // Make sure we're single-threaded through here. dec.mutex.Lock() defer dec.mutex.Unlock() + dec.buf.Reset() // In case data lingers from previous invocation. dec.err = nil - dec.recv() - if dec.err != nil { - return dec.err + id := dec.decodeTypeSequence(false) + if dec.err == nil { + dec.err = dec.decodeValue(id, value) } - dec.decodeValueFromBuffer(value, false, false) return dec.err } diff --git a/libgo/go/gob/doc.go b/libgo/go/gob/doc.go index 31253f16d09..613974a000f 100644 --- a/libgo/go/gob/doc.go +++ b/libgo/go/gob/doc.go @@ -220,6 +220,54 @@ be predefined or be defined before the value in the stream. package gob /* +Grammar: + +Tokens starting with a lower case letter are terminals; int(n) +and uint(n) represent the signed/unsigned encodings of the value n. + +GobStream: + DelimitedMessage* +DelimitedMessage: + uint(lengthOfMessage) Message +Message: + TypeSequence TypedValue +TypeSequence + (TypeDefinition DelimitedTypeDefinition*)? +DelimitedTypeDefinition: + uint(lengthOfTypeDefinition) TypeDefinition +TypedValue: + int(typeId) Value +TypeDefinition: + int(-typeId) encodingOfWireType +Value: + SingletonValue | StructValue +SingletonValue: + uint(0) FieldValue +FieldValue: + builtinValue | ArrayValue | MapValue | SliceValue | StructValue | InterfaceValue +InterfaceValue: + NilInterfaceValue | NonNilInterfaceValue +NilInterfaceValue: + uint(0) +NonNilInterfaceValue: + ConcreteTypeName TypeSequence InterfaceContents +ConcreteTypeName: + uint(lengthOfName) [already read=n] name +InterfaceContents: + int(concreteTypeId) DelimitedValue +DelimitedValue: + uint(length) Value +ArrayValue: + uint(n) FieldValue*n [n elements] +MapValue: + uint(n) (FieldValue FieldValue)*n [n (key, value) pairs] +SliceValue: + uint(n) FieldValue*n [n elements] +StructValue: + (uint(fieldDelta) FieldValue)* +*/ + +/* For implementers and the curious, here is an encoded example. Given type Point struct {x, y int} and the value diff --git a/libgo/go/gob/encode.go b/libgo/go/gob/encode.go index d286a7e00b8..e92db74ffdd 100644 --- a/libgo/go/gob/encode.go +++ b/libgo/go/gob/encode.go @@ -264,9 +264,6 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { } } -func encNoOp(i *encInstr, state *encoderState, p unsafe.Pointer) { -} - // Byte arrays are encoded as an unsigned count followed by the raw bytes. func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*[]byte)(p) @@ -359,7 +356,7 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in if v == nil { errorf("gob: encodeReflectValue: nil element") } - op(nil, state, unsafe.Pointer(v.Addr())) + op(nil, state, unsafe.Pointer(v.UnsafeAddr())) } func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { @@ -387,10 +384,10 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) return } - typ, _ := indirect(iv.Elem().Type()) - name, ok := concreteTypeToName[typ] + ut := userType(iv.Elem().Type()) + name, ok := concreteTypeToName[ut.base] if !ok { - errorf("gob: type not registered for interface: %s", typ) + errorf("gob: type not registered for interface: %s", ut.base) } // Send the name. state.encodeUint(uint64(len(name))) @@ -398,22 +395,26 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) if err != nil { error(err) } - // Send (and maybe first define) the type id. - enc.sendTypeDescriptor(typ) - // Encode the value into a new buffer. + // Define the type id if necessary. + enc.sendTypeDescriptor(enc.writer(), state, ut) + // Send the type id. + enc.sendTypeId(state, ut) + // Encode the value into a new buffer. Any nested type definitions + // should be written to b, before the encoded value. + enc.pushWriter(b) data := new(bytes.Buffer) - err = enc.encode(data, iv.Elem()) + err = enc.encode(data, iv.Elem(), ut) if err != nil { error(err) } - state.encodeUint(uint64(data.Len())) - _, err = state.b.Write(data.Bytes()) - if err != nil { + enc.popWriter() + enc.writeMessage(b, data) + if enc.err != nil { error(err) } } -var encOpMap = []encOp{ +var encOpTable = [...]encOp{ reflect.Bool: encBool, reflect.Int: encInt, reflect.Int8: encInt8, @@ -433,16 +434,24 @@ var encOpMap = []encOp{ reflect.String: encString, } -// Return the encoding op for the base type under rt and +// Return (a pointer to) the encoding op for the base type under rt and // the indirection count to reach it. -func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { - typ, indir := indirect(rt) - var op encOp +func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { + ut := userType(rt) + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). + // Return the pointer to the op we're already building. + if opPtr := inProgress[rt]; opPtr != nil { + return opPtr, ut.indir + } + typ := ut.base + indir := ut.indir k := typ.Kind() - if int(k) < len(encOpMap) { - op = encOpMap[k] + var op encOp + if int(k) < len(encOpTable) { + op = encOpTable[k] } if op == nil { + inProgress[rt] = &op // Special cases switch t := typ.(type) { case *reflect.SliceType: @@ -451,25 +460,25 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { break } // Slices have a header; we decode it to find the underlying array. - elemOp, indir := enc.encOpFor(t.Elem()) + elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { slice := (*reflect.SliceHeader)(p) if !state.sendZero && slice.Len == 0 { return } state.update(i) - state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) + state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len)) } case *reflect.ArrayType: // True arrays have size in the type. - elemOp, indir := enc.encOpFor(t.Elem()) + elemOp, indir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) - state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) + state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len()) } case *reflect.MapType: - keyOp, keyIndir := enc.encOpFor(t.Key()) - elemOp, elemIndir := enc.encOpFor(t.Elem()) + keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress) + elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for @@ -480,7 +489,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { return } state.update(i) - state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) + state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. @@ -508,28 +517,31 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { if op == nil { errorf("gob enc: can't happen: encode type %s", rt.String()) } - return op, indir + return &op, indir } // The local Type was compiled from the actual value, so we know it's compatible. func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { srt, isStruct := rt.(*reflect.StructType) engine := new(encEngine) + seen := make(map[reflect.Type]*encOp) if isStruct { - engine.instr = make([]encInstr, srt.NumField()+1) // +1 for terminator - for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ { - f := srt.Field(fieldnum) - op, indir := enc.encOpFor(f.Type) + for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ { + f := srt.Field(fieldNum) if !isExported(f.Name) { - op = encNoOp + continue } - engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)} + op, indir := enc.encOpFor(f.Type, seen) + engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)}) + } + if srt.NumField() > 0 && len(engine.instr) == 0 { + errorf("type %s has no exported fields", rt) } - engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0} + engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) } else { engine.instr = make([]encInstr, 1) - op, indir := enc.encOpFor(rt) - engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero + op, indir := enc.encOpFor(rt, seen) + engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero } return engine } @@ -556,18 +568,16 @@ func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine { return enc.getEncEngine(rt) } -func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) (err os.Error) { +func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) { defer catchError(&err) - // Dereference down to the underlying object. - rt, indir := indirect(value.Type()) - for i := 0; i < indir; i++ { + for i := 0; i < ut.indir; i++ { value = reflect.Indirect(value) } - engine := enc.lockAndGetEncEngine(rt) + engine := enc.lockAndGetEncEngine(ut.base) if value.Type().Kind() == reflect.Struct { - enc.encodeStruct(b, engine, value.Addr()) + enc.encodeStruct(b, engine, value.UnsafeAddr()) } else { - enc.encodeSingle(b, engine, value.Addr()) + enc.encodeSingle(b, engine, value.UnsafeAddr()) } return nil } diff --git a/libgo/go/gob/encoder.go b/libgo/go/gob/encoder.go index 8869b262982..92d036c11c3 100644 --- a/libgo/go/gob/encoder.go +++ b/libgo/go/gob/encoder.go @@ -16,9 +16,8 @@ import ( // other side of a connection. type Encoder struct { mutex sync.Mutex // each item must be sent atomically - w io.Writer // where to send the data + w []io.Writer // where to send the data sent map[reflect.Type]typeId // which types we've already sent - state *encoderState // so we can encode integers, strings directly countState *encoderState // stage for writing counts buf []byte // for collecting the output. err os.Error @@ -27,13 +26,27 @@ type Encoder struct { // NewEncoder returns a new encoder that will transmit on the io.Writer. func NewEncoder(w io.Writer) *Encoder { enc := new(Encoder) - enc.w = w + enc.w = []io.Writer{w} enc.sent = make(map[reflect.Type]typeId) - enc.state = newEncoderState(enc, new(bytes.Buffer)) enc.countState = newEncoderState(enc, new(bytes.Buffer)) return enc } +// writer() returns the innermost writer the encoder is using +func (enc *Encoder) writer() io.Writer { + return enc.w[len(enc.w)-1] +} + +// pushWriter adds a writer to the encoder. +func (enc *Encoder) pushWriter(w io.Writer) { + enc.w = append(enc.w, w) +} + +// popWriter pops the innermost writer. +func (enc *Encoder) popWriter() { + enc.w = enc.w[0 : len(enc.w)-1] +} + func (enc *Encoder) badType(rt reflect.Type) { enc.setError(os.ErrorString("gob: can't encode type " + rt.String())) } @@ -42,16 +55,14 @@ func (enc *Encoder) setError(err os.Error) { if enc.err == nil { // remember the first. enc.err = err } - enc.state.b.Reset() } -// Send the data item preceded by a unsigned count of its length. -func (enc *Encoder) send() { - // Encode the length. - enc.countState.encodeUint(uint64(enc.state.b.Len())) +// writeMessage sends the data item preceded by a unsigned count of its length. +func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { + enc.countState.encodeUint(uint64(b.Len())) // Build the buffer. countLen := enc.countState.b.Len() - total := countLen + enc.state.b.Len() + total := countLen + b.Len() if total > len(enc.buf) { enc.buf = make([]byte, total+1000) // extra for growth } @@ -59,17 +70,18 @@ func (enc *Encoder) send() { // TODO(r): avoid the extra copy here. enc.countState.b.Read(enc.buf[0:countLen]) // Now the data. - enc.state.b.Read(enc.buf[countLen:total]) + b.Read(enc.buf[countLen:total]) // Write the data. - _, err := enc.w.Write(enc.buf[0:total]) + _, err := w.Write(enc.buf[0:total]) if err != nil { enc.setError(err) } } -func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { +func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { // Drill down to the base type. - rt, _ := indirect(origt) + ut := userType(origt) + rt := ut.base switch rt := rt.(type) { default: @@ -112,10 +124,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { } // Send the pair (-id, type) // Id: - enc.state.encodeInt(-int64(info.id)) + state.encodeInt(-int64(info.id)) // Type: - enc.encode(enc.state.b, reflect.NewValue(info.wire)) - enc.send() + enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.writeMessage(w, state.b) if enc.err != nil { return } @@ -128,10 +140,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { switch st := rt.(type) { case *reflect.StructType: for i := 0; i < st.NumField(); i++ { - enc.sendType(st.Field(i).Type) + enc.sendType(w, state, st.Field(i).Type) } case reflect.ArrayOrSliceType: - enc.sendType(st.Elem()) + enc.sendType(w, state, st.Elem()) } return true } @@ -142,15 +154,16 @@ func (enc *Encoder) Encode(e interface{}) os.Error { return enc.EncodeValue(reflect.NewValue(e)) } -// sendTypeId makes sure the remote side knows about this type. +// sendTypeDescriptor makes sure the remote side knows about this type. // It will send a descriptor if this is the first time the type has been -// sent. Regardless, it sends the id. -func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) { +// sent. +func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { // Make sure the type is known to the other side. - // First, have we already sent this type? - if _, alreadySent := enc.sent[rt]; !alreadySent { + // First, have we already sent this (base) type? + base := ut.base + if _, alreadySent := enc.sent[base]; !alreadySent { // No, so send it. - sent := enc.sendType(rt) + sent := enc.sendType(w, state, base) if enc.err != nil { return } @@ -159,18 +172,21 @@ func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) { // need to send the type info but we do need to update enc.sent. if !sent { typeLock.Lock() - info, err := getTypeInfo(rt) + info, err := getTypeInfo(base) typeLock.Unlock() if err != nil { enc.setError(err) return } - enc.sent[rt] = info.id + enc.sent[base] = info.id } } +} +// sendTypeId sends the id, which must have already been defined. +func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) { // Identify the type of this top-level value. - enc.state.encodeInt(int64(enc.sent[rt])) + state.encodeInt(int64(enc.sent[ut.base])) } // EncodeValue transmits the data item represented by the reflection value, @@ -181,26 +197,29 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { enc.mutex.Lock() defer enc.mutex.Unlock() - enc.err = nil - rt, _ := indirect(value.Type()) + // Remove any nested writers remaining due to previous errors. + enc.w = enc.w[0:1] - // Sanity check only: encoder should never come in with data present. - if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 { - enc.err = os.ErrorString("encoder: buffer not empty") - return enc.err + ut, err := validUserType(value.Type()) + if err != nil { + return err } - enc.sendTypeDescriptor(rt) + enc.err = nil + state := newEncoderState(enc, new(bytes.Buffer)) + + enc.sendTypeDescriptor(enc.writer(), state, ut) + enc.sendTypeId(state, ut) if enc.err != nil { return enc.err } // Encode the object. - err := enc.encode(enc.state.b, value) + err = enc.encode(state.b, value, ut) if err != nil { enc.setError(err) } else { - enc.send() + enc.writeMessage(enc.writer(), state.b) } return enc.err diff --git a/libgo/go/gob/encoder_test.go b/libgo/go/gob/encoder_test.go index c2309352a09..a0c713b81df 100644 --- a/libgo/go/gob/encoder_test.go +++ b/libgo/go/gob/encoder_test.go @@ -6,6 +6,7 @@ package gob import ( "bytes" + "fmt" "io" "os" "reflect" @@ -120,7 +121,7 @@ func corruptDataCheck(s string, err os.Error, t *testing.T) { dec := NewDecoder(b) err1 := dec.Decode(new(ET2)) if err1 != err { - t.Error("expected error", err, "got", err1) + t.Errorf("from %q expected error %s; got %s", s, err, err1) } } @@ -220,7 +221,7 @@ func TestSlice(t *testing.T) { func TestValueError(t *testing.T) { // Encode a *T, decode a T type Type4 struct { - a int + A int } t4p := &Type4{3} var t4 Type4 // note: not a pointer. @@ -248,6 +249,24 @@ func TestArray(t *testing.T) { } } +func TestRecursiveMapType(t *testing.T) { + type recursiveMap map[string]recursiveMap + r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil} + r2 := make(recursiveMap) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + +func TestRecursiveSliceType(t *testing.T) { + type recursiveSlice []recursiveSlice + r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil} + r2 := make(recursiveSlice, 0) + if err := encAndDec(r1, &r2); err != nil { + t.Error(err) + } +} + // Regression test for bug: must send zero values inside arrays func TestDefaultsInArray(t *testing.T) { type Type7 struct { @@ -383,3 +402,115 @@ func TestInterfaceIndirect(t *testing.T) { t.Fatal("decode error:", err) } } + +// Now follow various tests that decode into things that can't represent the +// encoded value, all of which should be legal. + +// Also, when the ignored object contains an interface value, it may define +// types. Make sure that skipping the value still defines the types by using +// the encoder/decoder pair to send a value afterwards. If an interface +// is sent, its type in the test is always NewType0, so this checks that the +// encoder and decoder don't skew with respect to type definitions. + +type Struct0 struct { + I interface{} +} + +type NewType0 struct { + S string +} + +type ignoreTest struct { + in, out interface{} +} + +var ignoreTests = []ignoreTest{ + // Decode normal struct into an empty struct + {&struct{ A int }{23}, &struct{}{}}, + // Decode normal struct into a nil. + {&struct{ A int }{23}, nil}, + // Decode singleton string into a nil. + {"hello, world", nil}, + // Decode singleton slice into a nil. + {[]int{1, 2, 3, 4}, nil}, + // Decode struct containing an interface into a nil. + {&Struct0{&NewType0{"value0"}}, nil}, + // Decode singleton slice of interfaces into a nil. + {[]interface{}{"hi", &NewType0{"value1"}, 23}, nil}, +} + +func TestDecodeIntoNothing(t *testing.T) { + Register(new(NewType0)) + for i, test := range ignoreTests { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(test.in) + if err != nil { + t.Errorf("%d: encode error %s:", i, err) + continue + } + dec := NewDecoder(b) + err = dec.Decode(test.out) + if err != nil { + t.Errorf("%d: decode error: %s", i, err) + continue + } + // Now see if the encoder and decoder are in a consistent state. + str := fmt.Sprintf("Value %d", i) + err = enc.Encode(&NewType0{str}) + if err != nil { + t.Fatalf("%d: NewType0 encode error: %s", i, err) + } + ns := new(NewType0) + err = dec.Decode(ns) + if err != nil { + t.Fatalf("%d: NewType0 decode error: %s", i, err) + } + if ns.S != str { + t.Fatalf("%d: expected %q got %q", i, str, ns.S) + } + } +} + +// Another bug from golang-nuts, involving nested interfaces. +type Bug0Outer struct { + Bug0Field interface{} +} + +type Bug0Inner struct { + A int +} + +func TestNestedInterfaces(t *testing.T) { + var buf bytes.Buffer + e := NewEncoder(&buf) + d := NewDecoder(&buf) + Register(new(Bug0Outer)) + Register(new(Bug0Inner)) + f := &Bug0Outer{&Bug0Outer{&Bug0Inner{7}}} + var v interface{} = f + err := e.Encode(&v) + if err != nil { + t.Fatal("Encode:", err) + } + err = d.Decode(&v) + if err != nil { + t.Fatal("Decode:", err) + } + // Make sure it decoded correctly. + outer1, ok := v.(*Bug0Outer) + if !ok { + t.Fatalf("v not Bug0Outer: %T", v) + } + outer2, ok := outer1.Bug0Field.(*Bug0Outer) + if !ok { + t.Fatalf("v.Bug0Field not Bug0Outer: %T", outer1.Bug0Field) + } + inner, ok := outer2.Bug0Field.(*Bug0Inner) + if !ok { + t.Fatalf("v.Bug0Field.Bug0Field not Bug0Inner: %T", outer2.Bug0Field) + } + if inner.A != 7 { + t.Fatalf("final value %d; expected %d", inner.A, 7) + } +} diff --git a/libgo/go/gob/type.go b/libgo/go/gob/type.go index 22502a6e6b9..6e3f148b4e7 100644 --- a/libgo/go/gob/type.go +++ b/libgo/go/gob/type.go @@ -11,13 +11,76 @@ import ( "sync" ) -// Reflection types are themselves interface values holding structs -// describing the type. Each type has a different struct so that struct can -// be the kind. For example, if typ is the reflect type for an int8, typ is -// a pointer to a reflect.Int8Type struct; if typ is the reflect type for a -// function, typ is a pointer to a reflect.FuncType struct; we use the type -// of that pointer as the kind. +// userTypeInfo stores the information associated with a type the user has handed +// to the package. It's computed once and stored in a map keyed by reflection +// type. +type userTypeInfo struct { + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type +} + +var ( + // Protected by an RWMutex because we read it a lot and write + // it only when we see a new type, typically when compiling. + userTypeLock sync.RWMutex + userTypeCache = make(map[reflect.Type]*userTypeInfo) +) + +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { + userTypeLock.RLock() + ut = userTypeCache[rt] + userTypeLock.RUnlock() + if ut != nil { + return + } + // Now set the value under the write lock. + userTypeLock.Lock() + defer userTypeLock.Unlock() + if ut = userTypeCache[rt]; ut != nil { + // Lost the race; not a problem. + return + } + ut = new(userTypeInfo) + ut.base = rt + ut.user = rt + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt, ok := ut.base.(*reflect.PtrType) + if !ok { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.(*reflect.PtrType).Elem() + } + ut.indir++ + } + userTypeCache[rt] = ut + return +} +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error(err) + } + return ut +} // A typeId represents a gob Type as an integer that can be passed on the wire. // Internally, typeIds are used as keys to a map to recover the underlying type info. type typeId int32 @@ -93,7 +156,7 @@ var ( tBool = bootstrapType("bool", false, 1) tInt = bootstrapType("int", int(0), 2) tUint = bootstrapType("uint", uint(0), 3) - tFloat = bootstrapType("float", 0.0, 4) + tFloat = bootstrapType("float", float64(0), 4) tBytes = bootstrapType("bytes", make([]byte, 0), 5) tString = bootstrapType("string", "", 6) tComplex = bootstrapType("complex", 0+0i, 7) @@ -110,6 +173,7 @@ var ( // Predefined because it's needed by the Decoder var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) func init() { // Some magic numbers to make sure there are no surprises. @@ -133,6 +197,7 @@ func init() { } nextId = firstUserId registerBasics() + wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) } // Array type @@ -142,12 +207,18 @@ type arrayType struct { Len int } -func newArrayType(name string, elem gobType, length int) *arrayType { - a := &arrayType{CommonType{Name: name}, elem.id(), length} - setTypeId(a) +func newArrayType(name string) *arrayType { + a := &arrayType{CommonType{Name: name}, 0, 0} return a } +func (a *arrayType) init(elem gobType, len int) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(a) + a.Elem = elem.id() + a.Len = len +} + func (a *arrayType) safeString(seen map[typeId]bool) string { if seen[a.Id] { return a.Name @@ -165,12 +236,18 @@ type mapType struct { Elem typeId } -func newMapType(name string, key, elem gobType) *mapType { - m := &mapType{CommonType{Name: name}, key.id(), elem.id()} - setTypeId(m) +func newMapType(name string) *mapType { + m := &mapType{CommonType{Name: name}, 0, 0} return m } +func (m *mapType) init(key, elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(m) + m.Key = key.id() + m.Elem = elem.id() +} + func (m *mapType) safeString(seen map[typeId]bool) string { if seen[m.Id] { return m.Name @@ -189,12 +266,17 @@ type sliceType struct { Elem typeId } -func newSliceType(name string, elem gobType) *sliceType { - s := &sliceType{CommonType{Name: name}, elem.id()} - setTypeId(s) +func newSliceType(name string) *sliceType { + s := &sliceType{CommonType{Name: name}, 0} return s } +func (s *sliceType) init(elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(s) + s.Elem = elem.id() +} + func (s *sliceType) safeString(seen map[typeId]bool) string { if seen[s.Id] { return s.Name @@ -236,26 +318,26 @@ func (s *structType) string() string { return s.safeString(make(map[typeId]bool) func newStructType(name string) *structType { s := &structType{CommonType{Name: name}, nil} + // For historical reasons we set the id here rather than init. + // Se the comment in newTypeObject for details. setTypeId(s) return s } -// Step through the indirections on a type to discover the base type. -// Return the base type and the number of indirections. -func indirect(t reflect.Type) (rt reflect.Type, count int) { - rt = t - for { - pt, ok := rt.(*reflect.PtrType) - if !ok { - break - } - rt = pt.Elem() - count++ - } - return +func (s *structType) init(field []*fieldType) { + s.Field = field } func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { + var err os.Error + var type0, type1 gobType + defer func() { + if err != nil { + types[rt] = nil, false + } + }() + // Install the top-level type before the subtypes (e.g. struct before + // fields) so recursive types can be constructed safely. switch t := rt.(type) { // All basic types are easy: they are predefined. case *reflect.BoolType: @@ -280,47 +362,62 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { return tInterface.gobType(), nil case *reflect.ArrayType: - gt, err := getType("", t.Elem()) + at := newArrayType(name) + types[rt] = at + type0, err = getType("", t.Elem()) if err != nil { return nil, err } - return newArrayType(name, gt, t.Len()), nil + // Historical aside: + // For arrays, maps, and slices, we set the type id after the elements + // are constructed. This is to retain the order of type id allocation after + // a fix made to handle recursive types, which changed the order in + // which types are built. Delaying the setting in this way preserves + // type ids while allowing recursive types to be described. Structs, + // done below, were already handling recursion correctly so they + // assign the top-level id before those of the field. + at.init(type0, t.Len()) + return at, nil case *reflect.MapType: - kt, err := getType("", t.Key()) + mt := newMapType(name) + types[rt] = mt + type0, err = getType("", t.Key()) if err != nil { return nil, err } - vt, err := getType("", t.Elem()) + type1, err = getType("", t.Elem()) if err != nil { return nil, err } - return newMapType(name, kt, vt), nil + mt.init(type0, type1) + return mt, nil case *reflect.SliceType: // []byte == []uint8 is a special case if t.Elem().Kind() == reflect.Uint8 { return tBytes.gobType(), nil } - gt, err := getType(t.Elem().Name(), t.Elem()) + st := newSliceType(name) + types[rt] = st + type0, err = getType(t.Elem().Name(), t.Elem()) if err != nil { return nil, err } - return newSliceType(name, gt), nil + st.init(type0) + return st, nil case *reflect.StructType: - // Install the struct type itself before the fields so recursive - // structures can be constructed safely. - strType := newStructType(name) - types[rt] = strType - idToType[strType.id()] = strType + st := newStructType(name) + types[rt] = st + idToType[st.id()] = st field := make([]*fieldType, t.NumField()) for i := 0; i < t.NumField(); i++ { f := t.Field(i) - typ, _ := indirect(f.Type) + typ := userType(f.Type).base tname := typ.Name() if tname == "" { - t, _ := indirect(f.Type) + t := userType(f.Type).base tname = t.String() } gt, err := getType(tname, f.Type) @@ -329,8 +426,8 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { } field[i] = &fieldType{f.Name, gt.id()} } - strType.Field = field - return strType, nil + st.init(field) + return st, nil default: return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()) @@ -341,7 +438,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { // getType returns the Gob type describing the given reflect.Type. // typeLock must be held. func getType(name string, rt reflect.Type) (gobType, os.Error) { - rt, _ = indirect(rt) + rt = userType(rt).base typ, present := types[rt] if present { return typ, nil @@ -371,6 +468,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { types[rt] = typ setTypeId(typ) checkId(expect, nextId) + userType(rt) // might as well cache it now return nextId } @@ -381,7 +479,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { // For bootstrapping purposes, we assume that the recipient knows how // to decode a wireType; it is exactly the wireType struct here, interpreted // using the gob rules for sending a structure, except that we assume the -// ids for wireType and structType are known. The relevant pieces +// ids for wireType and structType etc. are known. The relevant pieces // are built in encode.go's init() function. // To maintain binary compatibility, if you extend this type, always put // the new fields last. @@ -473,18 +571,18 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } - rt, _ := indirect(reflect.Typeof(value)) + base := userType(reflect.Typeof(value)).base // Check for incompatible duplicates. - if t, ok := nameToConcreteType[name]; ok && t != rt { + if t, ok := nameToConcreteType[name]; ok && t != base { panic("gob: registering duplicate types for " + name) } - if n, ok := concreteTypeToName[rt]; ok && n != name { - panic("gob: registering duplicate names for " + rt.String()) + if n, ok := concreteTypeToName[base]; ok && n != name { + panic("gob: registering duplicate names for " + base.String()) } // Store the name and type provided by the user.... nameToConcreteType[name] = reflect.Typeof(value) // but the flattened type in the type table, since that's what decode needs. - concreteTypeToName[rt] = name + concreteTypeToName[base] = name } // Register records a type, identified by a value for that type, under its @@ -530,7 +628,7 @@ func registerBasics() { Register(uint32(0)) Register(uint64(0)) Register(float32(0)) - Register(0.0) + Register(float64(0)) Register(complex64(0i)) Register(complex128(0i)) Register(false) |