Skip to content

feature #384: Specify custom map key sort function #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions api_tests/marshal_json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"bytes"
"encoding/json"
"github.com/json-iterator/go"
"testing"
"github.com/stretchr/testify/require"
"testing"
)


type Foo struct {
Bar interface{}
}
Expand All @@ -19,11 +18,10 @@ func (f Foo) MarshalJSON() ([]byte, error) {
return buf.Bytes(), err
}


// Standard Encoder has trailing newline.
func TestEncodeMarshalJSON(t *testing.T) {

foo := Foo {
foo := Foo{
Bar: 123,
}
should := require.New(t)
Expand Down
73 changes: 73 additions & 0 deletions extension_tests/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,79 @@ func Test_customize_map_key_encoder(t *testing.T) {
should.Equal(map[int]int{1: 2}, m)
}

// Test using custom encoder with sorted map keys.
// Keys should be numerically sorted AFTER the custom key encoder runs.
func Test_customize_map_key_encoder_with_sorted_keys(t *testing.T) {
should := require.New(t)
cfg := jsoniter.Config{
SortMapKeys: true,
}.Froze()
cfg.RegisterExtension(&testMapKeyExtension{})
m := map[int]int{
3: 3,
1: 9,
}
output, err := cfg.MarshalToString(m)
should.NoError(err)
should.Equal(`{"2":9,"4":3}`, output)
m2 := map[int]int{}
should.NoError(cfg.UnmarshalFromString(output, &m2))
should.Equal(map[int]int{
1: 9,
3: 3,
}, m2)
}

func Test_customize_map_key_sorter(t *testing.T) {
should := require.New(t)
cfg := jsoniter.Config{
SortMapKeys: true,
}.Froze()

cfg.RegisterExtension(&testMapKeySorterExtension{
sorter: &testKeySorter{},
})

m := map[string]int{
"a": 1,
"foo": 2,
"b": 3,
}
output, err := cfg.MarshalToString(m)
should.NoError(err)
should.Equal(`{"foo":2,"a":1,"b":3}`, output)
m = map[string]int{}
should.NoError(cfg.UnmarshalFromString(output, &m))
should.Equal(map[string]int{
"foo": 2,
"a": 1,
"b": 3,
}, m)
}

type testKeySorter struct {
}

func (sorter *testKeySorter) Sort(keyA string, keyB string) bool {
// Prioritize "foo" over other keys, otherwise alpha-sort
if keyA == "foo" {
return true
} else if keyB == "foo" {
return false
} else {
return keyA < keyB
}
}

type testMapKeySorterExtension struct {
jsoniter.DummyExtension
sorter jsoniter.MapKeySorter
}

func (extension *testMapKeySorterExtension) CreateMapKeySorter() jsoniter.MapKeySorter {
return extension.sorter
}

type testMapKeyExtension struct {
jsoniter.DummyExtension
}
Expand Down
5 changes: 5 additions & 0 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ type ValEncoder interface {
Encode(ptr unsafe.Pointer, stream *Stream)
}

// MapKeySorter is used to define a custom function for sorting the keys of maps
type MapKeySorter interface {
Sort(keyA string, keyB string) bool
}

type checkIsEmpty interface {
IsEmpty(ptr unsafe.Pointer) bool
}
Expand Down
16 changes: 16 additions & 0 deletions reflect_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Extension interface {
UpdateStructDescriptor(structDescriptor *StructDescriptor)
CreateMapKeyDecoder(typ reflect2.Type) ValDecoder
CreateMapKeyEncoder(typ reflect2.Type) ValEncoder
CreateMapKeySorter() MapKeySorter
CreateDecoder(typ reflect2.Type) ValDecoder
CreateEncoder(typ reflect2.Type) ValEncoder
DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder
Expand All @@ -73,6 +74,11 @@ func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncod
return nil
}

// CreateMapKeySorter No-op
func (extension *DummyExtension) CreateMapKeySorter() MapKeySorter {
return nil
}

// CreateDecoder No-op
func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
return nil
Expand Down Expand Up @@ -119,6 +125,11 @@ func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEnco
return nil
}

// CreateMapKeySorter No-op
func (extension EncoderExtension) CreateMapKeySorter() MapKeySorter {
return nil
}

// DecorateDecoder No-op
func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
return decoder
Expand All @@ -145,6 +156,11 @@ func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEnco
return nil
}

// CreateMapKeySorter No-op
func (extension DecoderExtension) CreateMapKeySorter() MapKeySorter {
return nil
}

// CreateDecoder get decoder from map
func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
return extension[typ]
Expand Down
40 changes: 34 additions & 6 deletions reflect_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
return &sortKeysMapEncoder{
mapType: mapType,
keyEncoder: encoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()),
keySorter: sorterOfMapKey(ctx),
elemEncoder: encoderOfType(ctx.append("[mapElem]"), mapType.Elem()),
}
}
Expand All @@ -38,6 +39,23 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
}
}

type defaultMapKeySorter struct {
}

func (sorter defaultMapKeySorter) Sort(keyA string, keyB string) bool {
return keyA < keyB
}

func sorterOfMapKey(ctx *ctx) MapKeySorter {
for _, extension := range ctx.extraExtensions {
sorter := extension.CreateMapKeySorter()
if sorter != nil {
return sorter
}
}
return defaultMapKeySorter{}
}

func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ)
if decoder != nil {
Expand Down Expand Up @@ -275,6 +293,7 @@ func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
type sortKeysMapEncoder struct {
mapType *reflect2.UnsafeMapType
keyEncoder ValEncoder
keySorter MapKeySorter
elemEncoder ValEncoder
}

Expand All @@ -287,7 +306,7 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
mapIter := encoder.mapType.UnsafeIterate(ptr)
subStream := stream.cfg.BorrowStream(nil)
subIter := stream.cfg.BorrowIterator(nil)
keyValues := encodedKeyValues{}
keyValues := []encodedKV{}
for mapIter.HasNext() {
subStream.buf = make([]byte, 0, 64)
key, elem := mapIter.UnsafeNext()
Expand All @@ -309,7 +328,11 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
keyValue: subStream.Buffer(),
})
}
sort.Sort(keyValues)
keyValueWrapper := encodedKeyValues{
keySorter: encoder.keySorter,
values: keyValues,
}
sort.Sort(keyValueWrapper)
for i, keyValue := range keyValues {
if i != 0 {
stream.WriteMore()
Expand All @@ -326,13 +349,18 @@ func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
return !iter.HasNext()
}

type encodedKeyValues []encodedKV
type encodedKeyValues struct {
keySorter MapKeySorter
values []encodedKV
}

type encodedKV struct {
key string
keyValue []byte
}

func (sv encodedKeyValues) Len() int { return len(sv) }
func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key }
func (sv encodedKeyValues) Len() int { return len(sv.values) }
func (sv encodedKeyValues) Swap(i, j int) { sv.values[i], sv.values[j] = sv.values[j], sv.values[i] }
func (sv encodedKeyValues) Less(i, j int) bool {
return sv.keySorter.Sort(sv.values[i].key, sv.values[j].key)
}