Skip to content

Commit bedc124

Browse files
committed
Added support for frozen_default for dataclass_transform.
1 parent 9b44419 commit bedc124

File tree

8 files changed

+96
-10
lines changed

8 files changed

+96
-10
lines changed

packages/pyright-internal/src/analyzer/dataClasses.ts

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ export function validateDataClassTransformDecorator(
713713
keywordOnlyParams: false,
714714
generateEq: true,
715715
generateOrder: false,
716+
frozen: false,
716717
fieldDescriptorNames: [],
717718
};
718719

@@ -780,6 +781,24 @@ export function validateDataClassTransformDecorator(
780781
break;
781782
}
782783

784+
case 'frozen_default': {
785+
const value = evaluateStaticBoolExpression(
786+
arg.valueExpression,
787+
fileInfo.executionEnvironment,
788+
fileInfo.definedConstants
789+
);
790+
if (value === undefined) {
791+
evaluator.addError(
792+
Localizer.Diagnostic.dataClassTransformExpectedBoolLiteral(),
793+
arg.valueExpression
794+
);
795+
return;
796+
}
797+
798+
behaviors.frozen = value;
799+
break;
800+
}
801+
783802
// Earlier versions of the dataclass_transform spec used the name "field_descriptors"
784803
// rather than "field_specifiers". The older name is now deprecated but still supported
785804
// for the time being because some libraries shipped with the older __dataclass_transform__
@@ -858,6 +877,7 @@ export function getDataclassDecoratorBehaviors(type: Type): DataClassBehaviors |
858877
keywordOnlyParams: false,
859878
generateEq: true,
860879
generateOrder: false,
880+
frozen: false,
861881
fieldDescriptorNames: ['dataclasses.field', 'dataclasses.Field'],
862882
};
863883
}
@@ -998,7 +1018,8 @@ export function applyDataClassClassBehaviorOverrides(
9981018
evaluator: TypeEvaluator,
9991019
errorNode: ParseNode,
10001020
classType: ClassType,
1001-
args: FunctionArgument[]
1021+
args: FunctionArgument[],
1022+
defaultBehaviors: DataClassBehaviors
10021023
) {
10031024
let sawFrozenArg = false;
10041025

@@ -1015,7 +1036,7 @@ export function applyDataClassClassBehaviorOverrides(
10151036
// If there was no frozen argument, it is implicitly false. This will
10161037
// validate that we're not overriding a frozen class with a non-frozen class.
10171038
if (!sawFrozenArg) {
1018-
applyDataClassBehaviorOverrideValue(evaluator, errorNode, classType, 'frozen', false);
1039+
applyDataClassBehaviorOverrideValue(evaluator, errorNode, classType, 'frozen', defaultBehaviors.frozen);
10191040
}
10201041
}
10211042

@@ -1034,6 +1055,10 @@ export function applyDataClassDefaultBehaviors(classType: ClassType, defaultBeha
10341055
if (defaultBehaviors.generateOrder) {
10351056
classType.details.flags |= ClassTypeFlags.SynthesizedDataClassOrder;
10361057
}
1058+
1059+
if (defaultBehaviors.frozen) {
1060+
classType.details.flags |= ClassTypeFlags.FrozenDataClass;
1061+
}
10371062
}
10381063

10391064
export function applyDataClassDecorator(
@@ -1045,7 +1070,5 @@ export function applyDataClassDecorator(
10451070
) {
10461071
applyDataClassDefaultBehaviors(classType, defaultBehaviors);
10471072

1048-
if (callNode?.arguments) {
1049-
applyDataClassClassBehaviorOverrides(evaluator, errorNode, classType, callNode.arguments);
1050-
}
1073+
applyDataClassClassBehaviorOverrides(evaluator, errorNode, classType, callNode?.arguments ?? [], defaultBehaviors);
10511074
}

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15576,7 +15576,13 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1557615576

1557715577
if (dataClassBehaviors) {
1557815578
applyDataClassDefaultBehaviors(classType, dataClassBehaviors);
15579-
applyDataClassClassBehaviorOverrides(evaluatorInterface, node.name, classType, initSubclassArgs);
15579+
applyDataClassClassBehaviorOverrides(
15580+
evaluatorInterface,
15581+
node.name,
15582+
classType,
15583+
initSubclassArgs,
15584+
dataClassBehaviors
15585+
);
1558015586
}
1558115587

1558215588
// Run any class hooks that depend on this class.

packages/pyright-internal/src/analyzer/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ export interface DataClassBehaviors {
461461
keywordOnlyParams: boolean;
462462
generateEq: boolean;
463463
generateOrder: boolean;
464+
frozen: boolean;
464465
fieldDescriptorNames: string[];
465466
}
466467

packages/pyright-internal/src/tests/samples/dataclassTransform1.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,24 @@ class Customer2Subclass(Customer2, frozen=True):
6969
c2_2 = Customer2(0, "John")
7070

7171
v2 = c2_1 < c2_2
72+
73+
74+
@dataclass_transform(kw_only_default=True, order_default=True, frozen_default=True)
75+
def create_model_frozen(cls: _T) -> _T:
76+
...
77+
78+
@create_model_frozen
79+
class Customer3:
80+
id: int
81+
name: str
82+
83+
# This should generate an error because a non-frozen class
84+
# cannot inherit from a frozen class.
85+
@create_model
86+
class Customer3Subclass(Customer3):
87+
age: int
88+
89+
c3_1 = Customer3(id=2, name="hi")
90+
91+
# This should generate an error because Customer3 is frozen.
92+
c3_1.id = 4

packages/pyright-internal/src/tests/samples/dataclassTransform2.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,22 @@ class Customer2(ModelBase, order=True):
8080
# This should generate an error because Customer2 supports
8181
# keyword-only parameters for its constructor.
8282
c2_3 = Customer2(0, "John")
83+
84+
85+
86+
@dataclass_transform(frozen_default=True)
87+
class ModelMetaFrozen(type):
88+
pass
89+
90+
class ModelBaseFrozen(metaclass=ModelMetaFrozen):
91+
...
92+
93+
class Customer3(ModelBaseFrozen):
94+
id: int
95+
name: str
96+
97+
98+
c3_1 = Customer3(id=2, name="hi")
99+
100+
# This should generate an error because Customer3 is frozen.
101+
c3_1.id = 4

packages/pyright-internal/src/tests/samples/dataclassTransform3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def __dataclass_transform__(
1111
eq_default: bool = True,
1212
order_default: bool = False,
1313
kw_only_default: bool = False,
14+
frozen_default: bool = False,
1415
field_specifiers: Tuple[Union[type, Callable[..., Any]], ...] = (()),
1516
) -> Callable[[_T], _T]:
1617
return lambda a: a
@@ -108,4 +109,18 @@ def __init_subclass__(
108109
class GenericCustomer(GenericModelBase[int]):
109110
id: int = model_field()
110111

111-
gc_1 = GenericCustomer(id=3)
112+
gc_1 = GenericCustomer(id=3)
113+
114+
@__dataclass_transform__(frozen_default=True)
115+
class ModelBaseFrozen:
116+
not_a_field: str
117+
118+
class Customer3(ModelBaseFrozen):
119+
id: int
120+
name: str
121+
122+
123+
c3_1 = Customer3(id=2, name="hi")
124+
125+
# This should generate an error because Customer3 is frozen.
126+
c3_1.id = 4

packages/pyright-internal/src/tests/typeEvaluator3.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,19 +1383,19 @@ test('Decorator7', () => {
13831383
test('DataclassTransform1', () => {
13841384
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassTransform1.py']);
13851385

1386-
TestUtils.validateResults(analysisResults, 4);
1386+
TestUtils.validateResults(analysisResults, 6);
13871387
});
13881388

13891389
test('DataclassTransform2', () => {
13901390
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassTransform2.py']);
13911391

1392-
TestUtils.validateResults(analysisResults, 5);
1392+
TestUtils.validateResults(analysisResults, 6);
13931393
});
13941394

13951395
test('DataclassTransform3', () => {
13961396
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassTransform3.py']);
13971397

1398-
TestUtils.validateResults(analysisResults, 5);
1398+
TestUtils.validateResults(analysisResults, 6);
13991399
});
14001400

14011401
test('DataclassTransform4', () => {

packages/pyright-internal/typeshed-fallback/stdlib/typing_extensions.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ else:
227227
eq_default: bool = ...,
228228
order_default: bool = ...,
229229
kw_only_default: bool = ...,
230+
frozen_default: bool = ...,
230231
field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = ...,
231232
**kwargs: object,
232233
) -> IdentityFunction: ...

0 commit comments

Comments
 (0)