teflon: Use correct convolution params struct

Use TfLiteDepthwiseConvParams for kTfLiteBuiltinDepthwiseConv2d.
The layout of stride_width, stride_height, and padding struct members
happens to be the same, but we shouldn't depend on that.
This prepares for using the activation, dilation_width_factor, and
dilation_height_factor members, which are at different offsets.

Reviewed-by: Tomeu Vizoso <tomeu@tomeuvizoso.net>
Signed-off-by: Philipp Zabel <p.zabel@pengutronix.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31981>
This commit is contained in:
Philipp Zabel 2024-09-16 17:05:46 +02:00 committed by Marge Bot
parent e3257f7461
commit 319c56b10a

View File

@ -109,14 +109,24 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi
switch(node_registration->builtin_code) {
case kTfLiteBuiltinConv2d:
case kTfLiteBuiltinDepthwiseConv2d: {
TfLiteConvParams* params = (TfLiteConvParams*)node->builtin_data;
operation->type = PIPE_ML_OPERATION_TYPE_CONVOLUTION;
operation->conv.weight_tensor = &tensors[node->inputs->data[1]];
operation->conv.bias_tensor = &tensors[node->inputs->data[2]];
operation->conv.stride_x = params->stride_width;
operation->conv.stride_y = params->stride_height;
operation->conv.padding_same = params->padding == kTfLitePaddingSame;
operation->conv.depthwise = node_registration->builtin_code == kTfLiteBuiltinDepthwiseConv2d;
if (node_registration->builtin_code == kTfLiteBuiltinConv2d) {
TfLiteConvParams* params = (TfLiteConvParams*)node->builtin_data;
operation->conv.stride_x = params->stride_width;
operation->conv.stride_y = params->stride_height;
operation->conv.padding_same = params->padding == kTfLitePaddingSame;
operation->conv.depthwise = false;
} else {
TfLiteDepthwiseConvParams* params = (TfLiteDepthwiseConvParams*)node->builtin_data;
operation->conv.stride_x = params->stride_width;
operation->conv.stride_y = params->stride_height;
operation->conv.padding_same = params->padding == kTfLitePaddingSame;
operation->conv.depthwise = true;
}
operation->conv.pointwise = operation->conv.weight_tensor->dims[1] == 1 && \
operation->conv.weight_tensor->dims[2] == 1;
break;