Skip to content

Commit 4ee55d7

Browse files
anandoleecopybara-github
authored andcommitted
Raise warnings when assign bool to int/enum field in Python Proto. This will turn into error in 34.0 release.
PiperOrigin-RevId: 789619218
1 parent d9950d4 commit 4ee55d7

File tree

7 files changed

+220
-17
lines changed

7 files changed

+220
-17
lines changed

‎python/convert.c‎

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
// https://developers.google.com/open-source/licenses/bsd
77

88
#include "python/convert.h"
9-
109
#include "python/message.h"
1110
#include "python/protobuf.h"
1211
#include "upb/message/compare.h"
@@ -62,7 +61,22 @@ PyObject* PyUpb_UpbToPy(upb_MessageValue val, const upb_FieldDef* f,
6261
}
6362
}
6463

65-
static bool PyUpb_GetInt64(PyObject* obj, int64_t* val) {
64+
// TODO: raise error in 2026 Q1 release
65+
static void WarnBool(const upb_FieldDef* f) {
66+
static int bool_warning_count = 100;
67+
if (bool_warning_count > 0) {
68+
--bool_warning_count;
69+
PyErr_WarnFormat(PyExc_DeprecationWarning, 3,
70+
"Field %s: Expected an int, got a boolean. This "
71+
"will be rejected in 7.34.0, please fix it before that",
72+
upb_FieldDef_FullName(f));
73+
}
74+
}
75+
76+
static bool PyUpb_GetInt64(PyObject* obj, const upb_FieldDef* f, int64_t* val) {
77+
if (PyBool_Check(obj)) {
78+
WarnBool(f);
79+
}
6680
// We require that the value is either an integer or has an __index__
6781
// conversion.
6882
obj = PyNumber_Index(obj);
@@ -81,7 +95,11 @@ static bool PyUpb_GetInt64(PyObject* obj, int64_t* val) {
8195
return ok;
8296
}
8397

84-
static bool PyUpb_GetUint64(PyObject* obj, uint64_t* val) {
98+
static bool PyUpb_GetUint64(PyObject* obj, const upb_FieldDef* f,
99+
uint64_t* val) {
100+
if (PyBool_Check(obj)) {
101+
WarnBool(f);
102+
}
85103
// We require that the value is either an integer or has an __index__
86104
// conversion.
87105
obj = PyNumber_Index(obj);
@@ -98,9 +116,9 @@ static bool PyUpb_GetUint64(PyObject* obj, uint64_t* val) {
98116
return ok;
99117
}
100118

101-
static bool PyUpb_GetInt32(PyObject* obj, int32_t* val) {
119+
static bool PyUpb_GetInt32(PyObject* obj, const upb_FieldDef* f, int32_t* val) {
102120
int64_t i64;
103-
if (!PyUpb_GetInt64(obj, &i64)) return false;
121+
if (!PyUpb_GetInt64(obj, f, &i64)) return false;
104122
if (i64 < INT32_MIN || i64 > INT32_MAX) {
105123
PyErr_Format(PyExc_ValueError, "Value out of range: %S", obj);
106124
return false;
@@ -109,9 +127,10 @@ static bool PyUpb_GetInt32(PyObject* obj, int32_t* val) {
109127
return true;
110128
}
111129

112-
static bool PyUpb_GetUint32(PyObject* obj, uint32_t* val) {
130+
static bool PyUpb_GetUint32(PyObject* obj, const upb_FieldDef* f,
131+
uint32_t* val) {
113132
uint64_t u64;
114-
if (!PyUpb_GetUint64(obj, &u64)) return false;
133+
if (!PyUpb_GetUint64(obj, f, &u64)) return false;
115134
if (u64 > UINT32_MAX) {
116135
PyErr_Format(PyExc_ValueError, "Value out of range: %S", obj);
117136
return false;
@@ -164,8 +183,9 @@ const char* upb_FieldDef_TypeString(const upb_FieldDef* f) {
164183
UPB_UNREACHABLE();
165184
}
166185

167-
static bool PyUpb_PyToUpbEnum(PyObject* obj, const upb_EnumDef* e,
186+
static bool PyUpb_PyToUpbEnum(PyObject* obj, const upb_FieldDef* f,
168187
upb_MessageValue* val) {
188+
const upb_EnumDef* e = upb_FieldDef_EnumSubDef(f);
169189
if (PyUnicode_Check(obj)) {
170190
Py_ssize_t size;
171191
const char* name = PyUnicode_AsUTF8AndSize(obj, &size);
@@ -178,8 +198,11 @@ static bool PyUpb_PyToUpbEnum(PyObject* obj, const upb_EnumDef* e,
178198
val->int32_val = upb_EnumValueDef_Number(ev);
179199
return true;
180200
} else {
201+
if (PyBool_Check(obj)) {
202+
WarnBool(f);
203+
}
181204
int32_t i32;
182-
if (!PyUpb_GetInt32(obj, &i32)) return false;
205+
if (!PyUpb_GetInt32(obj, f, &i32)) return false;
183206
if (upb_EnumDef_IsClosed(e) && !upb_EnumDef_CheckNumber(e, i32)) {
184207
PyErr_Format(PyExc_ValueError, "invalid enumerator %d", (int)i32);
185208
return false;
@@ -238,15 +261,15 @@ bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
238261
upb_Arena* arena) {
239262
switch (upb_FieldDef_CType(f)) {
240263
case kUpb_CType_Enum:
241-
return PyUpb_PyToUpbEnum(obj, upb_FieldDef_EnumSubDef(f), val);
264+
return PyUpb_PyToUpbEnum(obj, f, val);
242265
case kUpb_CType_Int32:
243-
return PyUpb_GetInt32(obj, &val->int32_val);
266+
return PyUpb_GetInt32(obj, f, &val->int32_val);
244267
case kUpb_CType_Int64:
245-
return PyUpb_GetInt64(obj, &val->int64_val);
268+
return PyUpb_GetInt64(obj, f, &val->int64_val);
246269
case kUpb_CType_UInt32:
247-
return PyUpb_GetUint32(obj, &val->uint32_val);
270+
return PyUpb_GetUint32(obj, f, &val->uint32_val);
248271
case kUpb_CType_UInt64:
249-
return PyUpb_GetUint64(obj, &val->uint64_val);
272+
return PyUpb_GetUint64(obj, f, &val->uint64_val);
250273
case kUpb_CType_Float:
251274
if (!PyFloat_Check(obj) && PyUpb_IsNumpyNdarray(obj, f)) return false;
252275
val->float_val = PyFloat_AsDouble(obj);

‎python/google/protobuf/internal/message_test.py‎

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,115 @@ def testMessageClassName(self, message_module):
14311431
'TestAllTypes.NestedMessage', nested.__class__.__qualname__
14321432
)
14331433

1434+
def testAssignBoolToEnum(self, message_module):
1435+
# TODO: change warning into error in 2026 Q1
1436+
# with self.assertRaises(TypeError):
1437+
with warnings.catch_warnings(record=True) as w:
1438+
m = message_module.TestAllTypes(optional_nested_enum=True)
1439+
self.assertIn('bool', str(w[0].message))
1440+
self.assertEqual(m.optional_nested_enum, 1)
1441+
1442+
m = message_module.TestAllTypes(optional_nested_enum=2)
1443+
with warnings.catch_warnings(record=True) as w:
1444+
m.optional_nested_enum = True
1445+
self.assertIn('bool', str(w[0].message))
1446+
self.assertEqual(m.optional_nested_enum, 1)
1447+
1448+
with warnings.catch_warnings(record=True) as w:
1449+
m.optional_nested_enum = 2
1450+
self.assertFalse(w)
1451+
self.assertEqual(m.optional_nested_enum, 2)
1452+
1453+
def testBoolToRepeatedEnum(self, message_module):
1454+
with warnings.catch_warnings(record=True) as w:
1455+
m = message_module.TestAllTypes(repeated_nested_enum=[True])
1456+
self.assertIn('bool', str(w[0].message))
1457+
self.assertEqual(m.repeated_nested_enum, [1])
1458+
1459+
m = message_module.TestAllTypes()
1460+
with warnings.catch_warnings(record=True) as w:
1461+
m.repeated_nested_enum.append(True)
1462+
self.assertIn('bool', str(w[0].message))
1463+
self.assertEqual(m.repeated_nested_enum, [1])
1464+
1465+
def testBoolToOneofEnum(self, message_module):
1466+
m = unittest_pb2.TestOneof2()
1467+
with warnings.catch_warnings(record=True) as w:
1468+
m.foo_enum = True
1469+
self.assertIn('bool', str(w[0].message))
1470+
self.assertEqual(m.foo_enum, 1)
1471+
1472+
def testBoolToMapEnum(self, message_module):
1473+
m = map_unittest_pb2.TestMap()
1474+
with warnings.catch_warnings(record=True) as w:
1475+
m.map_int32_enum[10] = True
1476+
self.assertIn('bool', str(w[0].message))
1477+
self.assertEqual(m.map_int32_enum[10], 1)
1478+
1479+
def testBoolToExtensionEnum(self, message_module):
1480+
m = unittest_pb2.TestAllExtensions()
1481+
with warnings.catch_warnings(record=True) as w:
1482+
m.Extensions[unittest_pb2.optional_nested_enum_extension] = True
1483+
self.assertIn('bool', str(w[0].message))
1484+
self.assertEqual(
1485+
m.Extensions[unittest_pb2.optional_nested_enum_extension], 1
1486+
)
1487+
1488+
def testAssignBoolToInt(self, message_module):
1489+
with warnings.catch_warnings(record=True) as w:
1490+
m = message_module.TestAllTypes(optional_int32=True)
1491+
self.assertIn('bool', str(w[0].message))
1492+
self.assertEqual(m.optional_int32, 1)
1493+
1494+
m = message_module.TestAllTypes(optional_uint32=123)
1495+
with warnings.catch_warnings(record=True) as w:
1496+
m.optional_uint32 = True
1497+
self.assertIn('bool', str(w[0].message))
1498+
self.assertEqual(m.optional_uint32, 1)
1499+
1500+
with warnings.catch_warnings(record=True) as w:
1501+
m.optional_uint32 = 321
1502+
self.assertFalse(w)
1503+
self.assertEqual(m.optional_uint32, 321)
1504+
1505+
def testAssignBoolToRepeatedInt(self, message_module):
1506+
with warnings.catch_warnings(record=True) as w:
1507+
m = message_module.TestAllTypes(repeated_int64=[True])
1508+
self.assertIn('bool', str(w[0].message))
1509+
self.assertEqual(m.repeated_int64, [1])
1510+
1511+
m = message_module.TestAllTypes()
1512+
with warnings.catch_warnings(record=True) as w:
1513+
m.repeated_int64.append(True)
1514+
self.assertIn('bool', str(w[0].message))
1515+
self.assertEqual(m.repeated_int64, [1])
1516+
1517+
def testAssignBoolToOneofInt(self, message_module):
1518+
m = unittest_pb2.TestOneof2()
1519+
with warnings.catch_warnings(record=True) as w:
1520+
m.foo_int = True
1521+
self.assertIn('bool', str(w[0].message))
1522+
self.assertEqual(m.foo_int, 1)
1523+
1524+
def testAssignBoolToMapInt(self, message_module):
1525+
m = map_unittest_pb2.TestMap()
1526+
with warnings.catch_warnings(record=True) as w:
1527+
m.map_int32_int32[10] = True
1528+
self.assertIn('bool', str(w[0].message))
1529+
self.assertEqual(m.map_int32_int32[10], 1)
1530+
1531+
with warnings.catch_warnings(record=True) as w:
1532+
m.map_int32_int32[True] = 1
1533+
self.assertIn('bool', str(w[0].message))
1534+
self.assertEqual(m.map_int32_int32[1], 1)
1535+
1536+
def testAssignBoolToExtensionInt(self, message_module):
1537+
m = unittest_pb2.TestAllExtensions()
1538+
with warnings.catch_warnings(record=True) as w:
1539+
m.Extensions[unittest_pb2.optional_int32_extension] = True
1540+
self.assertIn('bool', str(w[0].message))
1541+
self.assertEqual(m.Extensions[unittest_pb2.optional_int32_extension], 1)
1542+
14341543

14351544
@testing_refleaks.TestCase
14361545
class TestRecursiveGroup(unittest.TestCase):

‎python/google/protobuf/internal/type_checkers.py‎

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,20 @@
2222

2323
__author__ = 'robinson@google.com (Will Robinson)'
2424

25-
import struct
2625
import numbers
26+
import struct
27+
import warnings
2728

29+
from google.protobuf import descriptor
2830
from google.protobuf.internal import decoder
2931
from google.protobuf.internal import encoder
3032
from google.protobuf.internal import wire_format
31-
from google.protobuf import descriptor
3233

3334
_FieldDescriptor = descriptor.FieldDescriptor
34-
35+
# TODO: Remove this warning count after 34.0
36+
# Assign bool to int/enum warnings will print 100 times at most which should
37+
# be enough for users to notice and do not cause timeout.
38+
_BoolWarningCount = 100
3539

3640
def TruncateToFourByteFloat(original):
3741
return struct.unpack('<f', struct.pack('<f', original))[0]
@@ -141,6 +145,20 @@ class IntValueChecker(object):
141145
"""Checker used for integer fields. Performs type-check and range check."""
142146

143147
def CheckValue(self, proposed_value):
148+
if type(proposed_value) == bool and _BoolWarningCount > 0:
149+
--_BoolWarningCount
150+
message = (
151+
'%.1024r has type %s, but expected one of: %s. This warning '
152+
'will turn into error in 7.34.0, please fix it before that.'
153+
% (
154+
proposed_value,
155+
type(proposed_value),
156+
(int,),
157+
)
158+
)
159+
# TODO: Raise errors in 2026 Q1 release
160+
warnings.warn(message)
161+
144162
if not hasattr(proposed_value, '__index__') or (
145163
type(proposed_value).__module__ == 'numpy' and
146164
type(proposed_value).__name__ == 'ndarray'):
@@ -167,6 +185,19 @@ def __init__(self, enum_type):
167185
self._enum_type = enum_type
168186

169187
def CheckValue(self, proposed_value):
188+
if type(proposed_value) == bool and _BoolWarningCount > 0:
189+
--_BoolWarningCount
190+
message = (
191+
'%.1024r has type %s, but expected one of: %s. This warning '
192+
'will turn into error in 7.34.0, please fix it before that.'
193+
% (
194+
proposed_value,
195+
type(proposed_value),
196+
(int,),
197+
)
198+
)
199+
# TODO: Raise errors in 2026 Q1 release
200+
warnings.warn(message)
170201
if not isinstance(proposed_value, numbers.Integral):
171202
message = ('%.1024r has type %s, but expected one of: %s' %
172203
(proposed_value, type(proposed_value), (int,)))

‎python/google/protobuf/pyext/map_container.cc‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,25 @@ static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key,
108108
switch (field_descriptor->cpp_type()) {
109109
case FieldDescriptor::CPPTYPE_INT32: {
110110
PROTOBUF_CHECK_GET_INT32(obj, value, false);
111+
CheckIntegerWithBool(obj, field_descriptor);
111112
key->SetInt32Value(value);
112113
break;
113114
}
114115
case FieldDescriptor::CPPTYPE_INT64: {
115116
PROTOBUF_CHECK_GET_INT64(obj, value, false);
117+
CheckIntegerWithBool(obj, field_descriptor);
116118
key->SetInt64Value(value);
117119
break;
118120
}
119121
case FieldDescriptor::CPPTYPE_UINT32: {
120122
PROTOBUF_CHECK_GET_UINT32(obj, value, false);
123+
CheckIntegerWithBool(obj, field_descriptor);
121124
key->SetUInt32Value(value);
122125
break;
123126
}
124127
case FieldDescriptor::CPPTYPE_UINT64: {
125128
PROTOBUF_CHECK_GET_UINT64(obj, value, false);
129+
CheckIntegerWithBool(obj, field_descriptor);
126130
key->SetUInt64Value(value);
127131
break;
128132
}
@@ -210,21 +214,25 @@ static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
210214
switch (field_descriptor->cpp_type()) {
211215
case FieldDescriptor::CPPTYPE_INT32: {
212216
PROTOBUF_CHECK_GET_INT32(obj, value, false);
217+
CheckIntegerWithBool(obj, field_descriptor);
213218
value_ref->SetInt32Value(value);
214219
return true;
215220
}
216221
case FieldDescriptor::CPPTYPE_INT64: {
217222
PROTOBUF_CHECK_GET_INT64(obj, value, false);
223+
CheckIntegerWithBool(obj, field_descriptor);
218224
value_ref->SetInt64Value(value);
219225
return true;
220226
}
221227
case FieldDescriptor::CPPTYPE_UINT32: {
222228
PROTOBUF_CHECK_GET_UINT32(obj, value, false);
229+
CheckIntegerWithBool(obj, field_descriptor);
223230
value_ref->SetUInt32Value(value);
224231
return true;
225232
}
226233
case FieldDescriptor::CPPTYPE_UINT64: {
227234
PROTOBUF_CHECK_GET_UINT64(obj, value, false);
235+
CheckIntegerWithBool(obj, field_descriptor);
228236
value_ref->SetUInt64Value(value);
229237
return true;
230238
}
@@ -253,6 +261,7 @@ static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
253261
}
254262
case FieldDescriptor::CPPTYPE_ENUM: {
255263
PROTOBUF_CHECK_GET_INT32(obj, value, false);
264+
CheckIntegerWithBool(obj, field_descriptor);
256265
if (allow_unknown_enum_values) {
257266
value_ref->SetEnumValue(value);
258267
return true;

0 commit comments

Comments
 (0)